Hamiltonian Monte Carlo with Blackjax

Hide code cell source
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks")

from functools import partial
import jax.numpy as jnp
from jax import lax, vmap
from jax.scipy.stats import multivariate_normal
import jax.random as jrandom
import blackjax

import numpy as np

key = jrandom.PRNGKey(123)

Hamiltonian Monte Carlo with Blackjax#

Let’s use the Hamiltonian Monte Carlo (HMC) algorithm to sample from a “banana”-shaped distribution (defined in section 5.1.3 of Wang et. al.).

Here is the probability density of the banana distribution with high curvature:

Hide code cell source
def banana_logdensity(x, a=1.15, b=0.5, rho=0.5):
    """A banana-shaped distribution. Comes from a nonlinear transformation of a correlated Gaussian."""
    x1, x2 = x
    u1 = x1/a
    u2 = a*(x2 - b*(u1**2 + a**2))
    return multivariate_normal.logpdf(jnp.array([u1, u2]), jnp.zeros(2), jnp.array([[1, rho], [rho, 1]]))

def plot_2d_function(f, alpha=1.0, plot_type="pcolormesh", ax=None, levels=None):
    x = jnp.linspace(-4, 4, 100)
    y = jnp.linspace(-1, 11, 100)
    X, Y = jnp.meshgrid(x, y)
    Z = vmap(f)(jnp.stack([X.flatten(), Y.flatten()], axis=1)).reshape(X.shape)

    if ax is None:
        _, ax = plt.subplots()
    if plot_type == "contour":
        ax.contour(X, Y, jnp.exp(Z), alpha=alpha, cmap="Greens", levels=levels)
    elif plot_type == "pcolormesh":
        ax.pcolormesh(X, Y, jnp.exp(Z), alpha=alpha)
    ax.set_xticks([-2, 0, 2])
    ax.set_yticks([0, 2, 4, 6])
    ax.set_aspect("equal")
    sns.despine(trim=True, left=True, bottom=True)
    return ax
banana_logdensity_high_curv = partial(banana_logdensity, a=1.15, b=1.0, rho=0.9)
plot_2d_function(banana_logdensity_high_curv);
../../_images/3051428dda6c47154de3a5e04d2b218099cabe684d9ecce816b5bf7a6d1dfb62.svg

Hamiltonian Monte Carlo recap#

Intuitively, HMC works by simulating the dynamics of a particle moving in a potential energy field.

As an analogy, think of the probability density plot above as showing the elevation of a landscape, where the bright region is a valley. Imagine placing a ball somewhere on this landscape and kicking in a random direction with a random amount of force. Let the ball roll for a while, then stop it and record its position. Repeat this process many times, and the positions you record will be samples from the target distribution. This is HMC in a nutshell.

Let’s do it in BlackJax:

# HMC hyperparameters
step_size = 0.1
inverse_mass_matrix = jnp.eye(2)
num_integration_steps = 20

# Create the HMC kernel
hmc = blackjax.hmc(
    banana_logdensity_high_curv, step_size, inverse_mass_matrix, num_integration_steps
)

def step(state, _):
    """A single step of the Hamiltonian Monte Carlo sampler. Used with `lax.scan`."""
    key, kernel_state = state
    key, subkey = jrandom.split(key)
    kernel_state, info = hmc.step(subkey, kernel_state)
    return (key, kernel_state), (kernel_state.position, info)

def run_mcmc_chain(key, init_state, num_samples):
    """Run a chain of MCMC."""
    _, (samples, info) = lax.scan(step, (key, init_state), None, length=num_samples)
    return samples, info
num_chains = 5
num_samples_per_chain = 200

key, key_run, key_init = jrandom.split(key, 3)
keys = jrandom.split(key_run, num_chains)
init_state_spread = 2.0
init_state = vmap(hmc.init)(init_state_spread*jrandom.normal(key_init, (num_chains, 2)))
samples, info = vmap(run_mcmc_chain, in_axes=(0, 0, None))(keys, init_state, num_samples_per_chain)

