Initialization of Neural Network Parameters

Hide code cell source
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks");

Initialization of Neural Network Parameters#

When it comes to deep neural nets, you should be careful about how you initialize the weights. Bad initialization can lead from not learning at all to vanishing/exploding gradients. Let’s demonstrate this with our trivial example.

Hide code cell source
import jax.numpy as jnp
import jax.random as jrandom

key = jrandom.PRNGKey(0)

# Generate some synthetic data
N = 1_000
X = jrandom.normal(key, (N,))
key, subkey = jrandom.split(key)
y = 1.5 * X ** 2 - 2 * X + jrandom.normal(subkey, (N,)) * 0.5

# Make also a test set (here an ideal one)
N_test = 50
X_test = jnp.linspace(-3, 3, N_test)
key, subkey = jrandom.split(key)
y_test = 1.5 * X_test ** 2 - 2 * X_test + jrandom.normal(subkey, (N_test,)) * 0.5

import numpy as np
import equinox as eqx
import jax
import optax
from functools import partial

    
# The function below generates batches of data
def data_generator(X, y, batch_size, shuffle=True):
    num_samples = X.shape[0]
    indices = np.arange(num_samples)
    if shuffle:
        np.random.shuffle(indices)
    
    for start_idx in range(0, num_samples, batch_size):
        end_idx = min(start_idx + batch_size, num_samples)
        batch_indices = indices[start_idx:end_idx]
        yield X[batch_indices], y[batch_indices]

# This is the loss function
def loss(model, x, y):
    y_pred = model(x)
    return optax.l2_loss(y_pred, y).mean()

# This is the training loop
def train_batch(
        model,
        x, y,
        optimizer,
        x_test, y_test,
        n_batch=10,
        n_epochs=10,
        freq=1,
    ):

    # This is the step of the optimizer. We **always** jit:
    @eqx.filter_jit
    def step(opt_state, model, xi, yi):
        value, grads = eqx.filter_value_and_grad(loss)(model, xi, yi)
        updates, opt_state = optimizer.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, value
    
    # The state of the optimizer
    opt_state = optimizer.init(model)
    # The path of the model
    path = []
    # The path of the test loss
    losses = []
    # The path of the test accuracy
    test_losses = []
    for e in range(n_epochs):
        for i, (xb, yb) in enumerate(data_generator(x, y, n_batch)):
            model, opt_state, value = step(opt_state, model, xb, yb)
            if i % freq == 0:
                path.append(model)
                losses.append(value)
                test_losses.append(loss(model, x_test, y_test))
                print(f"Epoch {e}, step {i}, loss {value:.3f}, test {test_losses[-1]:.3f}")
    return model, path, losses, test_losses

This time we are going to make the model a proper neural network. Let’s go with 3 hidden layers with 10 neurons each. We will use ReLU activation function for the hidden layers.

class NeuralNetwork(eqx.Module):
    layers: list

    def __init__(self, key):
        key1, key2, key3 = jax.random.split(key, 3)
        self.layers = [
            eqx.nn.Linear(1, 10, key=key1),
            eqx.nn.Linear(10, 10, key=key2),
            eqx.nn.Linear(10, 1, key=key3)
        ]
    
    @partial(jax.vmap, in_axes=(None, 0))
    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = jax.nn.relu(layer(x))
        return self.layers[-1](x)

The code above actually initializes the weights the correct way. We will talk about it in a while. For now, let’s start with a stupid choice. First, set all the weights and biases to the same number, say zero. Recall that Jax models are immutable, so we need to create a new model.

key, subkey = jrandom.split(key)

model = NeuralNetwork(subkey)

# Set all weights and biases to zero
zero_model = jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), model)

# Confirm that everything is set to zero
jax.tree_util.tree_leaves(zero_model)
[Array([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.]], dtype=float32),
 Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32),
 Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32),
 Array([0.], dtype=float32)]

Let’s now demonstrate that this model cannot learn because most of the gradients are zero:

from jax import jit, grad

grad_loss = jit(grad(loss))

g = grad_loss(zero_model, X[:, None], y[:, None])

jax.tree_util.tree_leaves(g)
[Array([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.]], dtype=float32),
 Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32),
 Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
 Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32),
 Array([-1.4736239], dtype=float32)]

