(def net (-> (nn/sequential-block)
(nn/add (nn/batch-flatten-block 784))
(nn/add (-> (nn/new-linear-builder)
(nn/set-units 256)
(nn/build)))
(nn/add (utils/as-function nn/relu))
(nn/add (-> (nn/new-linear-builder)
(nn/set-units 10)
(nn/build)))
(nn/set-initializer (nn/new-normal-initializer))))
(def batch-size 256)
(def nepochs 10)
(def lr 0.5)
(def epoch-loss (atom 0.))
(def accuracy-val (atom 0.))
(def train-loss (atom []))
(def train-accuracy (atom []))
(def test-accuracy (atom []))
(def mnist-train (-> (FashionMnist/builder)
(ds/opt-usage :train)
(ds/set-sampling batch-size true)
(ds/build)
(ds/prepare)))
(def mnist-test (-> (FashionMnist/builder)
(ds/opt-usage :test)
(ds/set-sampling batch-size true)
(ds/build)
(ds/prepare)))
(def evaluator-metrics (atom {}))
(def lrt (tracker/fixed 0.5))
(def sgd (-> (optimizer/sgd)
(optimizer/set-learning-rate-tracker lrt)
(optimizer/build)))
(def loss (loss/sotfmax-cross-entropy-loss))
(def config (-> (t/new-training-config loss)
(t/opt-optimizer sgd)
(t/add-evaluator (t/new-accuracy))
(t/add-training-listeners (TrainingListener$Defaults/logging))))
(def evals (atom nil))
(def mets (atom nil))
(with-open [model (-> (m/new-instance "mlp")
(m/set-block net))
trainer (m/new-trainer model config)]
(-> trainer
(t/initialize [(nd/shape [1 784])])
(t/set-metrics (t/metrics))
(t/fit nepochs mnist-train mnist-test))
(reset! evals (t/get-evaluators trainer))
(reset! mets (t/get-metrics trainer))
(let [metrics (t/get-metrics trainer)]
(doseq [evaluator (t/get-evaluators trainer)]
(swap! evaluator-metrics
assoc (str "train_epoch_" (.getName evaluator))
(map :value (metrics (str "train_epoch_" (.getName evaluator)))))
(swap! evaluator-metrics
assoc (str "validate_epoch_" (.getName evaluator))
(map :value (metrics (str "validate_epoch_" (.getName evaluator))))))))
(let [x (range 1 (+ nepochs 1))]
(-> (c/xy-chart
{"test acc"
{:x x
:y (@evaluator-metrics "validate_epoch_Accuracy")
:style {:marker-type :none}}
"train acc"
{:x x
:y (@evaluator-metrics "train_epoch_Accuracy")
:style {:marker-type :none}}
"train loss"
{:x x
:y (@evaluator-metrics "train_epoch_SoftmaxCrossEntropyLoss")
:style {:marker-type :none}}})
(c/spit "mlp-concise.svg")))