(ns clj-d2l.dropout
(:require
[clojure.spec.alpha :as s]
[clj-djl.ndarray :as nd]
[clj-djl.training :as t]
[clj-djl.training.dataset :as ds]
[clj-djl.training.loss :as loss]
[clj-djl.training.optimizer :as optimizer]
[clj-djl.training.tracker :as tracker]
[clj-djl.training.listener :as listener]
[clj-djl.model :as m]
[clj-djl.nn :as nn]
[clj-djl.device :as dev]
[clj-d2l.core :as d2l]))
(def ndm (nd/base-manager))
(defn dropout-layer
([^ai.djl.ndarray.NDArray X dropout]
{:pre [(s/valid? #(<= 0 % 1) dropout)]}
(condp = dropout
1 (nd/zeros-like X)
0 X
(-> (nd/random-uniform ndm 0 1 (nd/shape X))
(nd/> dropout)
(nd/to-type :float32 false)
(nd/* X)
(nd// (- 1.0 dropout))))))
(def X (-> (nd/arange ndm 16)
(nd/reshape 2 8)))
X
ND: (2, 8) cpu() int32
[[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15],
]
ND: (2, 8) cpu() int32
[[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15],
]
ND: (2, 8) cpu() int32
[[ 0, 0, 0, 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0],
]
ND: (2, 8) cpu() float32
[[ 0., 2., 4., 6., 0., 0., 12., 0.],
[16., 18., 20., 22., 24., 0., 28., 30.],
]
(defn train [ndm net train-ds test-ds nepochs loss-fn updater-fn]
(let [results (atom {:epoch []
:train-loss []
:train-accuracy []
:test-loss []
:test-accuracy []})
acc (atom {:size 0
:loss 0
:accuracy 0})]
(doseq [i (range (inc nepochs))]
(print "epoch " i ", training...")
(swap! results assoc-in [:epoch i] i)
(doseq [batch (t/iter-seq (ds/get-data train-ds ndm))]
(let [X (-> batch ds/get-batch-data nd/head)
y (-> batch ds/get-batch-labels nd/head)]
(with-open [gc (-> (t/gradient-collector))]
(let [y-hat (net X true)
l (loss-fn y-hat y)
a (d2l/accuracy y-hat y)]
(.backward gc l)
(swap! acc update :size + (nd/size y))
(swap! acc update :loss + (nd/get-element (nd/sum l)))
(swap! acc update :accuracy + a))))
(updater-fn)
(.close batch))
(swap! results assoc-in [:train-loss i] (/ (@acc :loss) (@acc :size)))
(swap! results assoc-in [:train-accuracy i] (/ (@acc :accuracy) (@acc :size)))
(reset! acc {:size 0 :loss 0 :accuracy 0})
(println "validating...")
(doseq [batch (t/iter-seq (ds/get-data test-ds ndm))]
(let [X (-> batch ds/get-batch-data nd/head)
y (-> batch ds/get-batch-labels nd/head)]
(let [y-hat (net X false)
l (loss-fn y-hat y)
a (d2l/accuracy y-hat y)]
(swap! acc update :size + (nd/size y))
(swap! acc update :loss + (nd/get-element (nd/sum l)))
(swap! acc update :accuracy + a))))
(swap! results assoc-in [:test-loss i] (/ (@acc :loss) (@acc :size)))
(swap! results assoc-in [:test-accuracy i] (/ (@acc :accuracy) (@acc :size))))
@results))
(defn softmax [ndarray]
(let [Xexp (nd/exp ndarray)
partition (nd/sum Xexp [1] true)]
(nd// Xexp partition)))
(defn softmax-cross-entropy [y-hat y]
(-> (softmax y-hat)
(nd/get ":, {}" (nd/to-type y :int32 false))
(nd/log)
(nd/-)))
(defn make-net [ninputs W1 b1 W2 b2 W3 b3 dropout1 dropout2]
(fn [X training?]
(-> X
(nd/reshape [-1 ninputs])
(nd/dot W1) (nd/+ b1) (nn/relu)
((fn [layer]
(if training?
(dropout-layer layer dropout1)
layer)))
(nd/dot W2) (nd/+ b2) (nn/relu)
((fn [layer]
(if training?
(dropout-layer layer dropout2)
layer)))
(nd/dot W3) (nd/+ b3))))
(defn do-train [nepochs lr dropout1 dropout2 output]
(let [batchsize 256
[ninputs noutputs nhiddens1 nhiddens2] [784 10 256 256]
W1 (nd/random-normal ndm 0 0.01 [ninputs nhiddens1])
b1 (nd/zeros ndm nhiddens1)
W2 (nd/random-normal ndm 0 0.01 [nhiddens1 nhiddens2])
b2 (nd/zeros ndm nhiddens2)
W3 (nd/random-normal ndm 0 0.01 [nhiddens2 noutputs])
b3 (nd/zeros ndm noutputs)
params [W1 b1 W2 b2 W3 b3]
_ (run! nd/attach-gradient params)
dataset (d2l/load-data-fashion-mnist batchsize)
results (train ndm (make-net ninputs W1 b1 W2 b2 W3 b3 dropout1 dropout2)
(dataset 0) (dataset 1)
nepochs softmax-cross-entropy #(d2l/sgd params lr batchsize))]
(d2l/plot-lines output
["train loss" "train acc" "test acc"]
(results :epoch)
[(results :train-loss) (results :train-accuracy) (results :test-accuracy)])))
(setq org-babel-clojure-sync-nrepl-timeout 1000)
1000
(do-train 10 0.4 0.5 0.5 "figure/dropout_11.svg")
epoch 0 , training...validating...
epoch 1 , training...validating...
epoch 2 , training...validating...
epoch 3 , training...validating...
epoch 4 , training...validating...
epoch 5 , training...validating...
epoch 6 , training...validating...
epoch 7 , training...validating...
epoch 8 , training...validating...
epoch 9 , training...validating...
epoch 10 , training...validating...
(do-train 1 0.4 0.1 0.5 "figure/dropout_12.svg")
epoch 0 , training...validating...
epoch 1 , training...validating...
(do-train 10 0.4 0.5 0.1 "figure/dropout_13.svg")
(do-train 10 0.4 0.1 0.1 "figure/dropout_14.svg")
(do-train 10 0.1 0.5 0.5 "figure/dropout_21.svg")
(do-train 10 0.1 0.1 0.5 "figure/dropout_22.svg")
(do-train 10 0.1 0.5 0.1 "figure/dropout_23.svg")
(do-train 10 0.1 0.1 0.1 "figure/dropout_24.svg")
(defn do-train-concise [nepochs lr dropout1 dropout2 output]
(let [batchsize 256
[ninputs noutputs nhiddens1 nhiddens2] [784 10 256 256]
dataset (d2l/load-data-fashion-mnist batchsize)
net (-> (nn/sequential-block)
(nn/add (nn/batch-flatten-block ninputs))
(nn/add (nn/linear-block {:units nhiddens1}))
(nn/add nn/relu)
(nn/add (nn/dropout {:rate dropout1}))
(nn/add (nn/linear-block {:units nhiddens2}))
(nn/add nn/relu)
(nn/add (nn/dropout {:rate dropout2}))
(nn/add (nn/linear-block {:units noutputs}))
(nn/set-initializer (nn/normal-initializer)))
opt (optimizer/sgd {:tracker (tracker/fixed 0.5)})
loss (loss/sotfmax-cross-entropy-loss)
config (t/default-training-config {:loss loss :optimizer opt
:evaluator (t/accuracy)
:listeners (listener/logging)})]
(with-open [model (m/model {:name "mlp" :block net})
trainer (t/trainer model config)]
(t/initialize trainer (nd/shape [1 784]))
(t/set-metrics trainer (t/metrics))
(t/fit trainer nepochs (dataset 0) (dataset 1))
(let [metrics (t/get-metrics trainer)]
(d2l/plot-lines output
["train loss" "validate loss" "train accuracy"]
(range nepochs)
[(map :value (metrics "train_epoch_SoftmaxCrossEntropyLoss"))
(map :value (metrics "validate_epoch_SoftmaxCrossEntropyLoss"))
(map :value (metrics "validate_epoch_Accuracy"))])))))
(do-train-concise 10 0.1 0.1 0.5 "figure/dropout_41.svg")
Training: 100% |████████████████████████████████████████| Accuracy: 0.55, SoftmaxCrossEntropyLoss: 1.17
Validating: 100% |████████████████████████████████████████|
Training: 100% |████████████████████████████████████████| Accuracy: 0.78, SoftmaxCrossEntropyLoss: 0.58
Validating: 100% |████████████████████████████████████████|
Training: 100% |████████████████████████████████████████| Accuracy: 0.82, SoftmaxCrossEntropyLoss: 0.49
Validating: 100% |████████████████████████████████████████|
Training: 100% |████████████████████████████████████████| Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.45
Validating: 100% |████████████████████████████████████████|
Training: 100% |████████████████████████████████████████| Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.41
Validating: 100% |████████████████████████████████████████|
Training: 100% |████████████████████████████████████████| Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.39
Validating: 100% |████████████████████████████████████████|
Training: 100% |████████████████████████████████████████| Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.37
Validating: 100% |████████████████████████████████████████|
Training: 100% |████████████████████████████████████████| Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.36
Validating: 100% |████████████████████████████████████████|
Training: 100% |████████████████████████████████████████| Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.34
Validating: 100% |████████████████████████████████████████|
Training: 100% |████████████████████████████████████████| Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.34
Validating: 100% |████████████████████████████████████████|
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Training on: cpu().
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Load MXNet Engine Version 1.7.0 in 0.029 ms.
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 1 finished.
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.55, SoftmaxCrossEntropyLoss: 1.17
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.76, SoftmaxCrossEntropyLoss: 0.67
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 2 finished.
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.78, SoftmaxCrossEntropyLoss: 0.58
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.79, SoftmaxCrossEntropyLoss: 0.54
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 3 finished.
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.82, SoftmaxCrossEntropyLoss: 0.49
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.75, SoftmaxCrossEntropyLoss: 0.73
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 4 finished.
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.45
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.82, SoftmaxCrossEntropyLoss: 0.49
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 5 finished.
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.41
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.44
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 6 finished.
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.39
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.40
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 7 finished.
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.37
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.78, SoftmaxCrossEntropyLoss: 0.59
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 8 finished.
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.36
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.45
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 9 finished.
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.34
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.39
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 10 finished.
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.34
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.40
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - forward P50: 0.645 ms, P90: 0.987 ms
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - training-metrics P50: 0.016 ms, P90: 0.025 ms
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - backward P50: 0.859 ms, P90: 1.741 ms
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - step P50: 1.172 ms, P90: 1.552 ms
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - epoch P50: 31.315 s, P90: 103.561 s