Hide code cell source
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
import seaborn as sns

Pseudo Random Numbers without Side Effects#

Recall that Jax only uses pure functions. Pseudo random number generators are typically implemented as stateful objects:

  • you initialize the generator with a seed

  • you call the generator to get a random number

  • the generator updates its internal state

This won’t cut it in Jax since it would violate the purity of the function. To deal with this we need to explicitly pass the state around. This is done using the PRNGKey object. Please go through Pseudo Random Numbers in Jax and Stateful Computations in Jax for more details.

The state object is:

import jax.numpy as jnp
import jax.random as random

key = random.PRNGKey(0)

It is a tuple of two 32-bit unsigned integers:

Array([0, 0], dtype=uint32)

When sampling from a distribution, we explicitly pass the key. Here is a sample from a standard normal:

random.normal(key, shape=(2, 2))
Array([[ 1.8160863 , -0.75488514],
       [ 0.33988908, -0.53483534]], dtype=float32)

Now if you pass the same key, you get the same sample:

random.normal(key, shape=(2, 2))
Array([[ 1.8160863 , -0.75488514],
       [ 0.33988908, -0.53483534]], dtype=float32)

The key has not been updated:

Array([0, 0], dtype=uint32)

To get a different sample you need to split the key:

key, subkey = random.split(key)
Array([4146024105,  967050713], dtype=uint32)
Array([2718843009, 1272950319], dtype=uint32)

You are kind of branching the key to start two new generators. You can use either one to get a sample:

random.normal(subkey, shape=(2, 2))
Array([[ 1.1378784 , -0.14331433],
       [-0.59153634,  0.79466224]], dtype=float32)

So, this is it. You must thread the key through your code. You get used to it when you do it a few times.

Let’s look at an example. We will generate a sample from a random walk using only functional programming. The random walk is starting at \(x_0\):

\[ x_{t+1} = x_t + \sigma z_t, \]

where \(\sigma > 0\) and

\[ z_t \sim N(0, 1). \]
def rw_step(x, sigma, key):
    """A single step of the random walk."""
    key, subkey = random.split(key)
    z = random.normal(subkey, shape=x.shape)
    return key, x + sigma * z

Now we can put it in a loop that takes multiple steps and jit it:

from functools import partial
from jax import jit
from jax import lax

@partial(jit, static_argnums=(3,))
def sample_rw(x0, sigma, key, n_steps):
    """Sample a random walk."""
    x = x0
    xs = [x0]
    for _ in range(n_steps):
        key, x = rw_step(x, sigma, key)
    xs = jnp.stack(xs)
    return key, xs

This works:

sample_rw(jnp.zeros(2), 1.0, key, 10)
(Array([[ 0.        ,  0.        ],
        [ 0.00870701, -0.04888523],
        [-0.8823462 , -0.71072996],
        [ 0.3267806 ,  1.6009982 ],
        [ 1.2447658 ,  2.0843194 ],
        [ 1.0294966 ,  1.6681931 ],
        [ 4.0256443 ,  0.41421673],
        [ 3.6146142 ,  0.53783053],
        [ 3.1284342 , -0.39070803],
        [ 3.636192  , -1.0362725 ],
        [ 4.7213855 , -0.90524477]], dtype=float32),
 Array([2172655199,  567882137], dtype=uint32))

But it is not good because the loop is unrolled and the compilation is triggered every time we try a new n_steps. We can use scan to avoid this:

def sample_rw(x0, sigma, keys):
    """Sample a random walk."""

    def step_rw(prev_x, key):
        """A single step of the random walk."""
        z = random.normal(key, shape=prev_x.shape)
        new_x = prev_x + sigma * z
        return new_x, prev_x
    return lax.scan(step_rw, x0, keys)[1]
n_steps = 100_000
keys = random.split(key, n_steps)

walk = sample_rw(jnp.zeros(2), 0.1, keys)
(100000, 2)

Let’s plot it:

fig, ax = plt.subplots()
ax.plot(walk[:, 0], walk[:, 1], lw=0.5)
ax.set(xlabel="x", ylabel="y", title="Random Walk")

In case you did not notice, we did 100,000 steps in no time.