UP | HOME

Table of Contents

1 Concise Implementation of Softmax Regression

(ns clj-d2l.concise-softmax-reg
  (:require [clojure.java.io :as io]
            [clj-djl.ndarray :as nd]
            [clj-djl.device :as device]
            [clj-djl.engine :as engine]
            [clj-djl.training.dataset :as ds]
            [clj-djl.model :as model]
            [clj-djl.nn :as nn]
            [clj-djl.training.loss :as loss]
            [clj-djl.training.tracker :as tracker]
            [clj-djl.training.optimizer :as optimizer]
            [clj-djl.training :as training]
            [clj-djl.training.listener :as listener])
  (:import [ai.djl.basicdataset FashionMnist]
           [ai.djl.training.dataset Dataset$Usage]
           [java.nio.file Paths]))
(def batch-size 256)
(def random-shuffle true)

(def mnist-train (-> (FashionMnist/builder)
                     (ds/opt-usage Dataset$Usage/TRAIN)
                     (ds/set-sampling batch-size random-shuffle)
                     (ds/build)
                     (ds/prepare)))

(def mnist-test (-> (FashionMnist/builder)
                    (ds/opt-usage Dataset$Usage/TEST)
                    (ds/set-sampling batch-size random-shuffle)
                    (ds/build)
                    (ds/prepare)))

1.1 Initializing Model Parameters

(defn softmax [arrays]
  (nd/ndlist (nd/log-softmax (nd/singleton-or-throw arrays) 1)))


(def manager (nd/base-manager))

(def model (model/instance "softmax-regression"))

(def net (nn/sequential-block))

(-> net
    (nn/add (nn/batch-flatten-block (* 28 28)))
    (nn/add (-> (nn/linear-builder) (nn/set-units 10) (nn/build))))

(model/set-block model net)

1.2 The Softmax

(def loss (loss/sotfmax-cross-entropy-loss))

1.3 Optimization Algorithm

(def lrt (tracker/fixed 0.1))
(def sgd (-> (optimizer/sgd) (optimizer/set-learning-rate-tracker lrt) (optimizer/build)))

1.4 Instantiate Configuration and Trainer

(def config (training/training-config {:loss loss
                                       :optimizer sgd
                                       :accuracy (training/accuracy)
                                       :listeners (listener/logging)}))

(def trainer (model/trainer model config))

1.5 Initializing Trainer

(training/initialize trainer [(nd/shape [1 (* 28 28)])])

1.6 Metrics

(def metrics (training/metrics))
(training/set-metrics trainer metrics)

1.7 Training

(setq org-babel-clojure-sync-nrepl-timeout 1000)
1000
(def num-epochs 5)
(training/fit trainer num-epochs mnist-train mnist-test)

Training:    100% |████████████████████████████████████████| Accuracy: 0.74, SoftmaxCrossEntropyLoss: 0.80

Validating:  100% |████████████████████████████████████████|

Training:    100% |████████████████████████████████████████| Accuracy: 0.81, SoftmaxCrossEntropyLoss: 0.57

Validating:  100% |████████████████████████████████████████|

Training:    100% |████████████████████████████████████████| Accuracy: 0.82, SoftmaxCrossEntropyLoss: 0.53

Validating:  100% |████████████████████████████████████████|

Training:    100% |████████████████████████████████████████| Accuracy: 0.83, SoftmaxCrossEntropyLoss: 0.50

Validating:  100% |████████████████████████████████████████|

Training:    100% |████████████████████████████████████████| Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.49

Validating:  100% |████████████████████████████████████████|
[Finalizer] INFO ai.djl.training.listener.LoggingTrainingListener - forward P50: 0.405 ms, P90: 0.861 ms
[Finalizer] INFO ai.djl.training.listener.LoggingTrainingListener - training-metrics P50: 0.037 ms, P90: 0.051 ms
[Finalizer] INFO ai.djl.training.listener.LoggingTrainingListener - backward P50: 1.054 ms, P90: 1.706 ms
[Finalizer] INFO ai.djl.training.listener.LoggingTrainingListener - step P50: 0.498 ms, P90: 0.753 ms
[Finalizer] INFO ai.djl.training.listener.LoggingTrainingListener - epoch P50: 55.685 s, P90: 63.680 s
[Finalizer] WARN ai.djl.BaseModel - Model: softmax-regression was not closed explicitly.
[nREPL-session-37be5dc0-b384-4058-91e6-e80169eef8ef] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 1 finished.
[nREPL-session-37be5dc0-b384-4058-91e6-e80169eef8ef] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.74, SoftmaxCrossEntropyLoss: 0.80
[nREPL-session-37be5dc0-b384-4058-91e6-e80169eef8ef] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.79, SoftmaxCrossEntropyLoss: 0.63
[nREPL-session-37be5dc0-b384-4058-91e6-e80169eef8ef] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 2 finished.
[nREPL-session-37be5dc0-b384-4058-91e6-e80169eef8ef] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.81, SoftmaxCrossEntropyLoss: 0.57
[nREPL-session-37be5dc0-b384-4058-91e6-e80169eef8ef] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.81, SoftmaxCrossEntropyLoss: 0.57
[nREPL-session-37be5dc0-b384-4058-91e6-e80169eef8ef] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 3 finished.
[nREPL-session-37be5dc0-b384-4058-91e6-e80169eef8ef] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.82, SoftmaxCrossEntropyLoss: 0.53
[nREPL-session-37be5dc0-b384-4058-91e6-e80169eef8ef] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.82, SoftmaxCrossEntropyLoss: 0.54
[nREPL-session-37be5dc0-b384-4058-91e6-e80169eef8ef] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 4 finished.
[nREPL-session-37be5dc0-b384-4058-91e6-e80169eef8ef] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.83, SoftmaxCrossEntropyLoss: 0.50
[nREPL-session-37be5dc0-b384-4058-91e6-e80169eef8ef] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.82, SoftmaxCrossEntropyLoss: 0.53
[nREPL-session-37be5dc0-b384-4058-91e6-e80169eef8ef] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 5 finished.
[nREPL-session-37be5dc0-b384-4058-91e6-e80169eef8ef] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.49
[nREPL-session-37be5dc0-b384-4058-91e6-e80169eef8ef] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.82, SoftmaxCrossEntropyLoss: 0.51
(training/get-training-result trainer)
:epoch 10 :train-loss 0.44829798 :validate-loss 0.48035732

Created: 2021-04-11 Sun 20:59