UP | HOME

Table of Contents

1 Concise Implementation of Multilayer Perceptron

(ns clj-d2l.multilayer-perceptron-concise
  (:require [clojure.java.io :as io]
            [clj-djl.ndarray :as nd]
            [clj-djl.nn :as nn]
            [clj-djl.model :as m]
            [clj-djl.training :as t]
            [clj-djl.training.dataset :as ds]
            [clj-djl.engine :as engine]
            [clj-djl.training.loss :as loss]
            [clj-djl.training.tracker :as tracker]
            [clj-djl.training.optimizer :as optimizer]
            [clj-djl.training.listener :as listener]
            [clj-djl.utils :as utils]
            [clj-d2l.core :as d2l]
            [com.hypirion.clj-xchart :as c])
  (:import [ai.djl.basicdataset FashionMnist]
           [ai.djl.training.listener TrainingListener$Defaults]
           [ai.djl.nn Activation]))

1.1 The Model

(def net (-> (nn/sequential-block)
             (nn/add (nn/batch-flatten-block 784))
             (nn/add (-> (nn/new-linear-builder)
                         (nn/set-units 256)
                         (nn/build)))
             (nn/add (utils/as-function nn/relu))
             (nn/add (-> (nn/new-linear-builder)
                         (nn/set-units 10)
                         (nn/build)))
             (nn/set-initializer (nn/new-normal-initializer))))

(def batch-size 256)
(def nepochs 10)
(def lr 0.5)

(def epoch-loss (atom 0.))
(def accuracy-val (atom 0.))
(def train-loss (atom []))
(def train-accuracy (atom []))
(def test-accuracy (atom []))

(def mnist-train (-> (FashionMnist/builder)
                     (ds/opt-usage :train)
                     (ds/set-sampling batch-size true)
                     (ds/build)
                     (ds/prepare)))

(def mnist-test (-> (FashionMnist/builder)
                    (ds/opt-usage :test)
                    (ds/set-sampling batch-size true)
                    (ds/build)
                    (ds/prepare)))

(def evaluator-metrics (atom {}))
(def lrt (tracker/fixed 0.5))
(def sgd (-> (optimizer/sgd)
             (optimizer/set-learning-rate-tracker lrt)
             (optimizer/build)))
(def loss (loss/sotfmax-cross-entropy-loss))

(def config (-> (t/new-training-config loss)
                (t/opt-optimizer sgd)
                (t/add-evaluator (t/new-accuracy))
                (t/add-training-listeners (TrainingListener$Defaults/logging))))

(def evals (atom nil))
(def mets (atom nil))


(with-open [model (-> (m/new-instance "mlp")
                      (m/set-block net))
            trainer (m/new-trainer model config)]
  (-> trainer
      (t/initialize [(nd/shape [1 784])])
      (t/set-metrics (t/metrics))
      (t/fit nepochs mnist-train mnist-test))
  (reset! evals (t/get-evaluators trainer))
  (reset! mets (t/get-metrics trainer))
  (let [metrics (t/get-metrics trainer)]
    (doseq [evaluator (t/get-evaluators trainer)]
      (swap! evaluator-metrics
             assoc (str "train_epoch_" (.getName evaluator))
             (map :value (metrics (str "train_epoch_" (.getName evaluator)))))
      (swap! evaluator-metrics
             assoc (str "validate_epoch_" (.getName evaluator))
             (map :value (metrics (str "validate_epoch_" (.getName evaluator))))))))
(let [x (range 1 (+ nepochs 1))]
  (-> (c/xy-chart
       {"test acc"
        {:x x
         :y (@evaluator-metrics "validate_epoch_Accuracy")
         :style {:marker-type :none}}
        "train acc"
        {:x x
         :y (@evaluator-metrics "train_epoch_Accuracy")
         :style {:marker-type :none}}
        "train loss"
        {:x x
         :y (@evaluator-metrics "train_epoch_SoftmaxCrossEntropyLoss")
         :style {:marker-type :none}}})
      (c/spit "mlp-concise.svg")))

Sorry, your browser does not support SVG.

Created: 2021-04-11 Sun 20:59