UP | HOME

Table of Contents

1. Concise Implementation of Linear Regression

Broad and intense interest in deep learning for the past several years has inspired companies, academics, and hobbyists to develop a variety of mature open source frameworks for automating the repetitive work of implementing gradient-based learning algorithms. In :numref:`seclinearscratch`, we relied only on (i) tensors for data storage and linear algebra; and (ii) auto differentiation for calculating gradients. In practice, because data iterators, loss functions, optimizers, and neural network layers are so common, modern libraries implement these components for us as well.

In this section, we will show you how to implement the linear regression model from Section 3.2-linear-reg-impl-from-scratch.html#lin-reg-scratch concisely by using high-level APIs of deep learning frameworks.

1.1. Generating the Dataset

To start, we will generate the same dataset as in Section 3.2-linear-reg-impl-from-scratch.html#lin-reg-scratch.

(ns clj-d2l.linreg-easy
  (:require [clojure.java.io :as io]
            [clj-d2l.core :as d2l]
            [clj-djl.ndarray :as nd]
            [clj-djl.device :as device]
            [clj-djl.engine :as engine]
            [clj-djl.training.dataset :as ds]
            [clj-djl.model :as model]
            [clj-djl.nn :as nn]
            [clj-djl.training.loss :as loss]
            [clj-djl.training.tracker :as tracker]
            [clj-djl.training.optimizer :as optimizer]
            [clj-djl.training.parameter :as parameter]
            [clj-djl.training.initializer :as initializer]
            [clj-djl.training :as t]
            [clj-djl.training.listener :as listener])
  (:import [ai.djl.ndarray.types DataType]
           [java.nio.file Paths]))
(def ndm (nd/base-manager))
(def true-w (nd/create ndm (float-array [2 -3.4])))
(def true-b 4.2)
(def dp (d2l/synthetic-data ndm true-w true-b 1000))
(def features (get dp 0))
(def labels (get dp 1))
(println "features(0): "(nd/to-vec (nd/get features [0])))
(println "labels(0): " (nd/to-vec (nd/get labels [0])))
features(0):  [0.36353156 1.8406333]
labels(0):  [-1.3278697]

1.2. Reading the Dataset

Rather than rolling our own iterator, we can call upon the existing API in a framework to read data. We pass in features and labels as arguments and specify batch-size when instantiating a data iterator object. Besides, the boolean value is-train indicates whether or not we want the data iterator object to shuffle the data on each epoch (pass through the dataset).

(def batch-size 10)
(def datasets (-> (ds/new-array-dataset-builder)
                  (ds/set-data features)
                  (ds/opt-labels labels)
                  ;; (defn set-sampling [batch-size shuffle] ...)
                  (ds/set-sampling batch-size false)
                  (ds/build)))

Now we can use get-data-iterator in much the same way as we called the data-iter function in Section 3.2-linear-reg-impl-from-scratch.html#lin-reg-scratch. To verify that it is working, we can read and print the first minibatch of examples. Comparing with Section 3.2-linear-reg-impl-from-scratch.html#lin-reg-scratch, here we use Clojure function first to obtain the first item from the iterator.

(-> datasets
    (ds/get-data-iterator ndm) ;; generate a data iterator
    first ;; get the first batch
    ds/get-data ;; get the data
    first)
ND: (10, 2) cpu() float32
[[-2.2896,  1.0315],
 [-0.6617,  0.5531],
 [ 0.3967, -0.9902],
 [ 0.9992, -1.9574],
 [-0.9857, -0.1098],
 [-0.5344,  0.0834],
 [ 1.0844,  0.2221],
 [ 1.3125, -0.8627],
 [-0.581 ,  0.7608],
 [-1.4804,  0.2687],
]
(-> datasets
    (ds/get-data-iterator ndm) ;; generate a data iterator
    first ;; get the first batch
    ds/get-labels ;; get the labels
    first)
ND: (10) cpu() float32
[ 5.7955,  3.7533, 11.8775,  6.7918,  3.2628,  2.2621,  1.5316, 10.5247, 15.5371, -0.0894]

2. Defining the Model

