Table of Contents
1 Weight Decay
1.1 High-Dimensional Linear Regression
(ns clj-d2l.weight-decay (:require [clj-djl.ndarray :as nd] [clj-djl.training.dataset :as ds] [clj-djl.training.loss :as loss] [clj-djl.training :as t] [clj-djl.training.optimizer :as o] [clj-djl.training.tracker :as tracker] [clj-djl.training.listener :as listener] [clj-djl.model :as m] [clj-djl.nn :as nn] [clj-djl.device :as dev] [clj-d2l.core :as d2l]))
(def ntrain 20) (def ntest 100) (def ninputs 200) (def batchsize 5) (def ndm (nd/new-base-manager)) (def truew (-> (nd/ones ndm [ninputs 1]) (nd/* 0.01))) (def trueb 0.05) ;; create [0.05]? (def train-data (d2l/synthetic-data ndm truew trueb ntrain)) (def test-data (d2l/synthetic-data ndm truew trueb ntest)) (def train-ds (d2l/load-array train-data batchsize false)) (def test-ds (d2l/load-array test-data batchsize false))
1.2 Implementation from Scratch
1.2.1 Defining L2 Norm Penalty
(defn l2penalty [w] (-> (nd/** w 2) (nd/sum) (nd// 2)))
(def metrics (atom {:train-loss [] :test-loss [] :epoch []}))
1.2.2 Defining the Training Loop
(defn train [lambd] (let [nepochs 100 lr 0.003 w (nd/random-normal ndm 0 1.0 [ninputs 1] :float32) b (nd/zeros ndm [1]) _ (dorun (map nd/attach-gradient [w b])) params (nd/ndlist w b)] (reset! metrics {:train-loss [] :test-loss [] :epoch []}) (doseq [epoch (range 0 (inc 100))] (doseq [batch (t/iter-seq (ds/get-data train-ds ndm))] (let [X (first (ds/get-batch-data batch)) y (first (ds/get-batch-labels batch)) w (nd/get params 0) b (nd/get params 1)] (with-open [gc (t/new-gradient-collector)] ;;NDArray l = Training.squaredLoss(Training.linreg(X, w, b), y).add(l2Penalty(w).mul(lambd)); (let [l (-> (d2l/linreg X w b) (d2l/squared-loss y) (nd/+ (nd/* (l2penalty w) lambd)))] (t/backward gc l)))) (ds/close-batch batch) (d2l/sgd params lr batchsize)) (when (zero? (mod epoch 5)) (let [index (quot epoch 5) train-loss (-> (d2l/linreg (train-data 0) (nd/get params 0) (nd/get params 1)) (d2l/squared-loss (train-data 1))) test-loss (-> (d2l/linreg (test-data 0) (nd/get params 0) (nd/get params 1)) (d2l/squared-loss (test-data 1)))] (swap! metrics assoc-in [:epoch index] epoch) (swap! metrics assoc-in [:train-loss index] train-loss) (swap! metrics assoc-in [:test-loss index] test-loss)))) (println "l1 norm of w: " (-> (nd/get params 0) (nd/abs) (nd/sum)))))
1.2.3 Training without Regularization
(train 0) (d2l/plot-lines "figure/train_no_regularization.svg" ["train loss" "test loss"] (@metrics :epoch) (map #(map (fn [ndarray] (nd/get-element (nd/log10 (nd/mean ndarray)))) %) [(@metrics :train-loss) (@metrics :test-loss)]))
1.2.4 Using Weight Decay
(train 3) (d2l/plot-lines "figure/train_weight_decay.svg" ["train loss" "test loss"] (@metrics :epoch) (map #(map (fn [ndarray] (nd/get-element (nd/log10 (nd/mean ndarray)))) %) [(@metrics :train-loss) (@metrics :test-loss)]))
1.3 Concise Implementation
(defn train-concise [wd] (let [nepochs 100 lr 0.003 w (nd/random-normal ndm 0 1.0 [ninputs 1] :float32) b (nd/zeros ndm [1]) _ (dorun (map nd/attach-gradient [w b])) params (nd/ndlist w b) config (t/default-training-config {:loss (loss/l2) :optimizer (o/sgd {:tracker (tracker/fixed lr) :weight-decay wd}) :devices (dev/get-devices 1) :evaluator (loss/l2) :listeners (listener/logging)})] (with-open [model (m/new-model {:name "mlp" :block (-> (nn/sequential-block) (nn/add (nn/linear-block {:bias true :units 1})))}) trainer (t/new-trainer model config)] (reset! metrics {:train-loss [] :test-loss [] :epoch []}) (t/initialize trainer (nd/shape [batchsize 200])) (doseq [epoch (range 0 (inc 100))] (doseq [batch (t/iterate-dataset trainer train-ds)] (println (first (ds/get-batch-data batch))) ;; 5 x 200 (t/train-batch trainer batch) (t/step trainer) (ds/close-batch batch)) (doseq [batch (t/iterate-dataset trainer test-ds)] (t/validate-batch trainer batch) (ds/close-batch batch)) (t/notify-listeners trainer (fn [listner] (.onEpoch listner trainer))) (when (zero? (mod epoch 5)) (let [train-result (t/get-training-result trainer) index (quot epoch 5)] (swap! metrics assoc-in [:epoch index] (train-result :epoch)) (swap! metrics assoc-in [:train-loss index] (train-result :train-loss)) (swap! metrics assoc-in [:test-loss index] (train-result :validate-loss))))) (println "l2 norm of w: " (-> model (m/get-block) (nn/get-parameters) (.get "01Linear_weight") (.getArray) (nd/to-vec))))))
1.3.1 without weight decay
extend eval timeout for long training:
(setq org-babel-clojure-sync-nrepl-timeout 1000)
1000
(train-concise 0) (d2l/plot-lines "figure/train_no_wd_concise.svg" ["train loss" "test loss"] (@metrics :epoch) [(@metrics :train-loss) (@metrics :test-loss)])
(train-concise 3) (d2l/plot-lines "figure/train_wd_concise.svg" ["train loss" "test loss"] (@metrics :epoch) [(@metrics :train-loss) (@metrics :test-loss)])