import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks")
import jax
from jax import grad, vmap, jit
import jax.numpy as jnp
import jax.scipy.stats as jst
import optax
jax.config.update("jax_enable_x64", True)
The Laplace approximation#
The Laplace approximation is a method to approximate a probability distribution with a Gaussian distribution. The method is based on the Taylor expansion of the logarithm of the probability distribution around the mode of the distribution. The Laplace approximation is a useful tool in Bayesian statistics, where it is used to approximate the posterior distribution of a parameter given some data.
The Taylor Expansion in Higher Dimensions#
Let \(g:\mathbb{R}^d\rightarrow \mathbb{R}\) and \(x_0\in\mathbb{R}^d\). The Taylor series expansion is:
Another way of writing this is:
where the Jacobian is defined as:
and the Hessian is:
The Laplace Approximation in 1D#
If you are not interested in probabilities of rare events, you may approximate any probability density as a Gaussian (assuming that it is sufficiently narrow). Let \(X\) be a random variable with probability density \(p(x)\). Because \(p(x)\) is positive, it is better to work with its logarithm. First, we find the maximum of \(\log p(x)\) which is called the nominal value (or just maximum):
Then, we take the Taylor expansion of \(\log p(x)\) about \(x=x_0\):
Since \(x_0\) is a critical point of \(\log p(x)\), we must have that:
So, the expansion becomes:
Therefore,
Since \(x_0\) is a maximum of \(\log p(x)\), the matrix \(\frac{d^2\log p(x_0)}{dx^2}\) must be a negative number. Therefore, the number:
is positive. By inspection then, we see that:
Ignoring all higher order terms, we conclude:
This is the Laplace approximation in one dimension.
Example: Gamma Distribution#
Let’s try it out with the gamma distribution. First I am going to write a quick function to help us minimize the negative log likelihood of the gamma distribution.
# Optimization function using Optax (for scalar minimization)
def find_minimum(fun, init_x, max_iter=5000, tol=1e-6):
optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(init_x)
def step(x, opt_state):
loss, grad_val = jax.value_and_grad(fun)(x)
updates, opt_state = optimizer.update(grad_val, opt_state, x)
x = optax.apply_updates(x, updates)
return x, opt_state, loss
x = init_x
for _ in range(max_iter):
x, opt_state, loss = step(x, opt_state)
if jnp.abs(grad(fun)(x)) < tol:
status = 'converged'
break
status = 'maximum iterations reached'
return x, loss, status
Okay, let’s first build our minimization problem. We want to find where the distribution’s probability is maximized. We can do this by minimizing the negative log likelihood of the gamma distribution.
# Construct the -log pdf of a gamma distribution
alpha = 10
minus_log_pdf_true = lambda x: -jst.gamma.logpdf(x, alpha)
# Find the maximum using an optimization method
init_x = 1.0
x_0, min_logpdf_val, status = find_minimum(minus_log_pdf_true, init_x)
# Return the maximum of the pdf
p_0 = jnp.exp(-min_logpdf_val)
print(f'The maximum of the gamma pdf is {p_0:.4f} at x = {x_0:.4f}.')
The maximum of the gamma pdf is 0.1318 at x = 8.9997.
We can prove this to ourselves by checking the derivative of the negative log likelihood of the gamma distribution. It should be zero at the maximum.
dp_0 = grad(minus_log_pdf_true)(x_0)
print(f'The derivative of the log pdf at the maximum is: {dp_0:.4f}.')
The derivative of the log pdf at the maximum is: -0.0000.
Great. we can trust that we found the maximimum. We need one more thing for the Laplace approximation: The standard deviation of our Gaussian. So to get this we need to calculate the second derivative of the log likelihood of the gamma distribution.
# Compute the second derivative of the log pdf at the maximum
d2p_0 = grad(grad(minus_log_pdf_true))(x_0)
std = jnp.sqrt(1. / d2p_0)
# Define the range of x values to plot
x = jnp.linspace(x_0 - 6 * std, x_0 + 6 * std, 200)
# Build the Gaussian approximation to the gamma pdf
approx = lambda d: jst.norm.pdf(x, loc=x_0, scale=std)
See how painless this all was with the power of JAX? Let’s plot both distributions to verify our approximation.
# Plot the results
fig, ax = plt.subplots()
pdf_true = jst.gamma.pdf(x, alpha)
ax.plot(x, pdf_true, label='PDF of Gamma(%.2f)' % alpha)
ax.plot(x, approx(x), '--', label='Laplace approximation')
ax.plot([x_0] * 10, jnp.linspace(0, p_0, 10), ':', label='Maximum of true PDF')
ax.legend()
ax.set_xlabel('$x$')
ax.set_ylabel('PDF')
sns.despine(trim=True)
plt.show()
The Laplace Approximation in Many Dimensions#
Then, we take the Taylor expansion of \(\log p(x)\) about \(x=x_0\):
Since \(x_0\) is a critical point of \(\log p(x)\), we must have that:
So, the expansion becomes:
Therefore,
Since \(x_0\) is a maximum of \(\log p(x)\), the matrix \(\nabla^2 \log p(x_0)\) must be negative definite. Therefore, the matrix:
is positive definite. By inspection then, we see that:
Ignoring all higher order terms, we conclude:
This is the Laplace approximation in many dimensions.
The Laplace Approximation#
If you have a linear model and a Gaussian likelihood and the prior \(p(x)\) is Gaussian, then this is just a Gaussian. In general, you can only chacterize it through sampling or some approximation. We will discuss sampling from the posterior in Lecture 21 when we will introduce Markov Chain Monte Carlo. For now, let us discuss the Laplace approximation to the posterior.
The Laplace approximation finds a Gaussian approximation to the posterior. For simplicity we will work with a Gaussian likelihood, but these ideas can be applied to any likelihood. To avoid over complicating things on this first attempt, let us assume that we know the value of \(\sigma\). We will generalize our approach later. Our posterior is (for just \(x\)) is:
To implement the Laplace approximation define:
We need to find the maximum of this:
the matrix
and the approximation is:
The second derivative is:
So, you need the second derivatives of your model…