No-U-Turn Sampler with Blackjax

Hide code cell source
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks")

from functools import partial
import jax.numpy as jnp
from jax import lax, vmap
from jax.scipy.stats import multivariate_normal
import jax.random as jrandom
import blackjax

import numpy as np

key = jrandom.PRNGKey(123)

No-U-Turn Sampler with Blackjax#

TODO: You can do the catalysis problem with blackjax, but keep the sigma fixed.

def banana_logdensity(x, a=1.15, b=0.5, rho=0.5):
    """A banana-shaped distribution. Comes from a nonlinear transformation of a correlated Gaussian."""
    x1, x2 = x
    u1 = x1/a
    u2 = a*(x2 - b*(u1**2 + a**2))
    return multivariate_normal.logpdf(jnp.array([u1, u2]), jnp.zeros(2), jnp.array([[1, rho], [rho, 1]]))

def plot_2d_function(f, alpha=1.0, plot_type="pcolormesh", ax=None, levels=None):
    x = jnp.linspace(-4, 4, 100)
    y = jnp.linspace(-1, 11, 100)
    X, Y = jnp.meshgrid(x, y)
    Z = vmap(f)(jnp.stack([X.flatten(), Y.flatten()], axis=1)).reshape(X.shape)

    if ax is None:
        _, ax = plt.subplots()
    if plot_type == "contour":
        ax.contour(X, Y, jnp.exp(Z), alpha=alpha, cmap="Greens", levels=levels)
    elif plot_type == "pcolormesh":
        ax.pcolormesh(X, Y, jnp.exp(Z), alpha=alpha)
    ax.set_xticks([-2, 0, 2])
    ax.set_yticks([0, 2, 4, 6])
    ax.set_aspect("equal")
    sns.despine(trim=True, left=True, bottom=True)
    return ax
banana_logdensity_high_curv = partial(banana_logdensity, a=1.15, b=1.0, rho=0.9)
plot_2d_function(banana_logdensity_high_curv);
../../_images/55fbcfb3f2990042fe5c9e59e7a8815febbc54176c392cf5e846f59cf9733074.svg

No-U-Turn Sampler (NUTS) recap#

# HMC hyperparameters
step_size = 0.1
inverse_mass_matrix = jnp.eye(2)

# Create the HMC kernel
nuts = blackjax.nuts(
    banana_logdensity_high_curv, step_size, inverse_mass_matrix
)

def step(state, _):
    """A single step of the Hamiltonian Monte Carlo sampler. Used with `lax.scan`."""
    key, kernel_state = state
    key, subkey = jrandom.split(key)
    kernel_state, info = nuts.step(subkey, kernel_state)
    return (key, kernel_state), (kernel_state.position, info)

def run_mcmc_chain(key, init_state, num_samples):
    """Run a chain of MCMC."""
    _, (samples, info) = lax.scan(step, (key, init_state), None, length=num_samples)
    return samples, info
num_chains = 5
num_samples_per_chain = 200

key, key_run, key_init = jrandom.split(key, 3)
keys = jrandom.split(key_run, num_chains)
init_state_spread = 2.0
init_state = vmap(nuts.init)(init_state_spread*jrandom.normal(key_init, (num_chains, 2)))
samples, info = vmap(run_mcmc_chain, in_axes=(0, 0, None))(keys, init_state, num_samples_per_chain)

Let’s make the trace plot (with the help of the arviz library):

import arviz as az
az.plot_trace(np.array(samples[:, :]), compact=False, backend_kwargs=dict(figsize=(8,4), tight_layout=True));
../../_images/9980ac14f6a8bbfda16a3c0b35edad4597843b1b2e0c46a4e782d497f9be50e6.svg

Let’s look at \(\hat{R}\) to assess convergence (see the Metropolis-Hastings hands-on activity for more details):

compute_diagnostics_every = 10
rhats = []
for i in range(2, num_samples_per_chain, compute_diagnostics_every):
    rhat = blackjax.diagnostics.potential_scale_reduction(samples[:, :i])
    rhats.append(rhat)
rhats = jnp.array(rhats)
fig, ax = plt.subplots(figsize=(5,4))
ax.plot(range(2, num_samples_per_chain, compute_diagnostics_every), rhats[:,0], label=r"$\hat{R}$ for $x_1$")
ax.plot(range(2, num_samples_per_chain, compute_diagnostics_every), rhats[:,1], label=r"$\hat{R}$ for $x_2$")
ax.axhline(1.0, color="black", linestyle="--")
ax.set_xlabel("Number of samples")
ax.set_ylabel(r"$\hat{R}$")
ax.legend()
sns.despine(trim=True)
../../_images/9239b8e799dda0494aa0e7eb227f51bc1ce2bfa8bbd4c89e660b3bd4056fce5e.svg

It looks like the chains converge fairly quickly.

Let’s look at the effective sample size (ESS) to see how many independent samples we have (again, see the Metropolis-Hastings hands-on activity for more details):

n_effs = []
for i in range(2, num_samples_per_chain, compute_diagnostics_every):
    n_eff = blackjax.diagnostics.effective_sample_size(samples[:, :i])
    n_effs.append(n_eff)
n_effs = jnp.array(n_effs)
Hide code cell source
fig, ax = plt.subplots(figsize=(5,4))
ax.plot(range(2, num_samples_per_chain, compute_diagnostics_every), n_effs[:,0], label=r"$\hat{R}$ for $x_1$")
ax.plot(range(2, num_samples_per_chain, compute_diagnostics_every), n_effs[:,1], label=r"$\hat{R}$ for $x_2$")
ax.set_xlabel("Number of samples")
ax.set_ylabel("Effective sample size")
ax.legend()
sns.despine(trim=True)
../../_images/4e1d13cf7406498e1c6c04eb52f5f9f61fc877079384e71a934a6393ddbcd8e7.svg

The ESS is very high. This is good—it means that our samples are not correlated, and we don’t have to thin them out.

Finally, let’s plot the samples:

# The original shape of the `samples` array is (n_chains, n_samples, n_dim)
burn_in = 50  # Remove first N samples
thin = 1  # Only keep every M samples
true_samples = samples[:, burn_in::thin]

# Concatenate the chains. Final shape is (n_chains * n_true_samples_per_chain, n_dim)
true_samples = true_samples.reshape(-1, 2)
x = jnp.linspace(-4, 4, 100)
y = jnp.linspace(-3, 7, 100)
X, Y = jnp.meshgrid(x, y)
Z = vmap(banana_logdensity)(jnp.stack([X.flatten(), Y.flatten()], axis=1)).reshape(X.shape)

ax = plot_2d_function(banana_logdensity_high_curv, alpha=0.4);
ax.set_title("MCMC samples from \nthe original distribution", fontsize=16)
ax.scatter(true_samples[:, 0], true_samples[:, 1], s=2, alpha=0.5);
../../_images/59c6310b3865282a13b7265224f2fec937b62e42515579be2a0a8a5c9121c989.svg