UP | HOME

Table of Contents

1 The Image Classification Dataset

(ns clj-d2l.image-classification
  (:require [clojure.java.io :as io]
            [clj-djl.ndarray :as nd]
            [clj-djl.device :as device]
            [clj-djl.engine :as engine]
            [clj-djl.training.dataset :as ds]
            [clj-djl.model :as model]
            [clj-djl.nn :as nn]
            [clj-djl.training.loss :as loss]
            [clj-djl.training.tracker :as tracker]
            [clj-djl.training.optimizer :as optimizer]
            [clj-djl.training :as training]
            [clj-djl.training.listener :as listener])
  (:import [ai.djl.ndarray.types DataType]
           [ai.djl.basicdataset FashionMnist]
           [ai.djl.training.dataset Dataset$Usage]
           [java.nio.file Paths]))

1.1 Getting the Dataset

(setq org-babel-clojure-sync-nrepl-timeout 1000)
1000
(def batch-size 256)
(def random-shuffle true)

(def mnist-train (-> (FashionMnist/builder)
                     (ds/opt-usage Dataset$Usage/TRAIN)
                     (ds/set-sampling batch-size random-shuffle)
                     (ds/build)
                     (ds/prepare)))

(def mnist-test (-> (FashionMnist/builder)
                    (ds/opt-usage Dataset$Usage/TEST)
                    (ds/set-sampling batch-size random-shuffle)
                    (ds/build)
                    (ds/prepare)))

(println "train dataset size: "(nd/size mnist-train))
(println "test dataset size: " (nd/size mnist-test))
train dataset size:  60000
test dataset size:  10000

Created: 2021-04-11 Sun 20:59