(def nepochs 10)
(def lr 0.5)
(def epoch-loss (atom 0.))
(def accuracy-val (atom 0.))
(def train-loss (atom []))
(def train-accuracy (atom []))
(def test-accuracy (atom []))
(doseq [epoch (range 1 (+ nepochs 1))]
(print "Running epoch " epoch "......")
(doseq [batch (ds/get-data-iterator mnist-train manager)]
(let [X (nd/head (ds/get-batch-data batch))
y (nd/head (ds/get-batch-labels batch))]
(with-open [gc (t/gradient-collector)
yhat (net X)
lossvalue (.evaluate loss (nd/ndlist [y]) (nd/ndlist [yhat]))
l (nd/* lossvalue batch-size)]
(swap! epoch-loss + (nd/get-element(nd/sum l)))
(swap! accuracy-val + (d2l/accuracy yhat y))
(.backward gc l))
(ds/close-batch batch)
(d2l/sgd params lr batch-size)))
(swap! train-loss conj (/ @epoch-loss (nd/size mnist-train)))
(swap! train-accuracy conj (/ @accuracy-val (nd/size mnist-train)))
(reset! epoch-loss 0.)
(reset! accuracy-val 0.)
(doseq [batch (ds/get-data-iterator mnist-test manager)]
(let [X (nd/head (ds/get-batch-data batch))
y (nd/head (ds/get-batch-labels batch))
yhat (net X)]
(swap! accuracy-val + (d2l/accuracy yhat y))))
(swap! test-accuracy conj (/ @accuracy-val (nd/size mnist-test)))
(reset! accuracy-val 0.)
(println "Finished epoch " epoch))
Running epoch 1 ......Finished epoch 1
Running epoch 2 ......Finished epoch 2
Running epoch 3 ......Finished epoch 3
Running epoch 4 ......Finished epoch 4
Running epoch 5 ......Finished epoch 5
Running epoch 6 ......Finished epoch 6
Running epoch 7 ......Finished epoch 7
Running epoch 8 ......Finished epoch 8
Running epoch 9 ......Finished epoch 9
Running epoch 10 ......Finished epoch 10
(let [x (range 1 (+ nepochs 1))]
(-> (c/xy-chart
{"test acc"
{:x x
:y @test-accuracy
:style {:marker-type :none}}
"train acc"
{:x x
:y @train-accuracy
:style {:marker-type :none}}
"train loss"
{:x x
:y @train-loss
:style {:marker-type :none}}})
(c/spit "figure/mlp-scratch.svg")))