Table of Contents
1. Concise Implementation of Linear Regression
Broad and intense interest in deep learning for the past several years has inspired companies, academics, and hobbyists to develop a variety of mature open source frameworks for automating the repetitive work of implementing gradient-based learning algorithms. In :numref:`seclinearscratch`, we relied only on (i) tensors for data storage and linear algebra; and (ii) auto differentiation for calculating gradients. In practice, because data iterators, loss functions, optimizers, and neural network layers are so common, modern libraries implement these components for us as well.
In this section, we will show you how to implement the linear regression model from Section 3.2-linear-reg-impl-from-scratch.html#lin-reg-scratch concisely by using high-level APIs of deep learning frameworks.
1.1. Generating the Dataset
To start, we will generate the same dataset as in Section 3.2-linear-reg-impl-from-scratch.html#lin-reg-scratch.
(ns clj-d2l.linreg-easy (:require [clojure.java.io :as io] [clj-d2l.core :as d2l] [clj-djl.ndarray :as nd] [clj-djl.device :as device] [clj-djl.engine :as engine] [clj-djl.training.dataset :as ds] [clj-djl.model :as model] [clj-djl.nn :as nn] [clj-djl.training.loss :as loss] [clj-djl.training.tracker :as tracker] [clj-djl.training.optimizer :as optimizer] [clj-djl.training.parameter :as parameter] [clj-djl.training.initializer :as initializer] [clj-djl.training :as t] [clj-djl.training.listener :as listener]) (:import [ai.djl.ndarray.types DataType] [java.nio.file Paths]))
(def ndm (nd/base-manager)) (def true-w (nd/create ndm (float-array [2 -3.4]))) (def true-b 4.2) (def dp (d2l/synthetic-data ndm true-w true-b 1000)) (def features (get dp 0)) (def labels (get dp 1)) (println "features(0): "(nd/to-vec (nd/get features [0]))) (println "labels(0): " (nd/to-vec (nd/get labels [0])))
features(0): [0.36353156 1.8406333] labels(0): [-1.3278697]
1.2. Reading the Dataset
Rather than rolling our own iterator, we can call upon the existing
API in a framework to read data. We pass in features
and labels
as
arguments and specify batch-size
when instantiating a data iterator
object. Besides, the boolean value is-train
indicates whether or not
we want the data iterator object to shuffle the data on each epoch
(pass through the dataset).
(def batch-size 10) (def datasets (-> (ds/new-array-dataset-builder) (ds/set-data features) (ds/opt-labels labels) ;; (defn set-sampling [batch-size shuffle] ...) (ds/set-sampling batch-size false) (ds/build)))
Now we can use get-data-iterator
in much the same way as we called the
data-iter
function in Section
3.2-linear-reg-impl-from-scratch.html#lin-reg-scratch. To
verify that it is working, we can read and print the first minibatch
of examples. Comparing with Section
3.2-linear-reg-impl-from-scratch.html#lin-reg-scratch, here
we use Clojure function first
to obtain the first item from the
iterator.
(-> datasets (ds/get-data-iterator ndm) ;; generate a data iterator first ;; get the first batch ds/get-data ;; get the data first)
ND: (10, 2) cpu() float32 [[-2.2896, 1.0315], [-0.6617, 0.5531], [ 0.3967, -0.9902], [ 0.9992, -1.9574], [-0.9857, -0.1098], [-0.5344, 0.0834], [ 1.0844, 0.2221], [ 1.3125, -0.8627], [-0.581 , 0.7608], [-1.4804, 0.2687], ]
(-> datasets (ds/get-data-iterator ndm) ;; generate a data iterator first ;; get the first batch ds/get-labels ;; get the labels first)
ND: (10) cpu() float32 [ 5.7955, 3.7533, 11.8775, 6.7918, 3.2628, 2.2621, 1.5316, 10.5247, 15.5371, -0.0894]
2. Defining the Model
When we implemented linear regression from scratch in Section 3.2-linear-reg-impl-from-scratch.html#lin-reg-scratch, we defined our model parameters explicitly and coded up the calculations to produce output using basic linear algebra operations. You should know how to do this. But once your models get more complex, and once you have to do this nearly every day, you will be glad for the assistance. The situation is similar to coding up your own blog from scratch. Doing it once or twice is rewarding and instructive, but you would be a lousy web developer if every time you needed a blog you spent a month reinventing the wheel.
For standard operations, we can use a framework’s predefined layers,
which allow us to focus especially on the layers used to construct the
model rather than having to focus on the implementation. We will first
define a model variable net
, which will refer to an instance of the
sequential-block
class. The sequential-block
class defines a container
for several layers that will be chained together. Given input data, a
sequential-block
instance passes it through the first layer, in turn
passing the output as the second layer’s input and so forth. In the
following example, our model consists of only one layer, so we do not
really need sequential-block
. But since nearly all of our future
models will involve multiple layers, we will use it anyway just to
familiarize you with the most standard workflow.
Recall the architecture of a single-layer network as shown in Fig 3.1-linear-regression.html#org30de283. The layer is said to be fully-connected because each of its inputs is connected to each of its outputs by means of a matrix-vector multiplication.
Now we define a model with name “lin-reg” and create a
sequential-block
with a linear-block
inside it. And finally, set the
sequential-block
to the model.
(def model (model/new-instance "lin-reg")) (def net (nn/sequential-block)) (def linear-block (nn/linear-block {:bias true :units 1})) (nn/add net linear-block)
2.1. Initializing Model Parameters
Before using net
, we need to initialize the model parameters, such as
the weights and bias in the linear regression model. Deep learning
frameworks often have a predefined way to initialize the parameters.
Here we specify that each weight parameter should be randomly sampled
from a normal distribution with mean 0 and standard deviation
0.01. The bias parameter will be initialized to zero.
We import the initializer
namespace from clj-djl
. This module provides
various methods for model parameter initialization. We only specify
how to initialize the weight by calling (normal-initializer
0.01)
. Bias parameters are initialized to zero by default.
(nn/set-initializer net (initializer/normal-initializer 0.01) parameter/weight) (model/set-block model net)
Model ( Name: lin-reg Data Type: float32 )
The code above may look straightforward but you should note that something strange is happening here. We are initializing parameters for a network even though clj-djl does not yet know how many dimensions the input will have! It might be 2 as in our example or it might be 2000. clj-djl lets us get away with this because behind the scene, the initialization is actually deferred. The real initialization will take place only when we for the first time attempt to pass data through the network. Just be careful to remember that since the parameters have not been initialized yet, we cannot access or manipulate them.
2.2. Defining the Loss Function
In clj-djl, the loss namespace defines various loss functions. In this example, we will use the squared loss (l2-Loss).
(def loss (loss/l2-loss))
#'clj-d2l.linreg-easy/loss
2.3. Defining the Optimization Algorithm
Minibatch stochastic gradient descent is a standard tool for
optimizing neural networks and thus clj-djl supports it alongside a
number of variations on this algorithm through its trainer
. When we
instantiate trainer
, we will specify the parameters to optimize over,
the optimization algorithm we wish to use (sgd), and a dictionary of
hyperparameters required by our optimization algorithm. Minibatch
stochastic gradient descent just requires that we set the value
learning rate, which is set to 0.03 here.
(def lrt (tracker/fixed 0.3)) (def sgd (optimizer/sgd {:tracker lrt}))
#'clj-d2l.linreg-easy/sgd
2.4. Instantiate Configuration and Trainer
(def trainer (t/trainer {:model model :loss loss :optimizer sgd :listeners (listener/logging)}))
#'clj-d2l.linreg-easy/trainer
2.5. Initializing Model Parameters
(t/initialize trainer [(nd/shape batch-size 2)])
ai.djl.training.Trainer@4b641b7c
2.6. Metrics
(def metrics (t/metrics)) (t/set-metrics trainer metrics)
2.7. Training
You might have noticed that expressing our model through high-level APIs of a deep learning framework requires comparatively few lines of code. We did not have to individually allocate parameters, define our loss function, or implement minibatch stochastic gradient descent. Once we start working with much more complex models, advantages of high-level APIs will grow considerably. However, once we have all the basic pieces in place, the training loop itself is strikingly similar to what we did when implementing everything from scratch.
To refresh your memory: for some number of epochs, we will make a complete pass over the dataset (train-data), iteratively grabbing one minibatch of inputs and the corresponding ground-truth labels. For each minibatch, we go through the following ritual:
- Generate predictions by calling
train-batch
and calculate the loss l (the forward propagation). - Calculate gradients by running the backpropagation.
- Update the model parameters by invoking our optimizer.
For good measure, we compute the loss after each epoch and print it to monitor progress.
(def epochs 3) (doseq [epoch (range epochs)] (doseq [batch (t/iterate-dataset trainer datasets)] (t/train-batch trainer batch) (t/step trainer) (ds/close-batch batch)) (t/notify-listeners trainer (fn [listner] (.onEpoch listner trainer))))
Training: 1% |= | L2Loss: _ Training: 2% |= | L2Loss: _ Training: 3% |== | L2Loss: _ Training: 4% |== | L2Loss: _ Training: 5% |=== | L2Loss: 6.82 Training: 6% |=== | L2Loss: 6.82 Training: 7% |=== | L2Loss: 6.82 Training: 8% |==== | L2Loss: 6.82 Training: 9% |==== | L2Loss: 6.82 Training: 10% |===== | L2Loss: 3.54 Training: 11% |===== | L2Loss: 3.54 Training: 12% |===== | L2Loss: 3.54 Training: 13% |====== | L2Loss: 3.54 Training: 14% |====== | L2Loss: 3.54 Training: 15% |======= | L2Loss: 2.36 Training: 16% |======= | L2Loss: 2.36 Training: 17% |======= | L2Loss: 2.36 Training: 18% |======== | L2Loss: 2.36 Training: 19% |======== | L2Loss: 2.36 Training: 20% |========= | L2Loss: 1.77 Training: 21% |========= | L2Loss: 1.77 Training: 22% |========= | L2Loss: 1.77 Training: 23% |========== | L2Loss: 1.77 Training: 24% |========== | L2Loss: 1.77 Training: 25% |=========== | L2Loss: 1.42 Training: 26% |=========== | L2Loss: 1.42 Training: 27% |=========== | L2Loss: 1.42 Training: 28% |============ | L2Loss: 1.42 Training: 29% |============ | L2Loss: 1.42 Training: 30% |============= | L2Loss: 1.18 Training: 31% |============= | L2Loss: 1.18 Training: 32% |============= | L2Loss: 1.18 Training: 33% |============== | L2Loss: 1.18 Training: 34% |============== | L2Loss: 1.18 Training: 35% |=============== | L2Loss: 1.01 Training: 36% |=============== | L2Loss: 1.01 Training: 37% |=============== | L2Loss: 1.01 Training: 38% |================ | L2Loss: 1.01 Training: 39% |================ | L2Loss: 1.01 Training: 40% |================= | L2Loss: 0.89 Training: 41% |================= | L2Loss: 0.89 Training: 42% |================= | L2Loss: 0.89 Training: 43% |================== | L2Loss: 0.89 Training: 44% |================== | L2Loss: 0.89 Training: 45% |=================== | L2Loss: 0.79 Training: 46% |=================== | L2Loss: 0.79 Training: 47% |=================== | L2Loss: 0.79 Training: 48% |==================== | L2Loss: 0.79 Training: 49% |==================== | L2Loss: 0.79 Training: 50% |===================== | L2Loss: 0.71 Training: 51% |===================== | L2Loss: 0.71 Training: 52% |===================== | L2Loss: 0.71 Training: 53% |====================== | L2Loss: 0.71 Training: 54% |====================== | L2Loss: 0.71 Training: 55% |======================= | L2Loss: 0.64 Training: 56% |======================= | L2Loss: 0.64 Training: 57% |======================= | L2Loss: 0.64 Training: 58% |======================== | L2Loss: 0.64 Training: 59% |======================== | L2Loss: 0.64 Training: 60% |========================= | L2Loss: 0.59 Training: 61% |========================= | L2Loss: 0.59 Training: 62% |========================= | L2Loss: 0.59 Training: 63% |========================== | L2Loss: 0.59 Training: 64% |========================== | L2Loss: 0.59 Training: 65% |=========================== | L2Loss: 0.55 Training: 66% |=========================== | L2Loss: 0.55 Training: 67% |=========================== | L2Loss: 0.55 Training: 68% |============================ | L2Loss: 0.55 Training: 69% |============================ | L2Loss: 0.55 Training: 70% |============================= | L2Loss: 0.51 Training: 71% |============================= | L2Loss: 0.51 Training: 72% |============================= | L2Loss: 0.51 Training: 73% |============================== | L2Loss: 0.51 Training: 74% |============================== | L2Loss: 0.51 Training: 75% |=============================== | L2Loss: 0.47 Training: 76% |=============================== | L2Loss: 0.47 Training: 77% |=============================== | L2Loss: 0.47 Training: 78% |================================ | L2Loss: 0.47 Training: 79% |================================ | L2Loss: 0.47 Training: 80% |================================= | L2Loss: 0.44 Training: 81% |================================= | L2Loss: 0.44 Training: 82% |================================= | L2Loss: 0.44 Training: 83% |================================== | L2Loss: 0.44 Training: 84% |================================== | L2Loss: 0.44 Training: 85% |=================================== | L2Loss: 0.42 Training: 86% |=================================== | L2Loss: 0.42 Training: 87% |=================================== | L2Loss: 0.42 Training: 88% |==================================== | L2Loss: 0.42 Training: 89% |==================================== | L2Loss: 0.42 Training: 90% |===================================== | L2Loss: 0.39 Training: 91% |===================================== | L2Loss: 0.39 Training: 92% |===================================== | L2Loss: 0.39 Training: 93% |====================================== | L2Loss: 0.39 Training: 94% |====================================== | L2Loss: 0.39 Training: 95% |======================================= | L2Loss: 0.37 Training: 96% |======================================= | L2Loss: 0.37 Training: 97% |======================================= | L2Loss: 0.37 Training: 98% |========================================| L2Loss: 0.37 Training: 99% |========================================| L2Loss: 0.37 Training: 100% |========================================| L2Loss: 0.35 Training: 1% |= | L2Loss: 0.35 Training: 2% |= | L2Loss: 0.35 Training: 3% |== | L2Loss: 0.35 Training: 4% |== | L2Loss: 0.35 Training: 5% |=== | L2Loss: 7.43E-05 Training: 6% |=== | L2Loss: 7.43E-05 Training: 7% |=== | L2Loss: 7.43E-05 Training: 8% |==== | L2Loss: 7.43E-05 Training: 9% |==== | L2Loss: 7.43E-05 Training: 10% |===== | L2Loss: 6.59E-05 Training: 11% |===== | L2Loss: 6.59E-05 Training: 12% |===== | L2Loss: 6.59E-05 Training: 13% |====== | L2Loss: 6.59E-05 Training: 14% |====== | L2Loss: 6.59E-05 Training: 15% |======= | L2Loss: 5.80E-05 Training: 16% |======= | L2Loss: 5.80E-05 Training: 17% |======= | L2Loss: 5.80E-05 Training: 18% |======== | L2Loss: 5.80E-05 Training: 19% |======== | L2Loss: 5.80E-05 Training: 20% |========= | L2Loss: 5.92E-05 Training: 21% |========= | L2Loss: 5.92E-05 Training: 22% |========= | L2Loss: 5.92E-05 Training: 23% |========== | L2Loss: 5.92E-05 Training: 24% |========== | L2Loss: 5.92E-05 Training: 25% |=========== | L2Loss: 5.60E-05 Training: 26% |=========== | L2Loss: 5.60E-05 Training: 27% |=========== | L2Loss: 5.60E-05 Training: 28% |============ | L2Loss: 5.60E-05 Training: 29% |============ | L2Loss: 5.60E-05 Training: 30% |============= | L2Loss: 5.70E-05 Training: 31% |============= | L2Loss: 5.70E-05 Training: 32% |============= | L2Loss: 5.70E-05 Training: 33% |============== | L2Loss: 5.70E-05 Training: 34% |============== | L2Loss: 5.70E-05 Training: 35% |=============== | L2Loss: 5.73E-05 Training: 36% |=============== | L2Loss: 5.73E-05 Training: 37% |=============== | L2Loss: 5.73E-05 Training: 38% |================ | L2Loss: 5.73E-05 Training: 39% |================ | L2Loss: 5.73E-05 Training: 40% |================= | L2Loss: 5.78E-05 Training: 41% |================= | L2Loss: 5.78E-05 Training: 42% |================= | L2Loss: 5.78E-05 Training: 43% |================== | L2Loss: 5.78E-05 Training: 44% |================== | L2Loss: 5.78E-05 Training: 45% |=================== | L2Loss: 5.58E-05 Training: 46% |=================== | L2Loss: 5.58E-05 Training: 47% |=================== | L2Loss: 5.58E-05 Training: 48% |==================== | L2Loss: 5.58E-05 Training: 49% |==================== | L2Loss: 5.58E-05 Training: 50% |===================== | L2Loss: 5.64E-05 Training: 51% |===================== | L2Loss: 5.64E-05 Training: 52% |===================== | L2Loss: 5.64E-05 Training: 53% |====================== | L2Loss: 5.64E-05 Training: 54% |====================== | L2Loss: 5.64E-05 Training: 55% |======================= | L2Loss: 5.74E-05 Training: 56% |======================= | L2Loss: 5.74E-05 Training: 57% |======================= | L2Loss: 5.74E-05 Training: 58% |======================== | L2Loss: 5.74E-05 Training: 59% |======================== | L2Loss: 5.74E-05 Training: 60% |========================= | L2Loss: 5.72E-05 Training: 61% |========================= | L2Loss: 5.72E-05 Training: 62% |========================= | L2Loss: 5.72E-05 Training: 63% |========================== | L2Loss: 5.72E-05 Training: 64% |========================== | L2Loss: 5.72E-05 Training: 65% |=========================== | L2Loss: 5.75E-05 Training: 66% |=========================== | L2Loss: 5.75E-05 Training: 67% |=========================== | L2Loss: 5.75E-05 Training: 68% |============================ | L2Loss: 5.75E-05 Training: 69% |============================ | L2Loss: 5.75E-05 Training: 70% |============================= | L2Loss: 5.77E-05 Training: 71% |============================= | L2Loss: 5.77E-05 Training: 72% |============================= | L2Loss: 5.77E-05 Training: 73% |============================== | L2Loss: 5.77E-05 Training: 74% |============================== | L2Loss: 5.77E-05 Training: 75% |=============================== | L2Loss: 5.85E-05 Training: 76% |=============================== | L2Loss: 5.85E-05 Training: 77% |=============================== | L2Loss: 5.85E-05 Training: 78% |================================ | L2Loss: 5.85E-05 Training: 79% |================================ | L2Loss: 5.85E-05 Training: 80% |================================= | L2Loss: 5.78E-05 Training: 81% |================================= | L2Loss: 5.78E-05 Training: 82% |================================= | L2Loss: 5.78E-05 Training: 83% |================================== | L2Loss: 5.78E-05 Training: 84% |================================== | L2Loss: 5.78E-05 Training: 85% |=================================== | L2Loss: 5.64E-05 Training: 86% |=================================== | L2Loss: 5.64E-05 Training: 87% |=================================== | L2Loss: 5.64E-05 Training: 88% |==================================== | L2Loss: 5.64E-05 Training: 89% |==================================== | L2Loss: 5.64E-05 Training: 90% |===================================== | L2Loss: 5.66E-05 Training: 91% |===================================== | L2Loss: 5.66E-05 Training: 92% |===================================== | L2Loss: 5.66E-05 Training: 93% |====================================== | L2Loss: 5.66E-05 Training: 94% |====================================== | L2Loss: 5.66E-05 Training: 95% |======================================= | L2Loss: 5.76E-05 Training: 96% |======================================= | L2Loss: 5.76E-05 Training: 97% |======================================= | L2Loss: 5.76E-05 Training: 98% |========================================| L2Loss: 5.76E-05 Training: 99% |========================================| L2Loss: 5.76E-05 Training: 100% |========================================| L2Loss: 5.63E-05 Training: 1% |= | L2Loss: 5.63E-05 Training: 2% |= | L2Loss: 5.63E-05 Training: 3% |== | L2Loss: 5.63E-05 Training: 4% |== | L2Loss: 5.63E-05 Training: 5% |=== | L2Loss: 7.43E-05 Training: 6% |=== | L2Loss: 7.43E-05 Training: 7% |=== | L2Loss: 7.43E-05 Training: 8% |==== | L2Loss: 7.43E-05 Training: 9% |==== | L2Loss: 7.43E-05 Training: 10% |===== | L2Loss: 6.59E-05 Training: 11% |===== | L2Loss: 6.59E-05 Training: 12% |===== | L2Loss: 6.59E-05 Training: 13% |====== | L2Loss: 6.59E-05 Training: 14% |====== | L2Loss: 6.59E-05 Training: 15% |======= | L2Loss: 5.80E-05 Training: 16% |======= | L2Loss: 5.80E-05 Training: 17% |======= | L2Loss: 5.80E-05 Training: 18% |======== | L2Loss: 5.80E-05 Training: 19% |======== | L2Loss: 5.80E-05 Training: 20% |========= | L2Loss: 5.92E-05 Training: 21% |========= | L2Loss: 5.92E-05 Training: 22% |========= | L2Loss: 5.92E-05 Training: 23% |========== | L2Loss: 5.92E-05 Training: 24% |========== | L2Loss: 5.92E-05 Training: 25% |=========== | L2Loss: 5.60E-05 Training: 26% |=========== | L2Loss: 5.60E-05 Training: 27% |=========== | L2Loss: 5.60E-05 Training: 28% |============ | L2Loss: 5.60E-05 Training: 29% |============ | L2Loss: 5.60E-05 Training: 30% |============= | L2Loss: 5.70E-05 Training: 31% |============= | L2Loss: 5.70E-05 Training: 32% |============= | L2Loss: 5.70E-05 Training: 33% |============== | L2Loss: 5.70E-05 Training: 34% |============== | L2Loss: 5.70E-05 Training: 35% |=============== | L2Loss: 5.73E-05 Training: 36% |=============== | L2Loss: 5.73E-05 Training: 37% |=============== | L2Loss: 5.73E-05 Training: 38% |================ | L2Loss: 5.73E-05 Training: 39% |================ | L2Loss: 5.73E-05 Training: 40% |================= | L2Loss: 5.78E-05 Training: 41% |================= | L2Loss: 5.78E-05 Training: 42% |================= | L2Loss: 5.78E-05 Training: 43% |================== | L2Loss: 5.78E-05 Training: 44% |================== | L2Loss: 5.78E-05 Training: 45% |=================== | L2Loss: 5.58E-05 Training: 46% |=================== | L2Loss: 5.58E-05 Training: 47% |=================== | L2Loss: 5.58E-05 Training: 48% |==================== | L2Loss: 5.58E-05 Training: 49% |==================== | L2Loss: 5.58E-05 Training: 50% |===================== | L2Loss: 5.64E-05 Training: 51% |===================== | L2Loss: 5.64E-05 Training: 52% |===================== | L2Loss: 5.64E-05 Training: 53% |====================== | L2Loss: 5.64E-05 Training: 54% |====================== | L2Loss: 5.64E-05 Training: 55% |======================= | L2Loss: 5.74E-05 Training: 56% |======================= | L2Loss: 5.74E-05 Training: 57% |======================= | L2Loss: 5.74E-05 Training: 58% |======================== | L2Loss: 5.74E-05 Training: 59% |======================== | L2Loss: 5.74E-05 Training: 60% |========================= | L2Loss: 5.72E-05 Training: 61% |========================= | L2Loss: 5.72E-05 Training: 62% |========================= | L2Loss: 5.72E-05 Training: 63% |========================== | L2Loss: 5.72E-05 Training: 64% |========================== | L2Loss: 5.72E-05 Training: 65% |=========================== | L2Loss: 5.75E-05 Training: 66% |=========================== | L2Loss: 5.75E-05 Training: 67% |=========================== | L2Loss: 5.75E-05 Training: 68% |============================ | L2Loss: 5.75E-05 Training: 69% |============================ | L2Loss: 5.75E-05 Training: 70% |============================= | L2Loss: 5.77E-05 Training: 71% |============================= | L2Loss: 5.77E-05 Training: 72% |============================= | L2Loss: 5.77E-05 Training: 73% |============================== | L2Loss: 5.77E-05 Training: 74% |============================== | L2Loss: 5.77E-05 Training: 75% |=============================== | L2Loss: 5.85E-05 Training: 76% |=============================== | L2Loss: 5.85E-05 Training: 77% |=============================== | L2Loss: 5.85E-05 Training: 78% |================================ | L2Loss: 5.85E-05 Training: 79% |================================ | L2Loss: 5.85E-05 Training: 80% |================================= | L2Loss: 5.78E-05 Training: 81% |================================= | L2Loss: 5.78E-05 Training: 82% |================================= | L2Loss: 5.78E-05 Training: 83% |================================== | L2Loss: 5.78E-05 Training: 84% |================================== | L2Loss: 5.78E-05 Training: 85% |=================================== | L2Loss: 5.64E-05 Training: 86% |=================================== | L2Loss: 5.64E-05 Training: 87% |=================================== | L2Loss: 5.64E-05 Training: 88% |==================================== | L2Loss: 5.64E-05 Training: 89% |==================================== | L2Loss: 5.64E-05 Training: 90% |===================================== | L2Loss: 5.66E-05 Training: 91% |===================================== | L2Loss: 5.66E-05 Training: 92% |===================================== | L2Loss: 5.66E-05 Training: 93% |====================================== | L2Loss: 5.66E-05 Training: 94% |====================================== | L2Loss: 5.66E-05 Training: 95% |======================================= | L2Loss: 5.76E-05 Training: 96% |======================================= | L2Loss: 5.76E-05 Training: 97% |======================================= | L2Loss: 5.76E-05 Training: 98% |========================================| L2Loss: 5.76E-05 Training: 99% |========================================| L2Loss: 5.76E-05 Training: 100% |========================================| L2Loss: 5.63E-05
Below, we compare the model parameters learned by training on finite data and the actual parameters that generated our dataset. To access parameters, we first access the layer that we need from net and then access that layer’s weights and bias. As in our from-scratch implementation, note that our estimated parameters are close to their ground-truth counterparts.
(def params (-> model (model/get-block) (model/get-parameters))) (def w (.getArray (.valueAt params 0))) (def b (.getArray (.valueAt params 1))) (def w-error (nd/to-vec (nd/- true-w (nd/reshape w (nd/get-shape true-w))))) (println "Error in estimating w:" (vec w-error)) (println "Error in estimating w:" (- true-b (nd/get-element b)))
Error in estimating w: [-0.0019903183 7.4744225E-4] Error in estimating w: -4.289627075193536E-4
2.8. Saving Your Model
You can also save the model for future prediction task.
(defn save-model [model path epoch name] (let [nio-path (java.nio.file.Paths/get path (into-array [""]))] (io/make-parents path) (model/set-property model "Epoch" epoch) (model/save model nio-path name))) (save-model model "models/lin-reg" "3" "lin-reg") (println (str model))
Model ( Name: lin-reg Model location: /home/kimim/workspace/clj-d2l/models/lin-reg Data Type: float32 Epoch: 3 )
2.9. Summary
- Using clj-djl, we can implement models much more concisely.
- In clj-djl, the
dataset
namespace provides tools for data processing, thenn
namespace defines a large number of neural network layers, and theloss
namespace defines many common loss functions. initializer
namespace provides various methods for model parameter initialization.- Dimensionality and storage are automatically inferred, but be careful not to attempt to access parameters before they have been initialized.