No-U-Turn Sampler with Blackjax

Hide code cell source
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks")

from functools import partial
import jax
import jax.numpy as jnp
from jax import lax, vmap
from jax.scipy.stats import multivariate_normal
import jax.random as jrandom
import blackjax
from jaxtyping import Array, Float, PyTree, PRNGKeyArray
from typing import Callable, Tuple

import numpy as np

jax.config.update("jax_enable_x64", True)
key = jrandom.PRNGKey(123)

No-U-Turn Sampler with Blackjax#

Now we will use the No-U-Turn Sampler (NUTS) to sample the banana distribution. Here is what the log density of the distribution looks like:

def banana_logdensity(x, a=1.15, b=0.5, rho=0.5):
    """A banana-shaped distribution. Comes from a nonlinear transformation of a correlated Gaussian."""
    x1, x2 = x
    u1 = x1/a
    u2 = a*(x2 - b*(u1**2 + a**2))
    return multivariate_normal.logpdf(jnp.array([u1, u2]), jnp.zeros(2), jnp.array([[1, rho], [rho, 1]]))

def plot_2d_function(f, alpha=1.0, plot_type="pcolormesh", ax=None, levels=None):
    x = jnp.linspace(-4, 4, 100)
    y = jnp.linspace(-1, 11, 100)
    X, Y = jnp.meshgrid(x, y)
    Z = vmap(f)(jnp.stack([X.flatten(), Y.flatten()], axis=1)).reshape(X.shape)

    if ax is None:
        _, ax = plt.subplots()
    if plot_type == "contour":
        ax.contour(X, Y, jnp.exp(Z), alpha=alpha, cmap="Greens", levels=levels)
    elif plot_type == "pcolormesh":
        ax.pcolormesh(X, Y, jnp.exp(Z), alpha=alpha)
    ax.set_xticks([-2, 0, 2])
    ax.set_yticks([0, 2, 4, 6])
    ax.set_aspect("equal")
    sns.despine(trim=True, left=True, bottom=True)
    return ax

banana_logdensity_high_curv = partial(banana_logdensity, a=1.15, b=1.0, rho=0.9)
plot_2d_function(banana_logdensity_high_curv);
../../_images/2ab242dc3481c550442697ad5a0b8d8359a10d116c19790012cab41033a2531b.svg

No-U-Turn Sampler (NUTS) recap#

The NUTS algorithm is essentially Hamiltonian Monte Carlo (HMC) equipped with a heuristic which adaptively selects the number of steps each Hamiltonian trajectory takes. We’ll use the blackjax’s implementation of NUTS. Much of the code below is modified from this blackjax tutorial.

Hide code cell source
def run_inference(
    joint_log_prob: Callable, 
    sample_starting_point: Callable, 
    key: PRNGKeyArray, 
    num_warmup: int, 
    num_samples: int, 
    num_chains: int
):
    """Do warmup with window adaptation (to tune step size and mass matrix) and then sample using NUTS.
    Returns a tuple with the NUTS states and some info.
    """
    # Initialize the window adaptation algorithm
    warmup = blackjax.window_adaptation(blackjax.nuts, joint_log_prob)

    # Initialize the MCMC chain
    key, init_key, warmup_key, sample_key = jax.random.split(key, 4)
    init_position = sample_starting_point(init_key, num_chains)

    # Run the warmup phase for each chain
    @jax.vmap
    def call_warmup(seed, param):
        """Run the warmup phase for a single chain."""
        (initial_states, tuned_params), _ = warmup.run(seed, param, num_warmup)
        return initial_states, tuned_params
    warmup_keys = jax.random.split(warmup_key, num_chains)
    initial_states, tuned_params = jax.jit(call_warmup)(warmup_keys, init_position)

    # Run the sampling phase
    states, infos = inference_loop_multiple_chains(sample_key, initial_states, tuned_params, joint_log_prob, num_samples, num_chains)

    # `states` contains the samples, `infos` contains other information about the sampling process
    return states, infos

def sample_starting_point(
    key: PRNGKeyArray, 
    num_chains: int, 
    mu: Array, 
    Sigma: Array
):
    """Draw samples from a multivariate normal as starting points for the MCMC chains.
    Returns an array of shape (num_chains, num_dims).
    """
    keys = jrandom.split(key, num_chains)
    return vmap(lambda k: jrandom.multivariate_normal(k, mu, Sigma))(keys)

def inference_loop_multiple_chains(
    key: PRNGKeyArray, 
    initial_states: PyTree, 
    tuned_params: dict, 
    log_prob_fn: Callable, 
    num_samples: int, 
    num_chains: int
):
    """Do NUTS sampling for multiple chains in a vectorized fashion. Returns a tuple with the NUTS states and some info."""
    # Initialize the NUTS kernel
    kernel = blackjax.nuts.build_kernel()

    def step_fn(key, state, **params):
        """A single step of NUTS for one chain."""
        return kernel(key, state, log_prob_fn, **params)

    def one_step(states, key):
        """A single step of NUTS for multiple chains."""
        keys = jax.random.split(key, num_chains)
        states, infos = jax.vmap(step_fn)(keys, states, **tuned_params)
        return states, (states, infos)

    # Run the NUTS sampling for multiple chains
    keys = jax.random.split(key, num_samples)
    _, (states, infos) = jax.lax.scan(one_step, initial_states, keys)

    return (states, infos)
