UP | HOME

Table of Contents

1 Dropout

(ns clj-d2l.dropout
  (:require
   [clojure.spec.alpha :as s]
   [clj-djl.ndarray :as nd]
   [clj-djl.training :as t]
   [clj-djl.training.dataset :as ds]
   [clj-djl.training.loss :as loss]
   [clj-djl.training.optimizer :as optimizer]
   [clj-djl.training.tracker :as tracker]
   [clj-djl.training.listener :as listener]
   [clj-djl.model :as m]
   [clj-djl.nn :as nn]
   [clj-djl.device :as dev]
   [clj-d2l.core :as d2l]))
(def ndm (nd/base-manager))
(defn dropout-layer
  ([^ai.djl.ndarray.NDArray X dropout]
   {:pre [(s/valid? #(<= 0 % 1) dropout)]}
   (condp = dropout
     1 (nd/zeros-like X)
     0 X
     (-> (nd/random-uniform ndm 0 1 (nd/shape X))
         (nd/> dropout)
         (nd/to-type :float32 false)
         (nd/* X)
         (nd// (- 1.0 dropout))))))

(def X (-> (nd/arange ndm 16)
           (nd/reshape 2 8)))
X
ND: (2, 8) cpu() int32
[[ 0,  1,  2,  3,  4,  5,  6,  7],
 [ 8,  9, 10, 11, 12, 13, 14, 15],
]
(dropout-layer X 0)
ND: (2, 8) cpu() int32
[[ 0,  1,  2,  3,  4,  5,  6,  7],
 [ 8,  9, 10, 11, 12, 13, 14, 15],
]
(dropout-layer X 1)
ND: (2, 8) cpu() int32
[[ 0,  0,  0,  0,  0,  0,  0,  0],
 [ 0,  0,  0,  0,  0,  0,  0,  0],
]
(dropout-layer X 0.5)
ND: (2, 8) cpu() float32
[[ 0.,  2.,  4.,  6.,  0.,  0., 12.,  0.],
 [16., 18., 20., 22., 24.,  0., 28., 30.],
]
(defn train [ndm net train-ds test-ds nepochs loss-fn updater-fn]
  (let [results (atom {:epoch []
                       :train-loss []
                       :train-accuracy []
                       :test-loss []
                       :test-accuracy []})
        acc (atom {:size 0
                   :loss 0
                   :accuracy 0})]
    (doseq [i (range (inc nepochs))]
      (print "epoch " i ", training...")
      (swap! results assoc-in [:epoch i] i)
      (doseq [batch (t/iter-seq (ds/get-data train-ds ndm))]
        (let [X (-> batch ds/get-batch-data nd/head)
              y (-> batch ds/get-batch-labels nd/head)]
          (with-open [gc (-> (t/gradient-collector))]
            (let [y-hat (net X true)
                  l (loss-fn y-hat y)
                  a (d2l/accuracy y-hat y)]
              (.backward gc l)
              (swap! acc update :size + (nd/size y))
              (swap! acc update :loss + (nd/get-element (nd/sum l)))
              (swap! acc update :accuracy +  a))))
        (updater-fn)
        (.close batch))

      (swap! results assoc-in [:train-loss i] (/ (@acc :loss) (@acc :size)))
      (swap! results assoc-in [:train-accuracy i] (/ (@acc :accuracy) (@acc :size)))
      (reset! acc {:size 0 :loss 0 :accuracy 0})

      (println "validating...")
      (doseq [batch (t/iter-seq (ds/get-data test-ds ndm))]
        (let [X (-> batch ds/get-batch-data nd/head)
              y (-> batch ds/get-batch-labels nd/head)]
          (let [y-hat (net X false)
                l (loss-fn y-hat y)
                a (d2l/accuracy y-hat y)]
            (swap! acc update :size + (nd/size y))
            (swap! acc update :loss + (nd/get-element (nd/sum l)))
            (swap! acc update :accuracy +  a))))
      (swap! results assoc-in [:test-loss i] (/ (@acc :loss) (@acc :size)))
      (swap! results assoc-in [:test-accuracy i] (/ (@acc :accuracy) (@acc :size))))
    @results))
(defn softmax [ndarray]
  (let [Xexp (nd/exp ndarray)
        partition (nd/sum Xexp [1] true)]
    (nd// Xexp partition)))

(defn softmax-cross-entropy [y-hat y]
  (-> (softmax y-hat)
      (nd/get ":, {}" (nd/to-type y :int32 false))
      (nd/log)
      (nd/-)))

(defn make-net [ninputs W1 b1 W2 b2 W3 b3 dropout1 dropout2]
  (fn [X training?]
      (-> X
          (nd/reshape [-1 ninputs])
          (nd/dot W1) (nd/+ b1) (nn/relu)
          ((fn [layer]
             (if training?
               (dropout-layer layer dropout1)
               layer)))
          (nd/dot W2) (nd/+ b2) (nn/relu)
          ((fn [layer]
             (if training?
               (dropout-layer layer dropout2)
               layer)))
          (nd/dot W3) (nd/+ b3))))
(defn do-train [nepochs lr dropout1 dropout2 output]
  (let [batchsize 256
        [ninputs noutputs nhiddens1 nhiddens2] [784 10 256 256]
        W1 (nd/random-normal ndm 0 0.01 [ninputs nhiddens1])
        b1 (nd/zeros ndm nhiddens1)
        W2 (nd/random-normal ndm 0 0.01 [nhiddens1 nhiddens2])
        b2 (nd/zeros ndm nhiddens2)
        W3 (nd/random-normal ndm 0 0.01 [nhiddens2 noutputs])
        b3 (nd/zeros ndm noutputs)
        params [W1 b1 W2 b2 W3 b3]
        _ (run! nd/attach-gradient params)
        dataset (d2l/load-data-fashion-mnist batchsize)
        results (train ndm (make-net ninputs W1 b1 W2 b2 W3 b3 dropout1 dropout2)
                       (dataset 0) (dataset 1)
                       nepochs softmax-cross-entropy #(d2l/sgd params lr batchsize))]
    (d2l/plot-lines output
                    ["train loss" "train acc" "test acc"]
                    (results :epoch)
                    [(results :train-loss) (results :train-accuracy) (results :test-accuracy)])))
(setq org-babel-clojure-sync-nrepl-timeout 1000)
1000
(do-train 10 0.4 0.5 0.5 "figure/dropout_11.svg")
epoch  0 , training...validating...
epoch  1 , training...validating...
epoch  2 , training...validating...
epoch  3 , training...validating...
epoch  4 , training...validating...
epoch  5 , training...validating...
epoch  6 , training...validating...
epoch  7 , training...validating...
epoch  8 , training...validating...
epoch  9 , training...validating...
epoch  10 , training...validating...

Sorry, your browser does not support SVG.

(do-train 1 0.4 0.1 0.5 "figure/dropout_12.svg")
epoch  0 , training...validating...
epoch  1 , training...validating...

Sorry, your browser does not support SVG.

(do-train 10 0.4 0.5 0.1 "figure/dropout_13.svg")

Sorry, your browser does not support SVG.

(do-train 10 0.4 0.1 0.1 "figure/dropout_14.svg")

Sorry, your browser does not support SVG.

(do-train 10 0.1 0.5 0.5 "figure/dropout_21.svg")

Sorry, your browser does not support SVG.

(do-train 10 0.1 0.1 0.5 "figure/dropout_22.svg")

Sorry, your browser does not support SVG.

(do-train 10 0.1 0.5 0.1 "figure/dropout_23.svg")

Sorry, your browser does not support SVG.

(do-train 10 0.1 0.1 0.1 "figure/dropout_24.svg")

Sorry, your browser does not support SVG.

(defn do-train-concise [nepochs lr dropout1 dropout2 output]
  (let [batchsize 256
        [ninputs noutputs nhiddens1 nhiddens2] [784 10 256 256]
        dataset (d2l/load-data-fashion-mnist batchsize)
        net (-> (nn/sequential-block)
                (nn/add (nn/batch-flatten-block ninputs))
                (nn/add (nn/linear-block {:units nhiddens1}))
                (nn/add nn/relu)
                (nn/add (nn/dropout {:rate dropout1}))
                (nn/add (nn/linear-block {:units nhiddens2}))
                (nn/add nn/relu)
                (nn/add (nn/dropout {:rate dropout2}))
                (nn/add (nn/linear-block {:units noutputs}))
                (nn/set-initializer (nn/normal-initializer)))
        opt (optimizer/sgd {:tracker (tracker/fixed 0.5)})
        loss (loss/sotfmax-cross-entropy-loss)
        config (t/default-training-config {:loss loss :optimizer opt
                                           :evaluator (t/accuracy)
                                           :listeners (listener/logging)})]
    (with-open [model (m/model {:name "mlp" :block net})
                trainer (t/trainer model config)]
      (t/initialize trainer (nd/shape [1 784]))
      (t/set-metrics trainer (t/metrics))
      (t/fit trainer nepochs (dataset 0) (dataset 1))
      (let [metrics (t/get-metrics trainer)]
        (d2l/plot-lines output
                        ["train loss" "validate loss" "train accuracy"]
                        (range nepochs)
                        [(map :value (metrics "train_epoch_SoftmaxCrossEntropyLoss"))
                         (map :value (metrics "validate_epoch_SoftmaxCrossEntropyLoss"))
                         (map :value (metrics "validate_epoch_Accuracy"))])))))
(do-train-concise 10 0.1 0.1 0.5 "figure/dropout_41.svg")
Training:    100% |████████████████████████████████████████| Accuracy: 0.55, SoftmaxCrossEntropyLoss: 1.17

Validating:  100% |████████████████████████████████████████|

Training:    100% |████████████████████████████████████████| Accuracy: 0.78, SoftmaxCrossEntropyLoss: 0.58

Validating:  100% |████████████████████████████████████████|

Training:    100% |████████████████████████████████████████| Accuracy: 0.82, SoftmaxCrossEntropyLoss: 0.49

Validating:  100% |████████████████████████████████████████|

Training:    100% |████████████████████████████████████████| Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.45

Validating:  100% |████████████████████████████████████████|

Training:    100% |████████████████████████████████████████| Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.41

Validating:  100% |████████████████████████████████████████|

Training:    100% |████████████████████████████████████████| Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.39

Validating:  100% |████████████████████████████████████████|

Training:    100% |████████████████████████████████████████| Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.37

Validating:  100% |████████████████████████████████████████|

Training:    100% |████████████████████████████████████████| Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.36

Validating:  100% |████████████████████████████████████████|

Training:    100% |████████████████████████████████████████| Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.34

Validating:  100% |████████████████████████████████████████|

Training:    100% |████████████████████████████████████████| Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.34

Validating:  100% |████████████████████████████████████████|

[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Training on: cpu().
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Load MXNet Engine Version 1.7.0 in 0.029 ms.
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 1 finished.
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.55, SoftmaxCrossEntropyLoss: 1.17
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.76, SoftmaxCrossEntropyLoss: 0.67
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 2 finished.
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.78, SoftmaxCrossEntropyLoss: 0.58
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.79, SoftmaxCrossEntropyLoss: 0.54
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 3 finished.
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.82, SoftmaxCrossEntropyLoss: 0.49
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.75, SoftmaxCrossEntropyLoss: 0.73
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 4 finished.
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.45
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.82, SoftmaxCrossEntropyLoss: 0.49
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 5 finished.
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.41
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.44
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 6 finished.
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.39
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.40
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 7 finished.
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.37
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.78, SoftmaxCrossEntropyLoss: 0.59
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 8 finished.
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.36
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.45
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 9 finished.
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.34
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.39
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 10 finished.
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.34
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.40
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - forward P50: 0.645 ms, P90: 0.987 ms
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - training-metrics P50: 0.016 ms, P90: 0.025 ms
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - backward P50: 0.859 ms, P90: 1.741 ms
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - step P50: 1.172 ms, P90: 1.552 ms
[nREPL-session-c4c3b9d9-6684-4cfd-bbcb-c3cd6ed5fa99] INFO ai.djl.training.listener.LoggingTrainingListener - epoch P50: 31.315 s, P90: 103.561 s

Sorry, your browser does not support SVG.

Created: 2021-04-11 Sun 20:59