When we implemented linear regression from scratch in Section 3.2-linear-reg-impl-from-scratch.html#lin-reg-scratch, we defined our model parameters explicitly and coded up the calculations to produce output using basic linear algebra operations. You should know how to do this. But once your models get more complex, and once you have to do this nearly every day, you will be glad for the assistance. The situation is similar to coding up your own blog from scratch. Doing it once or twice is rewarding and instructive, but you would be a lousy web developer if every time you needed a blog you spent a month reinventing the wheel.

For standard operations, we can use a framework’s predefined layers, which allow us to focus especially on the layers used to construct the model rather than having to focus on the implementation. We will first define a model variable net, which will refer to an instance of the sequential-block class. The sequential-block class defines a container for several layers that will be chained together. Given input data, a sequential-block instance passes it through the first layer, in turn passing the output as the second layer’s input and so forth. In the following example, our model consists of only one layer, so we do not really need sequential-block. But since nearly all of our future models will involve multiple layers, we will use it anyway just to familiarize you with the most standard workflow.

Recall the architecture of a single-layer network as shown in Fig 3.1-linear-regression.html#org30de283. The layer is said to be fully-connected because each of its inputs is connected to each of its outputs by means of a matrix-vector multiplication.

Now we define a model with name “lin-reg” and create a sequential-block with a linear-block inside it. And finally, set the sequential-block to the model.

(def model (model/new-instance "lin-reg"))
(def net (nn/sequential-block))
(def linear-block (nn/linear-block {:bias true
                                    :units 1}))
(nn/add net linear-block)

2.1. Initializing Model Parameters

Before using net, we need to initialize the model parameters, such as the weights and bias in the linear regression model. Deep learning frameworks often have a predefined way to initialize the parameters. Here we specify that each weight parameter should be randomly sampled from a normal distribution with mean 0 and standard deviation 0.01. The bias parameter will be initialized to zero.

We import the initializer namespace from clj-djl. This module provides various methods for model parameter initialization. We only specify how to initialize the weight by calling (normal-initializer 0.01). Bias parameters are initialized to zero by default.

(nn/set-initializer net (initializer/normal-initializer 0.01) parameter/weight)
(model/set-block model net)
Model (
	Name: lin-reg
	Data Type: float32
)

The code above may look straightforward but you should note that something strange is happening here. We are initializing parameters for a network even though clj-djl does not yet know how many dimensions the input will have! It might be 2 as in our example or it might be 2000. clj-djl lets us get away with this because behind the scene, the initialization is actually deferred. The real initialization will take place only when we for the first time attempt to pass data through the network. Just be careful to remember that since the parameters have not been initialized yet, we cannot access or manipulate them.

2.2. Defining the Loss Function

In clj-djl, the loss namespace defines various loss functions. In this example, we will use the squared loss (l2-Loss).

(def loss (loss/l2-loss))
#'clj-d2l.linreg-easy/loss

2.3. Defining the Optimization Algorithm

Minibatch stochastic gradient descent is a standard tool for optimizing neural networks and thus clj-djl supports it alongside a number of variations on this algorithm through its trainer. When we instantiate trainer, we will specify the parameters to optimize over, the optimization algorithm we wish to use (sgd), and a dictionary of hyperparameters required by our optimization algorithm. Minibatch stochastic gradient descent just requires that we set the value learning rate, which is set to 0.03 here.

(def lrt (tracker/fixed 0.3))
(def sgd (optimizer/sgd {:tracker lrt}))
#'clj-d2l.linreg-easy/sgd

2.4. Instantiate Configuration and Trainer

(def trainer (t/trainer {:model model
                         :loss loss
                         :optimizer sgd
                         :listeners (listener/logging)}))
#'clj-d2l.linreg-easy/trainer

2.5. Initializing Model Parameters

(t/initialize trainer [(nd/shape batch-size 2)])
ai.djl.training.Trainer@4b641b7c

2.6. Metrics

(def metrics (t/metrics))
(t/set-metrics trainer metrics)

2.7. Training

You might have noticed that expressing our model through high-level APIs of a deep learning framework requires comparatively few lines of code. We did not have to individually allocate parameters, define our loss function, or implement minibatch stochastic gradient descent. Once we start working with much more complex models, advantages of high-level APIs will grow considerably. However, once we have all the basic pieces in place, the training loop itself is strikingly similar to what we did when implementing everything from scratch.

