UP | HOME

Table of Contents

1 Implementation of Multilayer Perceptron from Scratch

(ns clj-d2l.multilayer-perceptron-scratch
  (:require [clojure.java.io :as io]
            [clj-djl.ndarray :as nd]
            [clj-djl.nn :as nn]
            [clj-djl.model :as m]
            [clj-djl.training :as t]
            [clj-djl.training.dataset :as ds]
            [clj-djl.engine :as engine]
            [clj-djl.training.loss :as loss]
            [clj-djl.training.tracker :as tracker]
            [clj-djl.training.optimizer :as optimizer]
            [clj-djl.training.listener :as listener]
            [clj-d2l.core :as d2l]
            [com.hypirion.clj-xchart :as c])
  (:import [ai.djl.basicdataset FashionMnist]))
(def batch-size 256)
(def mnist-train (-> (FashionMnist/builder)
                     (ds/opt-usage :train)
                     (ds/set-sampling batch-size true)
                     (ds/build)
                     (ds/prepare)))

(def mnist-test (-> (FashionMnist/builder)
                    (ds/opt-usage :test)
                    (ds/set-sampling batch-size true)
                    (ds/build)
                    (ds/prepare)))

1.1 Initializing Model Parameters

(def ninputs 784)
(def noutputs 10)
(def nhiddens 256)
(def manager (nd/base-manager))
(def W1 (nd/random-normal manager 0 0.01 [ninputs nhiddens] :float32 (nd/default-device)))
(def b1 (nd/zeros manager [nhiddens]))
(def W2 (nd/random-normal manager 0 0.01 [nhiddens noutputs] :float32 (nd/default-device)))
(def b2 (nd/zeros manager [noutputs]))
(def params [W1 b1 W2 b2])
(dorun (map nd/attach-gradient params))

1.2 Activation Function

(defn relu [X]
  (.maximum X 0.))

1.3 The model

(defn net [X]
  (let [X (nd/reshape X [-1 ninputs])
        H (-> X (nd/dot W1) (nd/+ b1) relu)]
    (-> H (nd/dot W2) (nd/+ b2))))

1.4 The Loss Function

(def loss (loss/sotfmax-cross-entropy-loss))

1.5 Training

(def nepochs 10)
(def lr 0.5)

(def epoch-loss (atom 0.))
(def accuracy-val (atom 0.))
(def train-loss (atom []))
(def train-accuracy (atom []))
(def test-accuracy (atom []))

(doseq [epoch (range 1 (+ nepochs 1))]
  (print "Running epoch " epoch "......")
  (doseq [batch (ds/get-data-iterator mnist-train manager)]
    (let [X (nd/head (ds/get-batch-data batch))
          y (nd/head (ds/get-batch-labels batch))]
      (with-open [gc (t/gradient-collector)
                  yhat (net X)
                  lossvalue (.evaluate loss (nd/ndlist [y]) (nd/ndlist [yhat]))
                  l (nd/* lossvalue batch-size)]
        (swap! epoch-loss + (nd/get-element(nd/sum l)))
        (swap! accuracy-val + (d2l/accuracy yhat y))
        (.backward gc l))
      (ds/close-batch batch)
      (d2l/sgd params lr batch-size)))
  (swap! train-loss conj (/ @epoch-loss (nd/size mnist-train)))
  (swap! train-accuracy conj (/ @accuracy-val (nd/size mnist-train)))

  (reset! epoch-loss 0.)
  (reset! accuracy-val 0.)

  (doseq [batch (ds/get-data-iterator mnist-test manager)]
    (let [X (nd/head (ds/get-batch-data batch))
          y (nd/head (ds/get-batch-labels batch))
          yhat (net X)]
      (swap! accuracy-val + (d2l/accuracy yhat y))))
  (swap! test-accuracy conj (/ @accuracy-val (nd/size mnist-test)))
  (reset! accuracy-val 0.)
  (println "Finished epoch " epoch))
Running epoch  1 ......Finished epoch  1
Running epoch  2 ......Finished epoch  2
Running epoch  3 ......Finished epoch  3
Running epoch  4 ......Finished epoch  4
Running epoch  5 ......Finished epoch  5
Running epoch  6 ......Finished epoch  6
Running epoch  7 ......Finished epoch  7
Running epoch  8 ......Finished epoch  8
Running epoch  9 ......Finished epoch  9
Running epoch  10 ......Finished epoch  10
(let [x (range 1 (+ nepochs 1))]
  (-> (c/xy-chart
       {"test acc"
        {:x x
         :y @test-accuracy
         :style {:marker-type :none}}
        "train acc"
        {:x x
         :y @train-accuracy
         :style {:marker-type :none}}
        "train loss"
        {:x x
         :y @train-loss
         :style {:marker-type :none}}})
      (c/spit "figure/mlp-scratch.svg")))

Sorry, your browser does not support SVG.

Created: 2021-04-11 Sun 20:59