## Show code cell source

```
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks");
```

# Pseudo Random Numbers without Side Effects#

Recall that `Jax`

only uses pure functions.
Pseudo random number generators are typically implemented as stateful objects:

you initialize the generator with a seed

you call the generator to get a random number

the generator updates its internal state

This won’t cut it in `Jax`

since it would violate the purity of the function.
To deal with this we need to explicitly pass the state around.
This is done using the `PRNGKey`

object.
Please go through Pseudo Random Numbers in Jax and Stateful Computations in Jax for more details.

The state object is:

```
import jax.numpy as jnp
import jax.random as random
key = random.PRNGKey(0)
```

It is a tuple of two 32-bit unsigned integers:

```
key
```

```
Array([0, 0], dtype=uint32)
```

When sampling from a distribution, we explicitly pass the key. Here is a sample from a standard normal:

```
random.normal(key, shape=(2, 2))
```

```
Array([[ 1.8160863 , -0.75488514],
[ 0.33988908, -0.53483534]], dtype=float32)
```

Now if you pass the same key, you get the same sample:

```
random.normal(key, shape=(2, 2))
```

```
Array([[ 1.8160863 , -0.75488514],
[ 0.33988908, -0.53483534]], dtype=float32)
```

The key has not been updated:

```
key
```

```
Array([0, 0], dtype=uint32)
```

To get a different sample you need to `split`

the key:

```
key, subkey = random.split(key)
```

```
key
```

```
Array([4146024105, 967050713], dtype=uint32)
```

```
subkey
```

```
Array([2718843009, 1272950319], dtype=uint32)
```

You are kind of branching the key to start two new generators. You can use either one to get a sample:

```
random.normal(subkey, shape=(2, 2))
```

```
Array([[ 1.1378784 , -0.14331433],
[-0.59153634, 0.79466224]], dtype=float32)
```

So, this is it. You must thread the key through your code. You get used to it when you do it a few times.

Let’s look at an example. We will generate a sample from a random walk using only functional programming. The random walk is starting at \(x_0\):

where \(\sigma > 0\) and

```
def rw_step(x, sigma, key):
"""A single step of the random walk."""
key, subkey = random.split(key)
z = random.normal(subkey, shape=x.shape)
return key, x + sigma * z
```

Now we can put it in a loop that takes multiple steps and jit it:

```
from functools import partial
from jax import jit
from jax import lax
@partial(jit, static_argnums=(3,))
def sample_rw(x0, sigma, key, n_steps):
"""Sample a random walk."""
x = x0
xs = [x0]
for _ in range(n_steps):
key, x = rw_step(x, sigma, key)
xs.append(x)
xs = jnp.stack(xs)
return key, xs
```

This works:

```
sample_rw(jnp.zeros(2), 1.0, key, 10)
```

```
(Array([[ 0. , 0. ],
[ 0.00870701, -0.04888523],
[-0.8823462 , -0.71072996],
[ 0.3267806 , 1.6009982 ],
[ 1.2447658 , 2.0843194 ],
[ 1.0294966 , 1.6681931 ],
[ 4.0256443 , 0.41421673],
[ 3.6146142 , 0.53783053],
[ 3.1284342 , -0.39070803],
[ 3.636192 , -1.0362725 ],
[ 4.7213855 , -0.90524477]], dtype=float32),
Array([2172655199, 567882137], dtype=uint32))
```

But it is not good because the loop is unrolled and the compilation is triggered every time we try a new `n_steps`

.
We can use `scan`

to avoid this:

```
@jit
def sample_rw(x0, sigma, keys):
"""Sample a random walk."""
def step_rw(prev_x, key):
"""A single step of the random walk."""
z = random.normal(key, shape=prev_x.shape)
new_x = prev_x + sigma * z
return new_x, prev_x
return lax.scan(step_rw, x0, keys)[1]
```

```
n_steps = 100_000
keys = random.split(key, n_steps)
walk = sample_rw(jnp.zeros(2), 0.1, keys)
walk.shape
```

```
(100000, 2)
```

Let’s plot it:

```
fig, ax = plt.subplots()
ax.plot(walk[:, 0], walk[:, 1], lw=0.5)
ax.set(xlabel="x", ylabel="y", title="Random Walk")
sns.despine(trim=True);
```

In case you did not notice, we did 100,000 steps in no time.