Autograd with JAX

Hide code cell source
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks");

Autograd with JAX#

We do not cover all of JAX’s features here, but we do cover the most important ones for our purposes. For a more complete introduction, see the Advanced Automatic Differentiation in Jax and the JAX: Autodiff Cookbook tutorials. Here we just cover the basics.

In the tutorials above, you can find some trivial examples. Here let’s do something a bit more interesting. First, take the radial basis functions we made in the Vectorization section:

import jax
import jax.numpy as jnp
from jax import jit, vmap, grad
from functools import partial

# This is only vectorized for the centers
@partial(vmap, in_axes=(None, 0, None), out_axes=0)
def rbf_basis(x, c, sigma2):
    return jnp.exp(-jnp.sum((x-c)**2, axis=-1) / sigma2)

Let’s let’s also make a function that builds the generalized linear model and vectorize with respect to x:

@jit
@partial(vmap, in_axes=(0, None, None, None), out_axes=0)
def model(x, w, c, sigma2):
    return jnp.dot(rbf_basis(x, c, sigma2), w).reshape(-1)

Let’s draw some random weights and visualize this in 1D:

import jax.random as random

sigma2 = 0.01
c = jnp.linspace(-1, 1, 10).reshape(-1, 1)
key = random.PRNGKey(0)
key, subkey = random.split(key)
w = random.normal(subkey, (10,))
x = jnp.linspace(-1, 1, 100).reshape(-1, 1)
y = model(x, w, c, sigma2)

fig, ax = plt.subplots()
ax.plot(x, y, label='model')
ax.set(xlabel='x', ylabel='y', title='RBF model')
sns.despine(trim=True);
../_images/88f311fd4cbb5055aefc59cd06618b6afe231dde75a7ff78614b8de807548ea1.svg

Now let’s add some data set up a sum of squared errors loss function:

key, subkey = random.split(key)
x_train = random.uniform(subkey, (20, 1), minval=-1, maxval=1)
key, subkey = random.split(key)
y_train = x_train ** 3 + 0.1 * random.normal(subkey, (20,1))

fig, ax = plt.subplots()
ax.plot(x, y, label='initial model')
ax.plot(x_train, y_train, 'kx', label='training data')
ax.set(xlabel='x', ylabel='y', title='RBF model')
plt.legend(loc='best', frameon=False)
sns.despine(trim=True);
../_images/c8537f24b1eae14ce8e7b847f378cefcd0190838721510c50bd8afbd3c7b5d0c.svg

The loss function:

def loss(w, c, sigma2, x, y):
    y_pred = model(x, w, c, sigma2)
    return jnp.mean((y - y_pred)**2)

Here is how we can compute the gradient of the loss function with respect to the weights:

loss_grad = jit(grad(loss, argnums=0))
loss_grad(w, c, sigma2, x_train, y_train)
Array([-0.28532007, -0.29728252, -0.19095227, -0.14909464,  0.03849696,
       -0.01216208, -0.4540519 , -0.01067693,  0.04548343, -0.06975526],      dtype=float32)

But we can do more. We can also compute the gradient of the loss function with respect to the the centers and sigma2:

full_loss_grad = jit(grad(loss, argnums=(0, 1, 2)))

Here is the result:

w_grad, c_grad, sigma2_grad = full_loss_grad(w, c, sigma2, x_train, y_train)
print(w_grad.shape, c_grad.shape, sigma2_grad.shape)
(10,) (10, 1) ()

We will discuss optimization in a later lecture. Below, is a naive implementation of gradient descent:

import numpy as np

@jit
def gd_step(w, c, sigma2, x, y, lr=0.1):
    w_grad = loss_grad(w, c, sigma2, x, y)
    w = w - lr * w_grad
    return w, c, sigma2

num_iter = 10_000
lr = 0.01
n_batch = 10

w = random.normal(key, (10,))
c = random.uniform(key, (10, 1), minval=-1, maxval=1)
sigma2 = 0.1

for i in range(num_iter):
    w, c, sigma2 = gd_step(w, c, sigma2, x_train, y_train, lr)
    
pred = model(x, w, c, sigma2)

fig, ax = plt.subplots()
ax.plot(x, pred, label='initial model')
ax.plot(x_train, y_train, 'kx', label='training data')
ax.set(xlabel='x', ylabel='y', title='RBF model')
plt.legend(loc='best', frameon=False)
sns.despine(trim=True);
../_images/e441e224f070515724840fd1ea0106c750a1d195fc76b12949c50b769f89b321.svg