Hide code cell source
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks")

import scipy
import scipy.stats as st
import urllib.request
import os

def download(
    url : str,
    local_filename : str = None
):
    """Download a file from a url.
    
    Arguments
    url            -- The url we want to download.
    local_filename -- The filemame to write on. If not
                      specified 
    """
    if local_filename is None:
        local_filename = os.path.basename(url)
    urllib.request.urlretrieve(url, local_filename)

import math
import tqdm

The Metropolis-Hastings Algorithm#

We build the Metropolis Adjusted Langevin dynamics algorithm. This algorithm uses the gradient of the log of the target distribution to guide the proposal distribution. It’s a pretty cool algorithm that shares some similarities with the NUTS (which is better but more complicated).

Metropolis Adjusted Langevin Dynamics (MALA)#

The proposal distribution of the MALA algorithm is given by

\[ q(x'|x) = \mathcal{N}\left(x' | x + \Delta t \nabla \log p(x), 2\Delta t^2 I\right), \]

where \(\Delta t\) is a step size parameter. Recall from calculus that the gradient of a function points in the direction of steepest ascent. So, the MALA algorithm is proposing a new state that is a random walk in the direction of steepest ascent of the log of the target distribution. It tries to take a step in the direction of the peak of the distribution. The gradient of the log of the target distribution is called the score function. Remember the name if you dig deeper into generative models such as diffusion.

We provide a basic implementation of the MALA algorithm below. We use pytorch to compute the gradient of the log of the target distribution.

import torch 


def propose(z, dt, log_h):
    """Propose a new point using the Langevin dynamics."""
    z.requires_grad_()
    log_h_z = log_h(z)
    grad_log_h_z = torch.autograd.grad(log_h_z, z)[0]
    return z + grad_log_h_z * dt + torch.randn(z.shape[0]) * math.sqrt(2 * dt), log_h_z, grad_log_h_z


def mala(x0, log_h, n, dt, args=()):
    """Random walk metropolis.
    
    Arguments
    ----------
    x0     -- The initial point.
    log_h  -- The logarithm of the target distribution.
    n      -- The maximum number of steps you want to take.
    dt     -- The time step you want to use.
    
    Returns
    -------
    x, acceptance_rate  -- The samples and the acceptance rate.
    """
    d = x0.shape[0]
    X = []
    x0.requires_grad_()
    log_h_x0 = log_h(x0)
    grad_log_h_x0 = torch.autograd.grad(log_h_x0, x0)[0]
    count_accepted = 0
    for t in tqdm.tqdm(range(1, n + 1)):
        # Propose
        x_next, log_h_x_next, grad_log_h_x_next = propose(x0, dt, log_h)
        # Compute acceptance ratio
        log_alpha = log_h_x_next - log_h_x0 \
            - 0.25 * dt * (
                ((x0 - x_next + dt * grad_log_h_x_next) ** 2).sum()
                - ((x_next - x0 + dt * grad_log_h_x0) ** 2).sum()
            )
        alpha = torch.exp(log_alpha)
        u = torch.rand(1)
        if u <= alpha:
            x0 = x_next
            log_h_x0 = log_h_x_next
            count_accepted += 1
        X.append(x0.detach().numpy())
    # Empirical acceptance rate
    acceptance_rate = count_accepted / (1. * n)
    X = np.array(X)
    return X, acceptance_rate

Example 1: Sampling from a Gaussian with MALA#

Let’s take \(\mathcal{X}=\mathbb{R}^2\) and:

\[ \pi(x) \propto h(x) = \exp\left\{-\frac{1}{2}\left(x-\mu\right)^T\Lambda(x-\mu)\right\}, \]

where \(\mu\in\mathbb{R}^2\) is the mean and \(\Lambda = \Sigma^{-1}\in\mathbb{R}^{2\times 2}\) is the precision matrix. We need:

\[ \log h(x) = -\frac{1}{2}(x-\mu)^T\Lambda (x-\mu), \]

and

\[ \nabla \log h(x) = -\Lambda (x-\mu). \]
# The parameters of the distribution from which we wish to sample
mu = torch.tensor([5., 2.])
Sigma = torch.tensor([[1., .4],
                      [.4, 0.2]]) # This has to be positive definite - otherwise you will get garbage!

Lambda = torch.linalg.inv(Sigma)


def log_h_mvn(x):
    tmp = x - mu
    return -0.5 * tmp.T @ (Lambda @ tmp)


# Initialiazation:
x0 = torch.randn(2)
# Parameters of the proposal:
dt = 0.005
# Number of steps:
n = 10000

# Start sampling
X, acceptance_rate = mala(x0, log_h_mvn, n, dt, args=(mu, Lambda))

print(f"Acceptance rate: {acceptance_rate:1.2f}")
100%|██████████| 10000/10000 [00:01<00:00, 7296.78it/s]
Acceptance rate: 0.75

fig, ax = plt.subplots()
ax.plot(X, lw=1)
ax.set_xlabel('$n$ (steps)')
ax.set_ylabel('$X_{ni}$');
../_images/87e47ab6ce171ba8c3353574d314f7d0b25d83b0d00a3617df9c77dba2d8cbaa.svg
fig, ax = plt.subplots()
ax.plot(X[:, 0], X[:, 1], lw=1)
ax.set_xlabel('$x_1$')
ax.set_ylabel('$x_2$')
sns.despine(trim=True);
../_images/4e014e0de889e3be3174e0e775dc112c2d2037611b8f88a418737a024d64b7f5.svg
# How many samples do you want to burn?
burn = 500
# How many samples do you want to throw in between?
thin = 10 # Keep one every thin samples 
# Here are the remaining samples:
X_rest = X[burn::thin]
for i in range(X_rest.shape[1]):
    fig, ax = plt.subplots()
    ax.acorr(X_rest[:, 0], detrend=plt.mlab.detrend_mean, maxlags=50)
    ax.set_xlim(0, 50)
    ax.set_ylabel('$R_{%d}(%d k)$ (Autocorrelation)' % (i + 1, thin))
    ax.set_xlabel(r'$k$ ($%d \times$ lag)' % thin);
    sns.despine(trim=True);
../_images/cefacd832682595fcb6df00c09d667fbffb0549a7ef7eae1811acdee98b27f07.svg../_images/5999fee900271df628fcd636038e61377b8874da417ea681f98caa6220bc266e.svg

Questions#

  • Play with the thin parameter until you get a satisfactory autocorrelation plot.