num_chains = 4
num_samples_per_chain = 500
num_warmup = 500

states, infos = run_inference(
    joint_log_prob=banana_logdensity_high_curv,
    sample_starting_point=partial(sample_starting_point, mu=jnp.zeros(2), Sigma=jnp.eye(2)),
    key=key,
    num_warmup=num_warmup,
    num_samples=num_samples_per_chain,
    num_chains=num_chains
)

Let’s print some diagnostics and make the trace plot (with the help of the arviz library):

Hide code cell source
import arviz as az

def arviz_trace_from_states(states, info, burn_in=0):
    position = states.position
    if isinstance(position, jax.Array):  # if states.position is array of samples
        position = dict(samples=position)
    else:
        try:
            position = position._asdict()
        except AttributeError:
            pass

    samples = {}
    for param in position.keys():
        ndims = len(position[param].shape)
        if ndims >= 2:
            samples[param] = jnp.swapaxes(position[param], 0, 1)[
                :, burn_in:
            ]  # swap n_samples and n_chains
            divergence = jnp.swapaxes(info.is_divergent[burn_in:], 0, 1)

        if ndims == 1:
            divergence = info.is_divergent
            samples[param] = position[param]

    trace_posterior = az.convert_to_inference_data(samples)
    trace_sample_stats = az.convert_to_inference_data(
        {"diverging": divergence}, group="sample_stats"
    )
    trace = az.concat(trace_posterior, trace_sample_stats)
    return trace

trace = arviz_trace_from_states(states, infos)
summ_df = az.summary(trace)
summ_df
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
samples[0] 0.116 1.141 -1.777 2.338 0.107 0.076 114.0 119.0 1.04
samples[1] 2.398 1.618 0.438 5.381 0.168 0.119 124.0 134.0 1.03
Hide code cell source
az.plot_trace(trace)
plt.tight_layout();
../../_images/0e5a7fa7606bfb2644d4dbfe00ff198fc7ac04be221f3d1e580fd7bc8b5e5a97.svg

Let’s look at \(\hat{R}\) to assess convergence (see the Metropolis-Hastings hands-on activity for more details):

samples = states.position.transpose(1, 0, 2)  # samples shape: (num_chains, num_samples, 2)
compute_diagnostics_every = 25
rhats = []
for i in range(2, num_samples_per_chain, compute_diagnostics_every):
    rhat = blackjax.diagnostics.potential_scale_reduction(samples[:, :i])
    rhats.append(rhat)
rhats = jnp.array(rhats)
Hide code cell source
fig, ax = plt.subplots(figsize=(5,4))
ax.plot(range(2, num_samples_per_chain, compute_diagnostics_every), rhats[:,0], label=r"$\hat{R}$ for $x_1$")
ax.plot(range(2, num_samples_per_chain, compute_diagnostics_every), rhats[:,1], label=r"$\hat{R}$ for $x_2$")
ax.axhline(1.0, color="black", linestyle="--")
ax.set_xlabel("Number of samples")
ax.set_ylabel(r"$\hat{R}$")
ax.legend()
sns.despine(trim=True)
../../_images/a68c7644ca1a1f8b9bcf9be0361a6d583fe3030e7afe55ce991da083745895d5.svg

It looks like the chains converge fairly quickly.

Let’s look at the effective sample size (ESS) to see how many independent samples we have (again, see the Metropolis-Hastings hands-on activity for more details):

n_effs = []
for i in range(2, num_samples_per_chain, compute_diagnostics_every):
    n_eff = blackjax.diagnostics.effective_sample_size(samples[:, :i])
    n_effs.append(n_eff)
n_effs = jnp.array(n_effs)
Hide code cell source
fig, ax = plt.subplots(figsize=(5,4))
ax.plot(range(2, num_samples_per_chain, compute_diagnostics_every), n_effs[:,0], label=r"Effective Sample Size for $x_1$")
ax.plot(range(2, num_samples_per_chain, compute_diagnostics_every), n_effs[:,1], label=r"Effective Sample Size for $x_2$")
ax.set_xlabel("Number of samples")
ax.set_ylabel("Effective sample size")
ax.legend()
sns.despine(trim=True)
../../_images/2f92ad8a596d7e5851d197d3cb134223b2f6fe831baaa6daecbb70eea40e257e.svg

The ESS is high. This is good—it means that our samples are not correlated, and we don’t have to thin them out too much.

Finally, let’s plot the samples:

# The original shape of the `samples` array is (n_chains, n_samples, n_dim)
burn_in = 50  # Remove first N samples
thin = 2  # Only keep every M samples
true_samples = samples[:, burn_in::thin]

# Concatenate the chains. Final shape is (n_chains * n_true_samples_per_chain, n_dim)
true_samples = true_samples.reshape(-1, 2)
Hide code cell source
x = jnp.linspace(-4, 4, 100)
y = jnp.linspace(-3, 7, 100)
X, Y = jnp.meshgrid(x, y)
Z = vmap(banana_logdensity)(jnp.stack([X.flatten(), Y.flatten()], axis=1)).reshape(X.shape)

ax = plot_2d_function(banana_logdensity_high_curv, alpha=0.4);
ax.set_title("MCMC samples from \nthe original distribution", fontsize=16)
ax.scatter(true_samples[:, 0], true_samples[:, 1], s=2, alpha=0.5);
../../_images/78bd33af6b23e92138222843a342635b0b3200f6f84fbc6a4a642be9d38d11b2.svg