Table of Contents
1 Implementation of Multilayer Perceptron from Scratch
(ns clj-d2l.multilayer-perceptron-scratch (:require [ :as io] [clj-djl.ndarray :as nd] [clj-djl.nn :as nn] [clj-djl.model :as m] [ :as t] [ :as ds] [clj-djl.engine :as engine] [ :as loss] [ :as tracker] [ :as optimizer] [ :as listener] [clj-d2l.core :as d2l] [com.hypirion.clj-xchart :as c]) (:import [ai.djl.basicdataset FashionMnist]))
(def batch-size 256) (def mnist-train (-> (FashionMnist/builder) (ds/opt-usage :train) (ds/set-sampling batch-size true) (ds/build) (ds/prepare))) (def mnist-test (-> (FashionMnist/builder) (ds/opt-usage :test) (ds/set-sampling batch-size true) (ds/build) (ds/prepare)))
1.1 Initializing Model Parameters
(def ninputs 784) (def noutputs 10) (def nhiddens 256) (def manager (nd/base-manager)) (def W1 (nd/random-normal manager 0 0.01 [ninputs nhiddens] :float32 (nd/default-device))) (def b1 (nd/zeros manager [nhiddens])) (def W2 (nd/random-normal manager 0 0.01 [nhiddens noutputs] :float32 (nd/default-device))) (def b2 (nd/zeros manager [noutputs])) (def params [W1 b1 W2 b2]) (dorun (map nd/attach-gradient params))
1.2 Activation Function
(defn relu [X] (.maximum X 0.))
1.3 The model
(defn net [X] (let [X (nd/reshape X [-1 ninputs]) H (-> X (nd/dot W1) (nd/+ b1) relu)] (-> H (nd/dot W2) (nd/+ b2))))
1.4 The Loss Function
(def loss (loss/sotfmax-cross-entropy-loss))
1.5 Training
(def nepochs 10) (def lr 0.5) (def epoch-loss (atom 0.)) (def accuracy-val (atom 0.)) (def train-loss (atom [])) (def train-accuracy (atom [])) (def test-accuracy (atom [])) (doseq [epoch (range 1 (+ nepochs 1))] (print "Running epoch " epoch "......") (doseq [batch (ds/get-data-iterator mnist-train manager)] (let [X (nd/head (ds/get-batch-data batch)) y (nd/head (ds/get-batch-labels batch))] (with-open [gc (t/gradient-collector) yhat (net X) lossvalue (.evaluate loss (nd/ndlist [y]) (nd/ndlist [yhat])) l (nd/* lossvalue batch-size)] (swap! epoch-loss + (nd/get-element(nd/sum l))) (swap! accuracy-val + (d2l/accuracy yhat y)) (.backward gc l)) (ds/close-batch batch) (d2l/sgd params lr batch-size))) (swap! train-loss conj (/ @epoch-loss (nd/size mnist-train))) (swap! train-accuracy conj (/ @accuracy-val (nd/size mnist-train))) (reset! epoch-loss 0.) (reset! accuracy-val 0.) (doseq [batch (ds/get-data-iterator mnist-test manager)] (let [X (nd/head (ds/get-batch-data batch)) y (nd/head (ds/get-batch-labels batch)) yhat (net X)] (swap! accuracy-val + (d2l/accuracy yhat y)))) (swap! test-accuracy conj (/ @accuracy-val (nd/size mnist-test))) (reset! accuracy-val 0.) (println "Finished epoch " epoch))
Running epoch 1 ......Finished epoch 1 Running epoch 2 ......Finished epoch 2 Running epoch 3 ......Finished epoch 3 Running epoch 4 ......Finished epoch 4 Running epoch 5 ......Finished epoch 5 Running epoch 6 ......Finished epoch 6 Running epoch 7 ......Finished epoch 7 Running epoch 8 ......Finished epoch 8 Running epoch 9 ......Finished epoch 9 Running epoch 10 ......Finished epoch 10
(let [x (range 1 (+ nepochs 1))] (-> (c/xy-chart {"test acc" {:x x :y @test-accuracy :style {:marker-type :none}} "train acc" {:x x :y @train-accuracy :style {:marker-type :none}} "train loss" {:x x :y @train-loss :style {:marker-type :none}}}) (c/spit "figure/mlp-scratch.svg")))