Why does this happen? The reason here is the the relu function. Recall that relu is defined as:

\[ \text{relu}(x) = \max(0, x) \]

So, if the input is negative, the gradient is zero. This is called the dying relu problem. If the input is negative, the gradient is zero and the weights are not updated. This is why the model cannot learn.

The same thing happens in other activation functions if we are not careful. For example, in sigmoid, if the input is too large or too small, the gradient is also zero. We say that the gradients are saturated.

Okay, let’s shift all weights and biases by a bit so that relu is not saturated anymore. We can achieve this by adding a small number, say 0.1, to everything:

new_model = jax.tree_util.tree_map(lambda x: 0.1 * jnp.ones_like(x), model)
jax.tree_util.tree_leaves(new_model)
[Array([[0.1],
        [0.1],
        [0.1],
        [0.1],
        [0.1],
        [0.1],
        [0.1],
        [0.1],
        [0.1],
        [0.1]], dtype=float32),
 Array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], dtype=float32),
 Array([[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
        [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
        [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
        [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
        [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
        [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
        [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
        [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
        [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
        [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]], dtype=float32),
 Array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], dtype=float32),
 Array([[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]], dtype=float32),
 Array([0.1], dtype=float32)]

Let’s see what the gradients are now:

g = grad_loss(new_model, X[:, None], y[:, None])

jax.tree_util.tree_leaves(g)
[Array([[0.02420982],
        [0.02420982],
        [0.02420982],
        [0.02420982],
        [0.02420982],
        [0.02420982],
        [0.02420982],
        [0.02420982],
        [0.02420982],
        [0.02420982]], dtype=float32),
 Array([-0.01698821, -0.01698821, -0.01698821, -0.01698821, -0.01698821,
        -0.01698821, -0.01698821, -0.01698821, -0.01698821, -0.01698821],      dtype=float32),
 Array([[0.00072216, 0.00072216, 0.00072216, 0.00072216, 0.00072216,
         0.00072216, 0.00072216, 0.00072216, 0.00072216, 0.00072216],
        [0.00072216, 0.00072216, 0.00072216, 0.00072216, 0.00072216,
         0.00072216, 0.00072216, 0.00072216, 0.00072216, 0.00072216],
        [0.00072216, 0.00072216, 0.00072216, 0.00072216, 0.00072216,
         0.00072216, 0.00072216, 0.00072216, 0.00072216, 0.00072216],
        [0.00072216, 0.00072216, 0.00072216, 0.00072216, 0.00072216,
         0.00072216, 0.00072216, 0.00072216, 0.00072216, 0.00072216],
        [0.00072216, 0.00072216, 0.00072216, 0.00072216, 0.00072216,
         0.00072216, 0.00072216, 0.00072216, 0.00072216, 0.00072216],
        [0.00072216, 0.00072216, 0.00072216, 0.00072216, 0.00072216,
         0.00072216, 0.00072216, 0.00072216, 0.00072216, 0.00072216],
        [0.00072216, 0.00072216, 0.00072216, 0.00072216, 0.00072216,
         0.00072216, 0.00072216, 0.00072216, 0.00072216, 0.00072216],
        [0.00072216, 0.00072216, 0.00072216, 0.00072216, 0.00072216,
         0.00072216, 0.00072216, 0.00072216, 0.00072216, 0.00072216],
        [0.00072216, 0.00072216, 0.00072216, 0.00072216, 0.00072216,
         0.00072216, 0.00072216, 0.00072216, 0.00072216, 0.00072216],
        [0.00072216, 0.00072216, 0.00072216, 0.00072216, 0.00072216,
         0.00072216, 0.00072216, 0.00072216, 0.00072216, 0.00072216]],      dtype=float32),
 Array([-0.11663406, -0.11663406, -0.11663406, -0.11663406, -0.11663406,
        -0.11663406, -0.11663406, -0.11663406, -0.11663406, -0.11663406],      dtype=float32),
 Array([[-0.10941248, -0.10941248, -0.10941248, -0.10941248, -0.10941248,
         -0.10941248, -0.10941248, -0.10941248, -0.10941248, -0.10941248]],      dtype=float32),
 Array([-1.1663406], dtype=float32)]

Now get non-zero gradients, but we have another problem. Pay attention to the gradients for the weight of a given layer. They are all the same! This is not good. The model will move all the weights in the same direction, which is not what we want. It will never learn anything useful. This is called the symmetry problem. We need to initialize the parameters in a way that breaks the symmetry.

Xavier or Glorot initialization#

The Xavier (or Glorot) initialization was introduced in the paper Understanding the difficulty of training deep feedforward neural networks by Xavier Glorot and Yoshua Bengio. The method initializes the weights of a layer with a uniform distribution in the range \([-a, a]\) with \(a\) being:

\[ a = \sqrt{\frac{6}{n_{\text{in}} + n_{\text{out}}}} \]

where \(n_{\text{in}}\) is the number of inputs to the layer and \(n_{\text{out}}\) is the number of outputs from the layer. The weights are initialized by:

\[ w_{ij} \sim \mathcal{U}([-a, a]). \]

The biases are initialized to zero or to a small positive number (if the activation function is ReLU).

Here is how we can do this in Jax:

def random_weight(key, shape, lim):
    return jrandom.uniform(key, shape, minval=-lim, maxval=lim)


xavier_model = model
for i in range(3):
    # Initialize the weight for layer i
    key, subkey = jrandom.split(key)
    shape = xavier_model.layers[i].weight.shape
    xavier_model = eqx.tree_at(
        lambda m: m.layers[i].weight,
        xavier_model,
        random_weight(
            subkey,
            shape,
            jnp.sqrt(6 / (shape[0] + shape[1]))
        ),
    )
    # Set the bias to 0.1
    xavier_model = eqx.tree_at(
        lambda m: m.layers[i].bias,
        xavier_model,
        0.1 * jnp.ones_like(xavier_model.layers[i].bias),
    )

jax.tree_util.tree_leaves(xavier_model)
[Array([[ 0.32371843],
        [-0.09270354],
        [ 0.6612312 ],
        [-0.6682776 ],
        [-0.24418366],
        [-0.37407488],
        [-0.25916752],
        [ 0.5819582 ],
        [-0.15898605],
        [-0.3486244 ]], dtype=float32),
 Array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], dtype=float32),
 Array([[ 0.3676433 ,  0.37487116,  0.08109037,  0.03349315, -0.16421711,
         -0.36349794, -0.07435115, -0.44226253,  0.18982252,  0.5187848 ],
        [-0.5399611 , -0.37376797,  0.27842256,  0.12041358, -0.33761817,
         -0.22923872,  0.3782641 ,  0.3010011 ,  0.20852078,  0.3341495 ],
        [-0.29488283,  0.40181315,  0.06843881,  0.1779477 ,  0.12430247,
          0.34012973,  0.41524926,  0.30109957,  0.02873337, -0.1421873 ],
        [ 0.08526394,  0.03162144,  0.14541529, -0.29462987,  0.456133  ,
          0.42672893,  0.23817806, -0.22955617, -0.28543863, -0.27825502],
        [-0.00537615, -0.46676868,  0.15361108,  0.2696966 , -0.3681953 ,
         -0.5060839 , -0.06666727, -0.5398564 , -0.4126733 , -0.33224645],
        [-0.46251455,  0.3704435 , -0.05446833,  0.43191114,  0.36721903,
          0.15557119,  0.2147323 ,  0.3821736 , -0.03419362, -0.27052712],
        [ 0.13798892, -0.29381084, -0.3722884 ,  0.1156525 ,  0.35628912,
          0.46102703, -0.29348907,  0.45198894,  0.3082577 ,  0.31575575],
        [-0.4943312 , -0.35612732,  0.48273635, -0.4380861 , -0.31668007,
         -0.2622906 , -0.41907677, -0.04267656, -0.22444655, -0.29398558],
        [-0.38443485, -0.16884956, -0.53369087, -0.06442156,  0.42611882,
          0.03432498,  0.20492494,  0.3643281 , -0.01832165, -0.2769835 ],
        [-0.26088378, -0.5202522 , -0.4201553 , -0.30814174,  0.20963025,
          0.38605714, -0.39654225, -0.05371719,  0.18494013, -0.15446864]],      dtype=float32),
 Array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], dtype=float32),
 Array([[-0.19861881, -0.00998325, -0.04571752,  0.58426976,  0.1256166 ,
          0.7041999 , -0.08622682,  0.34889665,  0.2071388 , -0.229225  ]],      dtype=float32),
 Array([0.1], dtype=float32)]

Test the gradients:

g = grad_loss(xavier_model, X[:, None], y[:, None])

jax.tree_util.tree_leaves(g)
[Array([[ 0.02699523],
        [ 0.37146565],
        [-0.02454231],
        [ 0.18399827],
        [ 1.2244885 ],
        [ 0.76599014],
        [ 0.69746107],
        [ 0.00779888],
        [-0.53853697],
        [-1.0709085 ]], dtype=float32),
 Array([-0.09308025, -0.20425463,  0.03026273, -0.11606807, -0.7400332 ,
        -0.4753357 , -0.42907578,  0.03981695,  0.2839816 ,  0.6556119 ],      dtype=float32),
 Array([[-1.53877772e-02,  5.92427403e-02, -2.56949235e-02,
          2.88320154e-01,  1.20761983e-01,  1.72247618e-01,
          1.26719296e-01, -2.32813507e-02,  8.65963027e-02,
          1.62184060e-01],
        [ 1.10576606e-04,  2.97773862e-03,  4.12342313e-04,
          1.44919343e-02,  6.06990140e-03,  8.65774229e-03,
          6.36933604e-03,  3.41097708e-04,  4.35262080e-03,
          8.15191399e-03],
        [ 5.06376964e-04,  1.36363255e-02,  1.88829028e-03,
          6.63647130e-02,  2.77966522e-02,  3.96474712e-02,
          2.91678905e-02,  1.56203099e-03,  1.99324954e-02,
          3.73310670e-02],
        [-6.47149421e-03, -1.74272239e-01, -2.41323281e-02,
         -8.48140836e-01, -3.55241179e-01, -5.06694555e-01,
         -3.72765601e-01, -1.99627522e-02, -2.54737228e-01,
         -4.77090985e-01],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00],
        [-7.79986661e-03, -2.10044265e-01, -2.90858671e-02,
         -1.02223456e+00, -4.28159773e-01, -6.10701203e-01,
         -4.49281394e-01, -2.40603983e-02, -3.07025880e-01,
         -5.75020969e-01],
        [ 9.55066003e-04,  2.57191807e-02,  3.56146414e-03,
          1.25169054e-01,  5.24266735e-02,  7.47782290e-02,
          5.50129265e-02,  2.94611370e-03,  3.75942476e-02,
          7.04093128e-02],
        [-4.95177973e-03,  3.79367894e-03, -1.56388544e-02,
          0.00000000e+00,  6.33984338e-04,  1.48322244e-04,
          5.46185591e-04, -1.31287407e-02,  1.64722861e-03,
          2.00687762e-04],
        [ 8.80013104e-04, -6.39121532e-02,  1.05306366e-03,
         -3.00687969e-01, -1.26234561e-01, -1.79674804e-01,
         -1.32399261e-01,  1.02004455e-03, -9.11822915e-02,
         -1.69203907e-01],
        [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00]], dtype=float32),
 Array([ 0.19894518,  0.01097653,  0.05026617, -0.6424015 ,  0.        ,
        -0.77426416,  0.09480594,  0.0529851 , -0.2577178 ,  0.        ],      dtype=float32),
 Array([[-0.31732795, -0.39008602, -1.0298618 , -0.13849749,  0.        ,
         -1.112479  , -1.0132195 , -0.00427387, -0.1674771 ,  0.        ]],      dtype=float32),
 Array([-1.0994947], dtype=float32)]

Another thing that is commonly done is to look at the histogram of the gradient of all parameters:

all_grads = jnp.hstack(
    jax.tree_util.tree_map(
        lambda p: p.flatten(),
        jax.tree_util.tree_leaves(g)
    )
)
fig, ax = plt.subplots()
ax.hist(all_grads, bins=50, density=True, alpha=0.5)
ax.set(xlabel="Gradient value", ylabel="Density");
../_images/6f1dda6851b1a89684175c18eeb593b9f3228452e86d8e19b3187aea11a01996.svg

It looks better than before, but it is not perfect. The reason is that Xavier initialization has been designed for sigmoid and tanh activation functions. For ReLU, we need to use a different initialization.

He initialization#

The He initialization was introduced in the paper Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. The method initializes the weights of a layer with a uniform distribution in the range \([-a, a]\) with \(a\) being:

\[ a = \sqrt{\frac{1}{n_{\text{in}}}} \]

where \(n_{\text{in}}\) is the number of inputs to the layer. This is the default initialization used in equinox.nn.Linear. Let’s see how it works:

g = grad_loss(model, X[:, None], y[:, None])
jax.tree_util.tree_leaves(g)
[Array([[-0.137161  ],
        [-0.14594285],
        [ 0.1972345 ],
        [-0.05604376],
        [-0.0023809 ],
        [ 0.00070995],
        [ 0.36193305],
        [ 0.00602406],
        [-0.01431112],
        [ 0.        ]], dtype=float32),
 Array([ 0.09015069,  0.12984118, -0.13102864,  0.05271221, -0.00669131,
        -0.00057694, -0.22451383,  0.00280493, -0.00255928,  0.        ],      dtype=float32),
 Array([[ 1.0778138e-02,  1.2327660e-02,  8.2478148e-04,  1.7101418e-02,
          2.3551367e-02,  2.0171790e-03,  5.2287388e-03,  2.0648468e-02,
          2.0914447e-02,  0.0000000e+00],
        [ 0.0000000e+00, -2.7174037e-03,  0.0000000e+00, -2.0803507e-02,
         -3.9404470e-02, -2.9371942e-03,  1.6780789e-03, -2.9066479e-02,
         -3.0834939e-02,  0.0000000e+00],
        [-1.3764948e-01, -1.4081910e-01, -5.5947408e-02, -4.5996405e-02,
         -1.1695730e-02, -8.3952531e-04, -8.1302613e-02, -8.0701225e-03,
         -8.8563329e-03,  0.0000000e+00],
        [-4.8459396e-01, -4.9094301e-01, -1.9145952e-01, -1.5886921e-01,
         -3.5563625e-02, -3.0460316e-03, -2.8462160e-01, -3.1180121e-02,
         -3.1581759e-02,  0.0000000e+00],
        [-1.1312931e-02, -3.4168558e-03, -6.0383808e-03, -2.5754301e-03,
         -1.0359661e-02, -8.8759384e-04, -3.5075848e-03, -9.0856841e-03,
         -9.2027187e-03,  0.0000000e+00],
        [-1.5130803e-01, -1.5622641e-01, -5.9776969e-02, -4.5845456e-02,
          4.0456787e-04,  0.0000000e+00, -9.0529509e-02,  0.0000000e+00,
          8.2088269e-07,  0.0000000e+00],
        [ 4.1664112e-01,  4.2209977e-01,  1.6461185e-01,  1.3659155e-01,
          3.0576676e-02,  2.6188979e-03,  2.4471018e-01,  2.6807845e-02,
          2.7153166e-02,  0.0000000e+00],
        [ 2.6695612e-01,  2.7357277e-01,  1.0685656e-01,  7.9259232e-02,
          0.0000000e+00,  0.0000000e+00,  1.5908462e-01,  0.0000000e+00,
          0.0000000e+00,  0.0000000e+00],
        [ 5.4095697e-01,  5.4804456e-01,  2.1372816e-01,  1.7734724e-01,
          3.9700039e-02,  3.4003155e-03,  3.1772596e-01,  3.4806676e-02,
          3.5255034e-02,  0.0000000e+00],
        [-3.6420399e-01, -3.6933330e-01, -1.4719671e-01, -1.0505517e-01,
          0.0000000e+00,  0.0000000e+00, -2.1582879e-01,  0.0000000e+00,
          0.0000000e+00,  0.0000000e+00]], dtype=float32),
 Array([ 0.02216513, -0.02088149, -0.11294242, -0.3923564 , -0.00411087,
        -0.1204321 ,  0.3373377 ,  0.20993245,  0.4379914 , -0.28158748],      dtype=float32),
 Array([[-0.0014199 , -0.09171595, -0.24393448, -0.63895756, -0.10640763,
         -0.4988277 , -1.0547736 , -0.5008075 , -0.6322478 , -0.27538082]],      dtype=float32),
 Array([-1.5230243], dtype=float32)]
all_grads = jnp.hstack(
    jax.tree_util.tree_map(
        lambda p: p.flatten(),
        jax.tree_util.tree_leaves(g)
    )
)
fig, ax = plt.subplots()
ax.hist(all_grads, bins=50, density=True, alpha=0.5)
ax.set(xlabel="Gradient value", ylabel="Density");
../_images/b24e8a66b870e6958467795196afda5cd210a01d0ecc814c2b052bcc1b2e567c.svg

It looks better because the gradients are more spread out.

You should use He initialization for ReLU and Xavier initialization for sigmoid and tanh. You should prefer He initialization for deeper networks. But at the end of the day, these are just rules of thumb. You should always look at your gradients histograms and make sure that they are not too peaked at zero.

Let’s finish up by seeing how our network can learn:

model, path, losses, test_losses = train_batch(
    model,
    X[:, None], y[:, None],
    optax.adam(0.01),
    X_test[:, None], y_test[:, None],
    n_batch=100,
    n_epochs=100,
    freq=10,
)
Hide code cell output
Epoch 0, step 0, loss 4.046, test 26.340
Epoch 1, step 0, loss 2.562, test 20.785
Epoch 2, step 0, loss 2.682, test 14.236
Epoch 3, step 0, loss 3.412, test 8.201
Epoch 4, step 0, loss 0.886, test 5.213
Epoch 5, step 0, loss 0.531, test 2.700
Epoch 6, step 0, loss 0.367, test 1.592
Epoch 7, step 0, loss 0.277, test 1.419
Epoch 8, step 0, loss 0.219, test 1.221
Epoch 9, step 0, loss 0.214, test 1.027
Epoch 10, step 0, loss 0.200, test 0.797
Epoch 11, step 0, loss 0.174, test 0.584
Epoch 12, step 0, loss 0.201, test 0.477
Epoch 13, step 0, loss 0.166, test 0.366
Epoch 14, step 0, loss 0.315, test 0.341
Epoch 15, step 0, loss 0.202, test 0.348
Epoch 16, step 0, loss 0.186, test 0.315
Epoch 17, step 0, loss 0.169, test 0.294
Epoch 18, step 0, loss 0.139, test 0.317
Epoch 19, step 0, loss 0.156, test 0.237
Epoch 20, step 0, loss 0.252, test 0.215
Epoch 21, step 0, loss 0.294, test 0.245
Epoch 22, step 0, loss 0.167, test 0.210
Epoch 23, step 0, loss 0.214, test 0.220
Epoch 24, step 0, loss 0.286, test 0.191
Epoch 25, step 0, loss 0.262, test 0.179
Epoch 26, step 0, loss 0.145, test 0.204
Epoch 27, step 0, loss 0.170, test 0.224
Epoch 28, step 0, loss 0.172, test 0.139
Epoch 29, step 0, loss 0.186, test 0.155
Epoch 30, step 0, loss 0.218, test 0.149
Epoch 31, step 0, loss 0.147, test 0.153
Epoch 32, step 0, loss 0.135, test 0.138
Epoch 33, step 0, loss 0.152, test 0.142
Epoch 34, step 0, loss 0.234, test 0.102
Epoch 35, step 0, loss 0.207, test 0.104
Epoch 36, step 0, loss 0.132, test 0.138
Epoch 37, step 0, loss 0.138, test 0.105
Epoch 38, step 0, loss 0.133, test 0.096
Epoch 39, step 0, loss 0.197, test 0.087
Epoch 40, step 0, loss 0.170, test 0.083
Epoch 41, step 0, loss 0.111, test 0.086
Epoch 42, step 0, loss 0.115, test 0.085
Epoch 43, step 0, loss 0.140, test 0.090
Epoch 44, step 0, loss 0.112, test 0.083
Epoch 45, step 0, loss 0.159, test 0.081
Epoch 46, step 0, loss 0.143, test 0.102
Epoch 47, step 0, loss 0.110, test 0.071
Epoch 48, step 0, loss 0.121, test 0.079
Epoch 49, step 0, loss 0.120, test 0.069
Epoch 50, step 0, loss 0.155, test 0.084
Epoch 51, step 0, loss 0.128, test 0.082
Epoch 52, step 0, loss 0.141, test 0.072
Epoch 53, step 0, loss 0.137, test 0.072
Epoch 54, step 0, loss 0.113, test 0.072
Epoch 55, step 0, loss 0.138, test 0.072
Epoch 56, step 0, loss 0.130, test 0.093
Epoch 57, step 0, loss 0.126, test 0.074
Epoch 58, step 0, loss 0.156, test 0.090
Epoch 59, step 0, loss 0.131, test 0.067
Epoch 60, step 0, loss 0.109, test 0.075
Epoch 61, step 0, loss 0.155, test 0.065
Epoch 62, step 0, loss 0.129, test 0.075
Epoch 63, step 0, loss 0.137, test 0.081
Epoch 64, step 0, loss 0.098, test 0.069
Epoch 65, step 0, loss 0.155, test 0.101
Epoch 66, step 0, loss 0.137, test 0.079
Epoch 67, step 0, loss 0.149, test 0.073
Epoch 68, step 0, loss 0.144, test 0.086
Epoch 69, step 0, loss 0.133, test 0.061
Epoch 70, step 0, loss 0.145, test 0.065
Epoch 71, step 0, loss 0.129, test 0.093
Epoch 72, step 0, loss 0.114, test 0.066
Epoch 73, step 0, loss 0.109, test 0.075
Epoch 74, step 0, loss 0.133, test 0.124
Epoch 75, step 0, loss 0.139, test 0.073
Epoch 76, step 0, loss 0.151, test 0.103
Epoch 77, step 0, loss 0.121, test 0.079
Epoch 78, step 0, loss 0.149, test 0.080
Epoch 79, step 0, loss 0.116, test 0.076
Epoch 80, step 0, loss 0.122, test 0.064
Epoch 81, step 0, loss 0.162, test 0.070
Epoch 82, step 0, loss 0.137, test 0.079
Epoch 83, step 0, loss 0.131, test 0.073
Epoch 84, step 0, loss 0.155, test 0.083
Epoch 85, step 0, loss 0.158, test 0.092
Epoch 86, step 0, loss 0.140, test 0.095
Epoch 87, step 0, loss 0.127, test 0.070
Epoch 88, step 0, loss 0.134, test 0.068
Epoch 89, step 0, loss 0.123, test 0.087
Epoch 90, step 0, loss 0.125, test 0.073
Epoch 91, step 0, loss 0.150, test 0.091
Epoch 92, step 0, loss 0.129, test 0.094
Epoch 93, step 0, loss 0.101, test 0.101
Epoch 94, step 0, loss 0.110, test 0.079
Epoch 95, step 0, loss 0.166, test 0.076
Epoch 96, step 0, loss 0.150, test 0.080
Epoch 97, step 0, loss 0.119, test 0.097
Epoch 98, step 0, loss 0.125, test 0.106
Epoch 99, step 0, loss 0.148, test 0.072

Here is the evolution of the loss:

fig, ax = plt.subplots()
ax.plot(losses, label="Train")
ax.plot(test_losses, label="Test")
ax.set(xlabel="Iteration $\\times$ 100", ylabel="Loss", title="Loss")
plt.legend(loc='best', frameon=False)
sns.despine(trim=True);
../_images/c5f37ac5a31ca68a468438074e87e28b8beec228094e1199de52af3cddab329a.svg

Let’s make some predictions:

xs = jnp.linspace(-4, 4, 100)
fig, ax = plt.subplots()
ax.scatter(X, y, color='black', label="Data", alpha=0.5, s=2)
ax.plot(xs, model(xs[:, None]).flatten(), label="Model")
ax.plot(xs, 1.5 * xs ** 2 - 2 * xs, '--', label="True")
ax.set(xlabel="x", ylabel="y", title="Model fit")
plt.legend(loc='best', frameon=False)
sns.despine(trim=True);
../_images/1bfdf3c69a6d4f558eba050c95ae9a08f3538f430bcf59c7910ec1e2b2d8a39a.svg

The fit looks good. It is not perfect because we are using a stupid model (we should have used a simple polynomial regression), but it is good enough to demonstrate the basic concepts.