Table of Contents
1 Concise Implementation of Multilayer Perceptron
(ns clj-d2l.multilayer-perceptron-concise (:require [clojure.java.io :as io] [clj-djl.ndarray :as nd] [clj-djl.nn :as nn] [clj-djl.model :as m] [clj-djl.training :as t] [clj-djl.training.dataset :as ds] [clj-djl.engine :as engine] [clj-djl.training.loss :as loss] [clj-djl.training.tracker :as tracker] [clj-djl.training.optimizer :as optimizer] [clj-djl.training.listener :as listener] [clj-djl.utils :as utils] [clj-d2l.core :as d2l] [com.hypirion.clj-xchart :as c]) (:import [ai.djl.basicdataset FashionMnist] [ai.djl.training.listener TrainingListener$Defaults] [ai.djl.nn Activation]))
1.1 The Model
(def net (-> (nn/sequential-block) (nn/add (nn/batch-flatten-block 784)) (nn/add (-> (nn/new-linear-builder) (nn/set-units 256) (nn/build))) (nn/add (utils/as-function nn/relu)) (nn/add (-> (nn/new-linear-builder) (nn/set-units 10) (nn/build))) (nn/set-initializer (nn/new-normal-initializer)))) (def batch-size 256) (def nepochs 10) (def lr 0.5) (def epoch-loss (atom 0.)) (def accuracy-val (atom 0.)) (def train-loss (atom [])) (def train-accuracy (atom [])) (def test-accuracy (atom [])) (def mnist-train (-> (FashionMnist/builder) (ds/opt-usage :train) (ds/set-sampling batch-size true) (ds/build) (ds/prepare))) (def mnist-test (-> (FashionMnist/builder) (ds/opt-usage :test) (ds/set-sampling batch-size true) (ds/build) (ds/prepare))) (def evaluator-metrics (atom {}))
(def lrt (tracker/fixed 0.5)) (def sgd (-> (optimizer/sgd) (optimizer/set-learning-rate-tracker lrt) (optimizer/build))) (def loss (loss/sotfmax-cross-entropy-loss)) (def config (-> (t/new-training-config loss) (t/opt-optimizer sgd) (t/add-evaluator (t/new-accuracy)) (t/add-training-listeners (TrainingListener$Defaults/logging)))) (def evals (atom nil)) (def mets (atom nil)) (with-open [model (-> (m/new-instance "mlp") (m/set-block net)) trainer (m/new-trainer model config)] (-> trainer (t/initialize [(nd/shape [1 784])]) (t/set-metrics (t/metrics)) (t/fit nepochs mnist-train mnist-test)) (reset! evals (t/get-evaluators trainer)) (reset! mets (t/get-metrics trainer)) (let [metrics (t/get-metrics trainer)] (doseq [evaluator (t/get-evaluators trainer)] (swap! evaluator-metrics assoc (str "train_epoch_" (.getName evaluator)) (map :value (metrics (str "train_epoch_" (.getName evaluator))))) (swap! evaluator-metrics assoc (str "validate_epoch_" (.getName evaluator)) (map :value (metrics (str "validate_epoch_" (.getName evaluator))))))))
(let [x (range 1 (+ nepochs 1))] (-> (c/xy-chart {"test acc" {:x x :y (@evaluator-metrics "validate_epoch_Accuracy") :style {:marker-type :none}} "train acc" {:x x :y (@evaluator-metrics "train_epoch_Accuracy") :style {:marker-type :none}} "train loss" {:x x :y (@evaluator-metrics "train_epoch_SoftmaxCrossEntropyLoss") :style {:marker-type :none}}}) (c/spit "mlp-concise.svg")))