To refresh your memory: for some number of epochs, we will make a complete pass over the dataset (train-data), iteratively grabbing one minibatch of inputs and the corresponding ground-truth labels. For each minibatch, we go through the following ritual:

  • Generate predictions by calling train-batch and calculate the loss l (the forward propagation).
  • Calculate gradients by running the backpropagation.
  • Update the model parameters by invoking our optimizer.

For good measure, we compute the loss after each epoch and print it to monitor progress.

(def epochs 3)

(doseq [epoch (range epochs)]
  (doseq [batch (t/iterate-dataset trainer datasets)]
    (t/train-batch trainer batch)
    (t/step trainer)
    (ds/close-batch batch))
  (t/notify-listeners trainer (fn [listner] (.onEpoch listner trainer))))

Training:      1% |=                                       | L2Loss: _
Training:      2% |=                                       | L2Loss: _
Training:      3% |==                                      | L2Loss: _
Training:      4% |==                                      | L2Loss: _
Training:      5% |===                                     | L2Loss: 6.82
Training:      6% |===                                     | L2Loss: 6.82
Training:      7% |===                                     | L2Loss: 6.82
Training:      8% |====                                    | L2Loss: 6.82
Training:      9% |====                                    | L2Loss: 6.82
Training:     10% |=====                                   | L2Loss: 3.54
Training:     11% |=====                                   | L2Loss: 3.54
Training:     12% |=====                                   | L2Loss: 3.54
Training:     13% |======                                  | L2Loss: 3.54
Training:     14% |======                                  | L2Loss: 3.54
Training:     15% |=======                                 | L2Loss: 2.36
Training:     16% |=======                                 | L2Loss: 2.36
Training:     17% |=======                                 | L2Loss: 2.36
Training:     18% |========                                | L2Loss: 2.36
Training:     19% |========                                | L2Loss: 2.36
Training:     20% |=========                               | L2Loss: 1.77
Training:     21% |=========                               | L2Loss: 1.77
Training:     22% |=========                               | L2Loss: 1.77
Training:     23% |==========                              | L2Loss: 1.77
Training:     24% |==========                              | L2Loss: 1.77
Training:     25% |===========                             | L2Loss: 1.42
Training:     26% |===========                             | L2Loss: 1.42
Training:     27% |===========                             | L2Loss: 1.42
Training:     28% |============                            | L2Loss: 1.42
Training:     29% |============                            | L2Loss: 1.42
Training:     30% |=============                           | L2Loss: 1.18
Training:     31% |=============                           | L2Loss: 1.18
Training:     32% |=============                           | L2Loss: 1.18
Training:     33% |==============                          | L2Loss: 1.18
Training:     34% |==============                          | L2Loss: 1.18
Training:     35% |===============                         | L2Loss: 1.01
Training:     36% |===============                         | L2Loss: 1.01
Training:     37% |===============                         | L2Loss: 1.01
Training:     38% |================                        | L2Loss: 1.01
Training:     39% |================                        | L2Loss: 1.01
Training:     40% |=================                       | L2Loss: 0.89
Training:     41% |=================                       | L2Loss: 0.89
Training:     42% |=================                       | L2Loss: 0.89
Training:     43% |==================                      | L2Loss: 0.89
Training:     44% |==================                      | L2Loss: 0.89
Training:     45% |===================                     | L2Loss: 0.79
Training:     46% |===================                     | L2Loss: 0.79
Training:     47% |===================                     | L2Loss: 0.79
Training:     48% |====================                    | L2Loss: 0.79
Training:     49% |====================                    | L2Loss: 0.79
Training:     50% |=====================                   | L2Loss: 0.71
Training:     51% |=====================                   | L2Loss: 0.71
Training:     52% |=====================                   | L2Loss: 0.71
Training:     53% |======================                  | L2Loss: 0.71
Training:     54% |======================                  | L2Loss: 0.71
Training:     55% |=======================                 | L2Loss: 0.64
Training:     56% |=======================                 | L2Loss: 0.64
Training:     57% |=======================                 | L2Loss: 0.64
Training:     58% |========================                | L2Loss: 0.64
Training:     59% |========================                | L2Loss: 0.64
Training:     60% |=========================               | L2Loss: 0.59
Training:     61% |=========================               | L2Loss: 0.59
Training:     62% |=========================               | L2Loss: 0.59
Training:     63% |==========================              | L2Loss: 0.59
Training:     64% |==========================              | L2Loss: 0.59
Training:     65% |===========================             | L2Loss: 0.55
Training:     66% |===========================             | L2Loss: 0.55
Training:     67% |===========================             | L2Loss: 0.55
Training:     68% |============================            | L2Loss: 0.55
Training:     69% |============================            | L2Loss: 0.55
Training:     70% |=============================           | L2Loss: 0.51
Training:     71% |=============================           | L2Loss: 0.51
Training:     72% |=============================           | L2Loss: 0.51
Training:     73% |==============================          | L2Loss: 0.51
Training:     74% |==============================          | L2Loss: 0.51
Training:     75% |===============================         | L2Loss: 0.47
Training:     76% |===============================         | L2Loss: 0.47
Training:     77% |===============================         | L2Loss: 0.47
Training:     78% |================================        | L2Loss: 0.47
Training:     79% |================================        | L2Loss: 0.47
Training:     80% |=================================       | L2Loss: 0.44
Training:     81% |=================================       | L2Loss: 0.44
Training:     82% |=================================       | L2Loss: 0.44
Training:     83% |==================================      | L2Loss: 0.44
Training:     84% |==================================      | L2Loss: 0.44
Training:     85% |===================================     | L2Loss: 0.42
Training:     86% |===================================     | L2Loss: 0.42
Training:     87% |===================================     | L2Loss: 0.42
Training:     88% |====================================    | L2Loss: 0.42
Training:     89% |====================================    | L2Loss: 0.42
Training:     90% |=====================================   | L2Loss: 0.39
Training:     91% |=====================================   | L2Loss: 0.39
Training:     92% |=====================================   | L2Loss: 0.39
Training:     93% |======================================  | L2Loss: 0.39
Training:     94% |======================================  | L2Loss: 0.39
Training:     95% |======================================= | L2Loss: 0.37
Training:     96% |======================================= | L2Loss: 0.37
Training:     97% |======================================= | L2Loss: 0.37
Training:     98% |========================================| L2Loss: 0.37
Training:     99% |========================================| L2Loss: 0.37
Training:    100% |========================================| L2Loss: 0.35

