Table of Contents
1 Dropout
(ns clj-d2l.dropout (:require [clojure.spec.alpha :as s] [clj-djl.ndarray :as nd] [clj-djl.training :as t] [clj-djl.training.dataset :as ds] [clj-djl.training.loss :as loss] [clj-djl.training.optimizer :as optimizer] [clj-djl.training.tracker :as tracker] [clj-djl.training.listener :as listener] [clj-djl.model :as m] [clj-djl.nn :as nn] [clj-djl.device :as dev] [clj-d2l.core :as d2l]))
(def ndm (nd/base-manager)) (defn dropout-layer ([^ai.djl.ndarray.NDArray X dropout] {:pre [(s/valid? #(<= 0 % 1) dropout)]} (condp = dropout 1 (nd/zeros-like X) 0 X (-> (nd/random-uniform ndm 0 1 (nd/shape X)) (nd/> dropout) (nd/to-type :float32 false) (nd/* X) (nd// (- 1.0 dropout)))))) (def X (-> (nd/arange ndm 16) (nd/reshape 2 8))) X
ND: (2, 8) cpu() int32 [[ 0, 1, 2, 3, 4, 5, 6, 7], [ 8, 9, 10, 11, 12, 13, 14, 15], ]
(dropout-layer X 0)
ND: (2, 8) cpu() int32 [[ 0, 1, 2, 3, 4, 5, 6, 7], [ 8, 9, 10, 11, 12, 13, 14, 15], ]
(dropout-layer X 1)
ND: (2, 8) cpu() int32 [[ 0, 0, 0, 0, 0, 0, 0, 0], [ 0, 0, 0, 0, 0, 0, 0, 0], ]
(dropout-layer X 0.5)
ND: (2, 8) cpu() float32 [[ 0., 2., 4., 6., 0., 0., 12., 0.], [16., 18., 20., 22., 24., 0., 28., 30.], ]
(defn train [ndm net train-ds test-ds nepochs loss-fn updater-fn] (let [results (atom {:epoch [] :train-loss [] :train-accuracy [] :test-loss [] :test-accuracy []}) acc (atom {:size 0 :loss 0 :accuracy 0})] (doseq [i (range (inc nepochs))] (print "epoch " i ", training...") (swap! results assoc-in [:epoch i] i) (doseq [batch (t/iter-seq (ds/get-data train-ds ndm))] (let [X (-> batch ds/get-batch-data nd/head) y (-> batch ds/get-batch-labels nd/head)] (with-open [gc (-> (t/gradient-collector))] (let [y-hat (net X true) l (loss-fn y-hat y) a (d2l/accuracy y-hat y)] (.backward gc l) (swap! acc update :size + (nd/size y)) (swap! acc update :loss + (nd/get-element (nd/sum l))) (swap! acc update :accuracy + a)))) (updater-fn) (.close batch)) (swap! results assoc-in [:train-loss i] (/ (@acc :loss) (@acc :size))) (swap! results assoc-in [:train-accuracy i] (/ (@acc :accuracy) (@acc :size))) (reset! acc {:size 0 :loss 0 :accuracy 0}) (println "validating...") (doseq [batch (t/iter-seq (ds/get-data test-ds ndm))] (let [X (-> batch ds/get-batch-data nd/head) y (-> batch ds/get-batch-labels nd/head)] (let [y-hat (net X false) l (loss-fn y-hat y) a (d2l/accuracy y-hat y)] (swap! acc update :size + (nd/size y)) (swap! acc update :loss + (nd/get-element (nd/sum l))) (swap! acc update :accuracy + a)))) (swap! results assoc-in [:test-loss i] (/ (@acc :loss) (@acc :size))) (swap! results assoc-in [:test-accuracy i] (/ (@acc :accuracy) (@acc :size)))) @results))
(defn softmax [ndarray] (let [Xexp (nd/exp ndarray) partition (nd/sum Xexp [1] true)] (nd// Xexp partition))) (defn softmax-cross-entropy [y-hat y] (-> (softmax y-hat) (nd/get ":, {}" (nd/to-type y :int32 false)) (nd/log) (nd/-))) (defn make-net [ninputs W1 b1 W2 b2 W3 b3 dropout1 dropout2] (fn [X training?] (-> X (nd/reshape [-1 ninputs]) (nd/dot W1) (nd/+ b1) (nn/relu) ((fn [layer] (if training? (dropout-layer layer dropout1) layer))) (nd/dot W2) (nd/+ b2) (nn/relu) ((fn [layer] (if training? (dropout-layer layer dropout2) layer))) (nd/dot W3) (nd/+ b3))))
(defn do-train [nepochs lr dropout1 dropout2 output] (let [batchsize 256 [ninputs noutputs nhiddens1 nhiddens2] [784 10 256 256] W1 (nd/random-normal ndm 0 0.01 [ninputs nhiddens1]) b1 (nd/zeros ndm nhiddens1) W2 (nd/random-normal ndm 0 0.01 [nhiddens1 nhiddens2]) b2 (nd/zeros ndm nhiddens2) W3 (nd/random-normal ndm 0 0.01 [nhiddens2 noutputs]) b3 (nd/zeros ndm noutputs) params [W1 b1 W2 b2 W3 b3] _ (run! nd/attach-gradient params) dataset (d2l/load-data-fashion-mnist batchsize) results (train ndm (make-net ninputs W1 b1 W2 b2 W3 b3 dropout1 dropout2) (dataset 0) (dataset 1) nepochs softmax-cross-entropy #(d2l/sgd params lr batchsize))] (d2l/plot-lines output ["train loss" "train acc" "test acc"] (results :epoch) [(results :train-loss) (results :train-accuracy) (results :test-accuracy)])))
(setq org-babel-clojure-sync-nrepl-timeout 1000)
1000
(do-train 10 0.4 0.5 0.5 "figure/dropout_11.svg")
epoch 0 , training...validating... epoch 1 , training...validating... epoch 2 , training...validating... epoch 3 , training...validating... epoch 4 , training...validating... epoch 5 , training...validating... epoch 6 , training...validating... epoch 7 , training...validating... epoch 8 , training...validating... epoch 9 , training...validating... epoch 10 , training...validating...
(do-train 1 0.4 0.1 0.5 "figure/dropout_12.svg")
epoch 0 , training...validating... epoch 1 , training...validating...
(do-train 10 0.4 0.5 0.1 "figure/dropout_13.svg")
(do-train 10 0.4 0.1 0.1 "figure/dropout_14.svg")
(do-train 10 0.1 0.5 0.5 "figure/dropout_21.svg")
(do-train 10 0.1 0.1 0.5 "figure/dropout_22.svg")
(do-train 10 0.1 0.5 0.1 "figure/dropout_23.svg")
(do-train 10 0.1 0.1 0.1 "figure/dropout_24.svg")
(defn do-train-concise [nepochs lr dropout1 dropout2 output] (let [batchsize 256 [ninputs noutputs nhiddens1 nhiddens2] [784 10 256 256] dataset (d2l/load-data-fashion-mnist batchsize) net (-> (nn/sequential-block) (nn/add (nn/batch-flatten-block ninputs)) (nn/add (nn/linear-block {:units nhiddens1})) (nn/add nn/relu) (nn/add (nn/dropout {:rate dropout1})) (nn/add (nn/linear-block {:units nhiddens2})) (nn/add nn/relu) (nn/add (nn/dropout {:rate dropout2})) (nn/add (nn/linear-block {:units noutputs})) (nn/set-initializer (nn/normal-initializer))) opt (optimizer/sgd {:tracker (tracker/fixed 0.5)}) loss (loss/sotfmax-cross-entropy-loss) config (t/default-training-config {:loss loss :optimizer opt :evaluator (t/accuracy) :listeners (listener/logging)})] (with-open [model (m/model {:name "mlp" :block net}) trainer (t/trainer model config)] (t/initialize trainer (nd/shape [1 784])) (t/set-metrics trainer (t/metrics)) (t/fit trainer nepochs (dataset 0) (dataset 1)) (let [metrics (t/get-metrics trainer)] (d2l/plot-lines output ["train loss" "validate loss" "train accuracy"] (range nepochs) [(map :value (metrics "train_epoch_SoftmaxCrossEntropyLoss")) (map :value (metrics "validate_epoch_SoftmaxCrossEntropyLoss")) (map :value (metrics "validate_epoch_Accuracy"))])))))
(do-train-concise 10 0.1 0.1 0.5 "figure/dropout_41.svg")
Training: 100% |████████████████████████████████████████| Accuracy: 0.55, SoftmaxCrossEntropyLoss: 1.17 Validating: 100% |████████████████████████████████████████| Training: 100% |████████████████████████████████████████| Accuracy: 0.78, SoftmaxCrossEntropyLoss: 0.58 Validating: 100% |████████████████████████████████████████| Training: 100% |████████████████████████████████████████| Accuracy: 0.82, SoftmaxCrossEntropyLoss: 0.49 Validating: 100% |████████████████████████████████████████| Training: 100% |████████████████████████████████████████| Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.45 Validating: 100% |████████████████████████████████████████| Training: 100% |████████████████████████████████████████| Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.41 Validating: 100% |████████████████████████████████████████| Training: 100% |████████████████████████████████████████| Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.39 Validating: 100% |████████████████████████████████████████| Training: 100% |████████████████████████████████████████| Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.37 Validating: 100% |████████████████████████████████████████| Training: 100% |████████████████████████████████████████| Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.36 Validating: 100% |████████████████████████████████████████| Training: 100% |████████████████████████████████████████| Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.34 Validating: 100% |████████████████████████████████████████| Training: 100% |████████████████████████████████████████| Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.34 Validating: 100% |████████████████████████████████████████| [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Training on: cpu(). [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Load MXNet Engine Version 1.7.0 in 0.029 ms. [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 1 finished. [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.55, SoftmaxCrossEntropyLoss: 1.17 [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.76, SoftmaxCrossEntropyLoss: 0.67 [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 2 finished. [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.78, SoftmaxCrossEntropyLoss: 0.58 [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.79, SoftmaxCrossEntropyLoss: 0.54 [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 3 finished. [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.82, SoftmaxCrossEntropyLoss: 0.49 [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.75, SoftmaxCrossEntropyLoss: 0.73 [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 4 finished. [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.45 [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.82, SoftmaxCrossEntropyLoss: 0.49 [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 5 finished. [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.41 [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.44 [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 6 finished. [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.39 [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.40 [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 7 finished. [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.37 [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.78, SoftmaxCrossEntropyLoss: 0.59 [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 8 finished. [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.36 [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.45 [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 9 finished. [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.34 [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.39 [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 10 finished. [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.34 [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.40 [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - forward P50: 0.645 ms, P90: 0.987 ms [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - training-metrics P50: 0.016 ms, P90: 0.025 ms [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - backward P50: 0.859 ms, P90: 1.741 ms [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - step P50: 1.172 ms, P90: 1.552 ms [nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - epoch P50: 31.315 s, P90: 103.561 s