Show code cell source
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks")
import scipy
import scipy.stats as st
import urllib.request
import os
def download(
url : str,
local_filename : str = None
):
"""Download a file from a url.
Arguments
url -- The url we want to download.
local_filename -- The filemame to write on. If not
specified
"""
if local_filename is None:
local_filename = os.path.basename(url)
urllib.request.urlretrieve(url, local_filename)
import math
import tqdm
The Metropolis-Hastings Algorithm#
We build the Metropolis Adjusted Langevin dynamics algorithm. This algorithm uses the gradient of the log of the target distribution to guide the proposal distribution. It’s a pretty cool algorithm that shares some similarities with the NUTS (which is better but more complicated).
Metropolis Adjusted Langevin Dynamics (MALA)#
The proposal distribution of the MALA algorithm is given by
where \(\Delta t\) is a step size parameter. Recall from calculus that the gradient of a function points in the direction of steepest ascent. So, the MALA algorithm is proposing a new state that is a random walk in the direction of steepest ascent of the log of the target distribution. It tries to take a step in the direction of the peak of the distribution. The gradient of the log of the target distribution is called the score function. Remember the name if you dig deeper into generative models such as diffusion.
We provide a basic implementation of the MALA algorithm below.
We use pytorch
to compute the gradient of the log of the target distribution.
import torch
def propose(z, dt, log_h):
"""Propose a new point using the Langevin dynamics."""
z.requires_grad_()
log_h_z = log_h(z)
grad_log_h_z = torch.autograd.grad(log_h_z, z)[0]
return z + grad_log_h_z * dt + torch.randn(z.shape[0]) * math.sqrt(2 * dt), log_h_z, grad_log_h_z
def mala(x0, log_h, n, dt, args=()):
"""Random walk metropolis.
Arguments
----------
x0 -- The initial point.
log_h -- The logarithm of the target distribution.
n -- The maximum number of steps you want to take.
dt -- The time step you want to use.
Returns
-------
x, acceptance_rate -- The samples and the acceptance rate.
"""
d = x0.shape[0]
X = []
x0.requires_grad_()
log_h_x0 = log_h(x0)
grad_log_h_x0 = torch.autograd.grad(log_h_x0, x0)[0]
count_accepted = 0
for t in tqdm.tqdm(range(1, n + 1)):
# Propose
x_next, log_h_x_next, grad_log_h_x_next = propose(x0, dt, log_h)
# Compute acceptance ratio
log_alpha = log_h_x_next - log_h_x0 \
- 0.25 * dt * (
((x0 - x_next + dt * grad_log_h_x_next) ** 2).sum()
- ((x_next - x0 + dt * grad_log_h_x0) ** 2).sum()
)
alpha = torch.exp(log_alpha)
u = torch.rand(1)
if u <= alpha:
x0 = x_next
log_h_x0 = log_h_x_next
count_accepted += 1
X.append(x0.detach().numpy())
# Empirical acceptance rate
acceptance_rate = count_accepted / (1. * n)
X = np.array(X)
return X, acceptance_rate
Example 1: Sampling from a Gaussian with MALA#
Let’s take \(\mathcal{X}=\mathbb{R}^2\) and:
where \(\mu\in\mathbb{R}^2\) is the mean and \(\Lambda = \Sigma^{-1}\in\mathbb{R}^{2\times 2}\) is the precision matrix. We need:
and
# The parameters of the distribution from which we wish to sample
mu = torch.tensor([5., 2.])
Sigma = torch.tensor([[1., .4],
[.4, 0.2]]) # This has to be positive definite - otherwise you will get garbage!
Lambda = torch.linalg.inv(Sigma)
def log_h_mvn(x):
tmp = x - mu
return -0.5 * tmp.T @ (Lambda @ tmp)
# Initialiazation:
x0 = torch.randn(2)
# Parameters of the proposal:
dt = 0.005
# Number of steps:
n = 10000
# Start sampling
X, acceptance_rate = mala(x0, log_h_mvn, n, dt, args=(mu, Lambda))
print(f"Acceptance rate: {acceptance_rate:1.2f}")
100%|██████████| 10000/10000 [00:01<00:00, 7296.78it/s]
Acceptance rate: 0.75
fig, ax = plt.subplots()
ax.plot(X, lw=1)
ax.set_xlabel('$n$ (steps)')
ax.set_ylabel('$X_{ni}$');
fig, ax = plt.subplots()
ax.plot(X[:, 0], X[:, 1], lw=1)
ax.set_xlabel('$x_1$')
ax.set_ylabel('$x_2$')
sns.despine(trim=True);
# How many samples do you want to burn?
burn = 500
# How many samples do you want to throw in between?
thin = 10 # Keep one every thin samples
# Here are the remaining samples:
X_rest = X[burn::thin]
for i in range(X_rest.shape[1]):
fig, ax = plt.subplots()
ax.acorr(X_rest[:, 0], detrend=plt.mlab.detrend_mean, maxlags=50)
ax.set_xlim(0, 50)
ax.set_ylabel('$R_{%d}(%d k)$ (Autocorrelation)' % (i + 1, thin))
ax.set_xlabel(r'$k$ ($%d \times$ lag)' % thin);
sns.despine(trim=True);
Questions#
Play with the
thin
parameter until you get a satisfactory autocorrelation plot.