Show code cell source
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks");
Pseudo Random Numbers without Side Effects#
Recall that Jax
only uses pure functions.
Pseudo random number generators are typically implemented as stateful objects:
you initialize the generator with a seed
you call the generator to get a random number
the generator updates its internal state
This won’t cut it in Jax
since it would violate the purity of the function.
To deal with this we need to explicitly pass the state around.
This is done using the PRNGKey
object.
Please go through Pseudo Random Numbers in Jax and Stateful Computations in Jax for more details.
The state object is:
import jax.numpy as jnp
import jax.random as random
key = random.PRNGKey(0)
It is a tuple of two 32-bit unsigned integers:
key
Array([0, 0], dtype=uint32)
When sampling from a distribution, we explicitly pass the key. Here is a sample from a standard normal:
random.normal(key, shape=(2, 2))
Array([[ 1.8160863 , -0.75488514],
[ 0.33988908, -0.53483534]], dtype=float32)
Now if you pass the same key, you get the same sample:
random.normal(key, shape=(2, 2))
Array([[ 1.8160863 , -0.75488514],
[ 0.33988908, -0.53483534]], dtype=float32)
The key has not been updated:
key
Array([0, 0], dtype=uint32)
To get a different sample you need to split
the key:
key, subkey = random.split(key)
key
Array([4146024105, 967050713], dtype=uint32)
subkey
Array([2718843009, 1272950319], dtype=uint32)
You are kind of branching the key to start two new generators. You can use either one to get a sample:
random.normal(subkey, shape=(2, 2))
Array([[ 1.1378784 , -0.14331433],
[-0.59153634, 0.79466224]], dtype=float32)
So, this is it. You must thread the key through your code. You get used to it when you do it a few times.
Let’s look at an example. We will generate a sample from a random walk using only functional programming. The random walk is starting at \(x_0\):
where \(\sigma > 0\) and
def rw_step(x, sigma, key):
"""A single step of the random walk."""
key, subkey = random.split(key)
z = random.normal(subkey, shape=x.shape)
return key, x + sigma * z
Now we can put it in a loop that takes multiple steps and jit it:
from functools import partial
from jax import jit
from jax import lax
@partial(jit, static_argnums=(3,))
def sample_rw(x0, sigma, key, n_steps):
"""Sample a random walk."""
x = x0
xs = [x0]
for _ in range(n_steps):
key, x = rw_step(x, sigma, key)
xs.append(x)
xs = jnp.stack(xs)
return key, xs
This works:
sample_rw(jnp.zeros(2), 1.0, key, 10)
(Array([[ 0. , 0. ],
[ 0.00870701, -0.04888523],
[-0.8823462 , -0.71072996],
[ 0.3267806 , 1.6009982 ],
[ 1.2447658 , 2.0843194 ],
[ 1.0294966 , 1.6681931 ],
[ 4.0256443 , 0.41421673],
[ 3.6146142 , 0.53783053],
[ 3.1284342 , -0.39070803],
[ 3.636192 , -1.0362725 ],
[ 4.7213855 , -0.90524477]], dtype=float32),
Array([2172655199, 567882137], dtype=uint32))
But it is not good because the loop is unrolled and the compilation is triggered every time we try a new n_steps
.
We can use scan
to avoid this:
@jit
def sample_rw(x0, sigma, keys):
"""Sample a random walk."""
def step_rw(prev_x, key):
"""A single step of the random walk."""
z = random.normal(key, shape=prev_x.shape)
new_x = prev_x + sigma * z
return new_x, prev_x
return lax.scan(step_rw, x0, keys)[1]
n_steps = 100_000
keys = random.split(key, n_steps)
walk = sample_rw(jnp.zeros(2), 0.1, keys)
walk.shape
(100000, 2)
Let’s plot it:
fig, ax = plt.subplots()
ax.plot(walk[:, 0], walk[:, 1], lw=0.5)
ax.set(xlabel="x", ylabel="y", title="Random Walk")
sns.despine(trim=True);
In case you did not notice, we did 100,000 steps in no time.