UP | HOME

Table of Contents

1 Weight Decay

1.1 High-Dimensional Linear Regression

(ns clj-d2l.weight-decay
  (:require [clj-djl.ndarray :as nd]
            [clj-djl.training.dataset :as ds]
            [clj-djl.training.loss :as loss]
            [clj-djl.training :as t]
            [clj-djl.training.optimizer :as o]
            [clj-djl.training.tracker :as tracker]
            [clj-djl.training.listener :as listener]
            [clj-djl.model :as m]
            [clj-djl.nn :as nn]
            [clj-djl.device :as dev]
            [clj-d2l.core :as d2l]))
(def ntrain 20)
(def ntest 100)
(def ninputs 200)
(def batchsize 5)
(def ndm (nd/new-base-manager))
(def truew (-> (nd/ones ndm [ninputs 1]) (nd/* 0.01)))
(def trueb 0.05) ;; create [0.05]?
(def train-data (d2l/synthetic-data ndm truew trueb ntrain))
(def test-data (d2l/synthetic-data ndm truew trueb ntest))
(def train-ds (d2l/load-array train-data batchsize false))
(def test-ds (d2l/load-array test-data batchsize false))

1.2 Implementation from Scratch

1.2.1 Defining L2 Norm Penalty

(defn l2penalty [w]
  (-> (nd/** w 2)
      (nd/sum)
      (nd// 2)))
(def metrics (atom {:train-loss []
                    :test-loss []
                    :epoch []}))

1.2.2 Defining the Training Loop

(defn train [lambd]
  (let [nepochs 100
        lr 0.003
        w (nd/random-normal ndm 0 1.0 [ninputs 1] :float32)
        b (nd/zeros ndm [1])
        _ (dorun (map nd/attach-gradient [w b]))
        params (nd/ndlist w b)]

    (reset! metrics {:train-loss []
                     :test-loss []
                     :epoch []})

    (doseq [epoch (range 0 (inc 100))]
      (doseq [batch (t/iter-seq (ds/get-data train-ds ndm))]
        (let [X (first (ds/get-batch-data batch))
              y (first (ds/get-batch-labels batch))
              w (nd/get params 0)
              b (nd/get params 1)]
          (with-open [gc (t/new-gradient-collector)]
            ;;NDArray l = Training.squaredLoss(Training.linreg(X, w, b), y).add(l2Penalty(w).mul(lambd));
            (let [l (-> (d2l/linreg X w b) (d2l/squared-loss y) (nd/+ (nd/* (l2penalty w) lambd)))]
              (t/backward gc l))))
        (ds/close-batch batch)
        (d2l/sgd params lr batchsize))
      (when (zero? (mod epoch 5))
        (let [index (quot epoch 5)
              train-loss (-> (d2l/linreg (train-data 0) (nd/get params 0) (nd/get params 1))
                             (d2l/squared-loss (train-data 1)))
              test-loss (-> (d2l/linreg (test-data 0) (nd/get params 0) (nd/get params 1))
                            (d2l/squared-loss (test-data 1)))]
          (swap! metrics assoc-in [:epoch index] epoch)
          (swap! metrics assoc-in [:train-loss index] train-loss)
          (swap! metrics assoc-in [:test-loss index] test-loss))))
    (println "l1 norm of w: " (-> (nd/get params 0) (nd/abs) (nd/sum)))))

1.2.3 Training without Regularization

(train 0)

(d2l/plot-lines "figure/train_no_regularization.svg"
                ["train loss" "test loss"]
                (@metrics :epoch)
                (map #(map (fn [ndarray]
                             (nd/get-element (nd/log10 (nd/mean ndarray)))) %)
                     [(@metrics :train-loss) (@metrics :test-loss)]))

Sorry, your browser does not support SVG.

1.2.4 Using Weight Decay

(train 3)

(d2l/plot-lines "figure/train_weight_decay.svg"
                ["train loss" "test loss"]
                (@metrics :epoch)
                (map #(map (fn [ndarray]
                             (nd/get-element (nd/log10 (nd/mean ndarray)))) %)
                     [(@metrics :train-loss) (@metrics :test-loss)]))

Sorry, your browser does not support SVG.

1.3 Concise Implementation

(defn train-concise [wd]
  (let [nepochs 100
        lr 0.003
        w (nd/random-normal ndm 0 1.0 [ninputs 1] :float32)
        b (nd/zeros ndm [1])
        _ (dorun (map nd/attach-gradient [w b]))
        params (nd/ndlist w b)
        config (t/default-training-config {:loss (loss/l2)
                                           :optimizer (o/sgd {:tracker (tracker/fixed lr)
                                                              :weight-decay wd})
                                           :devices (dev/get-devices 1)
                                           :evaluator (loss/l2)
                                           :listeners (listener/logging)})]
    (with-open [model (m/new-model {:name "mlp"
                                    :block (-> (nn/sequential-block)
                                               (nn/add (nn/linear-block {:bias true :units 1})))})
                trainer (t/new-trainer model config)]

      (reset! metrics {:train-loss []
                       :test-loss []
                       :epoch []})

      (t/initialize trainer (nd/shape [batchsize 200]))

      (doseq [epoch (range 0 (inc 100))]
        (doseq [batch (t/iterate-dataset trainer train-ds)]
          (println (first (ds/get-batch-data batch))) ;; 5 x 200
          (t/train-batch trainer batch)
          (t/step trainer)
          (ds/close-batch batch))
        (doseq [batch (t/iterate-dataset trainer test-ds)]
          (t/validate-batch trainer batch)
          (ds/close-batch batch))
        (t/notify-listeners trainer (fn [listner] (.onEpoch listner trainer)))
        (when (zero? (mod epoch 5))
          (let [train-result (t/get-training-result trainer)
                index (quot epoch 5)]
            (swap! metrics assoc-in [:epoch index] (train-result :epoch))
            (swap! metrics assoc-in [:train-loss index] (train-result :train-loss))
            (swap! metrics assoc-in [:test-loss index] (train-result :validate-loss)))))
      (println "l2 norm of w: " (-> model
                                    (m/get-block)
                                    (nn/get-parameters)
                                    (.get "01Linear_weight")
                                    (.getArray)
                                    (nd/to-vec))))))

1.3.1 without weight decay

extend eval timeout for long training:

(setq org-babel-clojure-sync-nrepl-timeout 1000)
1000
(train-concise 0)

(d2l/plot-lines "figure/train_no_wd_concise.svg"
                ["train loss" "test loss"]
                (@metrics :epoch)
                [(@metrics :train-loss) (@metrics :test-loss)])

Sorry, your browser does not support SVG.

(train-concise 3)

(d2l/plot-lines "figure/train_wd_concise.svg"
                ["train loss" "test loss"]
                (@metrics :epoch)
                [(@metrics :train-loss) (@metrics :test-loss)])

Sorry, your browser does not support SVG.

Created: 2021-04-11 Sun 20:59