Training:      1% |=                                       | L2Loss: 0.35
Training:      2% |=                                       | L2Loss: 0.35
Training:      3% |==                                      | L2Loss: 0.35
Training:      4% |==                                      | L2Loss: 0.35
Training:      5% |===                                     | L2Loss: 7.43E-05
Training:      6% |===                                     | L2Loss: 7.43E-05
Training:      7% |===                                     | L2Loss: 7.43E-05
Training:      8% |====                                    | L2Loss: 7.43E-05
Training:      9% |====                                    | L2Loss: 7.43E-05
Training:     10% |=====                                   | L2Loss: 6.59E-05
Training:     11% |=====                                   | L2Loss: 6.59E-05
Training:     12% |=====                                   | L2Loss: 6.59E-05
Training:     13% |======                                  | L2Loss: 6.59E-05
Training:     14% |======                                  | L2Loss: 6.59E-05
Training:     15% |=======                                 | L2Loss: 5.80E-05
Training:     16% |=======                                 | L2Loss: 5.80E-05
Training:     17% |=======                                 | L2Loss: 5.80E-05
Training:     18% |========                                | L2Loss: 5.80E-05
Training:     19% |========                                | L2Loss: 5.80E-05
Training:     20% |=========                               | L2Loss: 5.92E-05
Training:     21% |=========                               | L2Loss: 5.92E-05
Training:     22% |=========                               | L2Loss: 5.92E-05
Training:     23% |==========                              | L2Loss: 5.92E-05
Training:     24% |==========                              | L2Loss: 5.92E-05
Training:     25% |===========                             | L2Loss: 5.60E-05
Training:     26% |===========                             | L2Loss: 5.60E-05
Training:     27% |===========                             | L2Loss: 5.60E-05
Training:     28% |============                            | L2Loss: 5.60E-05
Training:     29% |============                            | L2Loss: 5.60E-05
Training:     30% |=============                           | L2Loss: 5.70E-05
Training:     31% |=============                           | L2Loss: 5.70E-05
Training:     32% |=============                           | L2Loss: 5.70E-05
Training:     33% |==============                          | L2Loss: 5.70E-05
Training:     34% |==============                          | L2Loss: 5.70E-05
Training:     35% |===============                         | L2Loss: 5.73E-05
Training:     36% |===============                         | L2Loss: 5.73E-05
Training:     37% |===============                         | L2Loss: 5.73E-05
Training:     38% |================                        | L2Loss: 5.73E-05
Training:     39% |================                        | L2Loss: 5.73E-05
Training:     40% |=================                       | L2Loss: 5.78E-05
Training:     41% |=================                       | L2Loss: 5.78E-05
Training:     42% |=================                       | L2Loss: 5.78E-05
Training:     43% |==================                      | L2Loss: 5.78E-05
Training:     44% |==================                      | L2Loss: 5.78E-05
Training:     45% |===================                     | L2Loss: 5.58E-05
Training:     46% |===================                     | L2Loss: 5.58E-05
Training:     47% |===================                     | L2Loss: 5.58E-05
Training:     48% |====================                    | L2Loss: 5.58E-05
Training:     49% |====================                    | L2Loss: 5.58E-05
Training:     50% |=====================                   | L2Loss: 5.64E-05
Training:     51% |=====================                   | L2Loss: 5.64E-05
Training:     52% |=====================                   | L2Loss: 5.64E-05
Training:     53% |======================                  | L2Loss: 5.64E-05
Training:     54% |======================                  | L2Loss: 5.64E-05
Training:     55% |=======================                 | L2Loss: 5.74E-05
Training:     56% |=======================                 | L2Loss: 5.74E-05
Training:     57% |=======================                 | L2Loss: 5.74E-05
Training:     58% |========================                | L2Loss: 5.74E-05
Training:     59% |========================                | L2Loss: 5.74E-05
Training:     60% |=========================               | L2Loss: 5.72E-05
Training:     61% |=========================               | L2Loss: 5.72E-05
Training:     62% |=========================               | L2Loss: 5.72E-05
Training:     63% |==========================              | L2Loss: 5.72E-05
Training:     64% |==========================              | L2Loss: 5.72E-05
Training:     65% |===========================             | L2Loss: 5.75E-05
Training:     66% |===========================             | L2Loss: 5.75E-05
Training:     67% |===========================             | L2Loss: 5.75E-05
Training:     68% |============================            | L2Loss: 5.75E-05
Training:     69% |============================            | L2Loss: 5.75E-05
Training:     70% |=============================           | L2Loss: 5.77E-05
Training:     71% |=============================           | L2Loss: 5.77E-05
Training:     72% |=============================           | L2Loss: 5.77E-05
Training:     73% |==============================          | L2Loss: 5.77E-05
Training:     74% |==============================          | L2Loss: 5.77E-05
Training:     75% |===============================         | L2Loss: 5.85E-05
Training:     76% |===============================         | L2Loss: 5.85E-05
Training:     77% |===============================         | L2Loss: 5.85E-05
Training:     78% |================================        | L2Loss: 5.85E-05
Training:     79% |================================        | L2Loss: 5.85E-05
Training:     80% |=================================       | L2Loss: 5.78E-05
Training:     81% |=================================       | L2Loss: 5.78E-05
Training:     82% |=================================       | L2Loss: 5.78E-05
Training:     83% |==================================      | L2Loss: 5.78E-05
Training:     84% |==================================      | L2Loss: 5.78E-05
Training:     85% |===================================     | L2Loss: 5.64E-05
Training:     86% |===================================     | L2Loss: 5.64E-05
Training:     87% |===================================     | L2Loss: 5.64E-05
Training:     88% |====================================    | L2Loss: 5.64E-05
Training:     89% |====================================    | L2Loss: 5.64E-05
Training:     90% |=====================================   | L2Loss: 5.66E-05
Training:     91% |=====================================   | L2Loss: 5.66E-05
Training:     92% |=====================================   | L2Loss: 5.66E-05
Training:     93% |======================================  | L2Loss: 5.66E-05
Training:     94% |======================================  | L2Loss: 5.66E-05
Training:     95% |======================================= | L2Loss: 5.76E-05
Training:     96% |======================================= | L2Loss: 5.76E-05
Training:     97% |======================================= | L2Loss: 5.76E-05
Training:     98% |========================================| L2Loss: 5.76E-05
Training:     99% |========================================| L2Loss: 5.76E-05
Training:    100% |========================================| L2Loss: 5.63E-05