Let’s make the trace plot (with the help of the arviz library):

import arviz as az
az.plot_trace(np.array(samples[:, :]), compact=False, backend_kwargs=dict(figsize=(8,4), tight_layout=True));
../../_images/56097cf226573d67f8e5052ef49b705dc2f6c494da6281d63df122c60f192159.svg

Let’s look at \(\hat{R}\) to assess convergence (see the Metropolis-Hastings hands-on activity for more details):

compute_diagnostics_every = 10
rhats = []
for i in range(2, num_samples_per_chain, compute_diagnostics_every):
    rhat = blackjax.diagnostics.potential_scale_reduction(samples[:, :i])
    rhats.append(rhat)
rhats = jnp.array(rhats)
Hide code cell source
fig, ax = plt.subplots(figsize=(5,4))
ax.plot(range(2, num_samples_per_chain, compute_diagnostics_every), rhats[:,0], label=r"$\hat{R}$ for $x_1$")
ax.plot(range(2, num_samples_per_chain, compute_diagnostics_every), rhats[:,1], label=r"$\hat{R}$ for $x_2$")
ax.axhline(1.0, color="black", linestyle="--")
ax.set_xlabel("Number of samples")
ax.set_ylabel(r"$\hat{R}$")
ax.legend()
sns.despine(trim=True)
../../_images/141a936ce5b816aa3eebdf759651eae68816eca3a22e7beddcfa16af25e40d64.svg

It looks like the chains converge fairly quickly.

Let’s look at the effective sample size (ESS) to see how many independent samples we have (again, see the Metropolis-Hastings hands-on activity for more details):

n_effs = []
for i in range(2, num_samples_per_chain, compute_diagnostics_every):
    n_eff = blackjax.diagnostics.effective_sample_size(samples[:, :i])
    n_effs.append(n_eff)
n_effs = jnp.array(n_effs)
Hide code cell source
fig, ax = plt.subplots(figsize=(5,4))
ax.plot(range(2, num_samples_per_chain, compute_diagnostics_every), n_effs[:,0], label=r"$\hat{R}$ for $x_1$")
ax.plot(range(2, num_samples_per_chain, compute_diagnostics_every), n_effs[:,1], label=r"$\hat{R}$ for $x_2$")
ax.set_xlabel("Number of samples")
ax.set_ylabel("Effective sample size")
ax.legend()
sns.despine(trim=True)
../../_images/e899dad2b7680252e792c94a1a0fa448f20aa33bcbc4478c673832c0985bb0fa.svg

The ESS is very high. This is good—it means that our samples are not correlated, and we don’t have to thin them out.

Finally, let’s plot the samples:

# The original shape of the `samples` array is (n_chains, n_samples, n_dim)
burn_in = 50  # Remove first N samples
thin = 1  # Only keep every M samples
true_samples = samples[:, burn_in::thin]

# Concatenate the chains. Final shape is (n_chains * n_true_samples_per_chain, n_dim)
true_samples = true_samples.reshape(-1, 2)
Hide code cell source
x = jnp.linspace(-4, 4, 100)
y = jnp.linspace(-3, 7, 100)
X, Y = jnp.meshgrid(x, y)
Z = vmap(banana_logdensity)(jnp.stack([X.flatten(), Y.flatten()], axis=1)).reshape(X.shape)

ax = plot_2d_function(banana_logdensity_high_curv, alpha=0.4);
ax.set_title("MCMC samples from \nthe original distribution", fontsize=16)
ax.scatter(true_samples[:, 0], true_samples[:, 1], s=2, alpha=0.5);
../../_images/dad24ecbbc613949f6a16771e3a4a9f87fdfc583869f874606cf7ef0a6b3cef6.svg

They look good! Note that even though the distribution has high curvature, HMC is still able to efficiently sample from it.