Training:      1% |=                                       | L2Loss: 5.63E-05
Training:      2% |=                                       | L2Loss: 5.63E-05
Training:      3% |==                                      | L2Loss: 5.63E-05
Training:      4% |==                                      | L2Loss: 5.63E-05
Training:      5% |===                                     | L2Loss: 7.43E-05
Training:      6% |===                                     | L2Loss: 7.43E-05
Training:      7% |===                                     | L2Loss: 7.43E-05
Training:      8% |====                                    | L2Loss: 7.43E-05
Training:      9% |====                                    | L2Loss: 7.43E-05
Training:     10% |=====                                   | L2Loss: 6.59E-05
Training:     11% |=====                                   | L2Loss: 6.59E-05
Training:     12% |=====                                   | L2Loss: 6.59E-05
Training:     13% |======                                  | L2Loss: 6.59E-05
Training:     14% |======                                  | L2Loss: 6.59E-05
Training:     15% |=======                                 | L2Loss: 5.80E-05
Training:     16% |=======                                 | L2Loss: 5.80E-05
Training:     17% |=======                                 | L2Loss: 5.80E-05
Training:     18% |========                                | L2Loss: 5.80E-05
Training:     19% |========                                | L2Loss: 5.80E-05
Training:     20% |=========                               | L2Loss: 5.92E-05
Training:     21% |=========                               | L2Loss: 5.92E-05
Training:     22% |=========                               | L2Loss: 5.92E-05
Training:     23% |==========                              | L2Loss: 5.92E-05
Training:     24% |==========                              | L2Loss: 5.92E-05
Training:     25% |===========                             | L2Loss: 5.60E-05
Training:     26% |===========                             | L2Loss: 5.60E-05
Training:     27% |===========                             | L2Loss: 5.60E-05
Training:     28% |============                            | L2Loss: 5.60E-05
Training:     29% |============                            | L2Loss: 5.60E-05
Training:     30% |=============                           | L2Loss: 5.70E-05
Training:     31% |=============                           | L2Loss: 5.70E-05
Training:     32% |=============                           | L2Loss: 5.70E-05
Training:     33% |==============                          | L2Loss: 5.70E-05
Training:     34% |==============                          | L2Loss: 5.70E-05
Training:     35% |===============                         | L2Loss: 5.73E-05
Training:     36% |===============                         | L2Loss: 5.73E-05
Training:     37% |===============                         | L2Loss: 5.73E-05
Training:     38% |================                        | L2Loss: 5.73E-05
Training:     39% |================                        | L2Loss: 5.73E-05
Training:     40% |=================                       | L2Loss: 5.78E-05
Training:     41% |=================                       | L2Loss: 5.78E-05
Training:     42% |=================                       | L2Loss: 5.78E-05
Training:     43% |==================                      | L2Loss: 5.78E-05
Training:     44% |==================                      | L2Loss: 5.78E-05
Training:     45% |===================                     | L2Loss: 5.58E-05
Training:     46% |===================                     | L2Loss: 5.58E-05
Training:     47% |===================                     | L2Loss: 5.58E-05
Training:     48% |====================                    | L2Loss: 5.58E-05
Training:     49% |====================                    | L2Loss: 5.58E-05
Training:     50% |=====================                   | L2Loss: 5.64E-05
Training:     51% |=====================                   | L2Loss: 5.64E-05
Training:     52% |=====================                   | L2Loss: 5.64E-05
Training:     53% |======================                  | L2Loss: 5.64E-05
Training:     54% |======================                  | L2Loss: 5.64E-05
Training:     55% |=======================                 | L2Loss: 5.74E-05
Training:     56% |=======================                 | L2Loss: 5.74E-05
Training:     57% |=======================                 | L2Loss: 5.74E-05
Training:     58% |========================                | L2Loss: 5.74E-05
Training:     59% |========================                | L2Loss: 5.74E-05
Training:     60% |=========================               | L2Loss: 5.72E-05
Training:     61% |=========================               | L2Loss: 5.72E-05
Training:     62% |=========================               | L2Loss: 5.72E-05
Training:     63% |==========================              | L2Loss: 5.72E-05
Training:     64% |==========================              | L2Loss: 5.72E-05
Training:     65% |===========================             | L2Loss: 5.75E-05
Training:     66% |===========================             | L2Loss: 5.75E-05
Training:     67% |===========================             | L2Loss: 5.75E-05
Training:     68% |============================            | L2Loss: 5.75E-05
Training:     69% |============================            | L2Loss: 5.75E-05
Training:     70% |=============================           | L2Loss: 5.77E-05
Training:     71% |=============================           | L2Loss: 5.77E-05
Training:     72% |=============================           | L2Loss: 5.77E-05
Training:     73% |==============================          | L2Loss: 5.77E-05
Training:     74% |==============================          | L2Loss: 5.77E-05
Training:     75% |===============================         | L2Loss: 5.85E-05
Training:     76% |===============================         | L2Loss: 5.85E-05
Training:     77% |===============================         | L2Loss: 5.85E-05
Training:     78% |================================        | L2Loss: 5.85E-05
Training:     79% |================================        | L2Loss: 5.85E-05
Training:     80% |=================================       | L2Loss: 5.78E-05
Training:     81% |=================================       | L2Loss: 5.78E-05
Training:     82% |=================================       | L2Loss: 5.78E-05
Training:     83% |==================================      | L2Loss: 5.78E-05
Training:     84% |==================================      | L2Loss: 5.78E-05
Training:     85% |===================================     | L2Loss: 5.64E-05
Training:     86% |===================================     | L2Loss: 5.64E-05
Training:     87% |===================================     | L2Loss: 5.64E-05
Training:     88% |====================================    | L2Loss: 5.64E-05
Training:     89% |====================================    | L2Loss: 5.64E-05
Training:     90% |=====================================   | L2Loss: 5.66E-05
Training:     91% |=====================================   | L2Loss: 5.66E-05
Training:     92% |=====================================   | L2Loss: 5.66E-05
Training:     93% |======================================  | L2Loss: 5.66E-05
Training:     94% |======================================  | L2Loss: 5.66E-05
Training:     95% |======================================= | L2Loss: 5.76E-05
Training:     96% |======================================= | L2Loss: 5.76E-05
Training:     97% |======================================= | L2Loss: 5.76E-05
Training:     98% |========================================| L2Loss: 5.76E-05
Training:     99% |========================================| L2Loss: 5.76E-05
Training:    100% |========================================| L2Loss: 5.63E-05

Below, we compare the model parameters learned by training on finite data and the actual parameters that generated our dataset. To access parameters, we first access the layer that we need from net and then access that layer’s weights and bias. As in our from-scratch implementation, note that our estimated parameters are close to their ground-truth counterparts.

(def params (-> model (model/get-block) (model/get-parameters)))
(def w (.getArray (.valueAt params 0)))
(def b (.getArray (.valueAt params 1)))
(def w-error (nd/to-vec (nd/- true-w (nd/reshape w (nd/get-shape true-w)))))
(println "Error in estimating w:" (vec w-error))
(println "Error in estimating w:" (- true-b (nd/get-element b)))
Error in estimating w: [-0.0019903183 7.4744225E-4]
Error in estimating w: -4.289627075193536E-4

2.8. Saving Your Model

You can also save the model for future prediction task.

(defn save-model [model path epoch name]
  (let [nio-path (java.nio.file.Paths/get path (into-array [""]))]
    (io/make-parents path)
    (model/set-property model "Epoch" epoch)
    (model/save model nio-path name)))

(save-model model "models/lin-reg" "3" "lin-reg")
(println (str model))
Model (
	Name: lin-reg
	Model location: /home/kimim/workspace/clj-d2l/models/lin-reg
	Data Type: float32
	Epoch: 3
)

2.9. Summary

  • Using clj-djl, we can implement models much more concisely.
  • In clj-djl, the dataset namespace provides tools for data processing, the nn namespace defines a large number of neural network layers, and the loss namespace defines many common loss functions.
  • initializer namespace provides various methods for model parameter initialization.
  • Dimensionality and storage are automatically inferred, but be careful not to attempt to access parameters before they have been initialized.

Author: Kimi Ma

Created: 2022-05-17 Tue 08:06