Probabilistic numerics using pyro

MAKE_BOOK_FIGURES=Trueimport numpy as npimport scipy.stats as stimport matplotlib as mplimport matplotlib.pyplot as plt%matplotlib inlineimport matplotlib_inlinematplotlib_inline.backend_inline.set_matplotlib_formats('svg')import seaborn as snssns.set_context("paper")sns.set_style("ticks")def set_book_style():    plt.style.use('seaborn-v0_8-white')     sns.set_style("ticks")    sns.set_palette("deep")    mpl.rcParams.update({        # Font settings        'font.family': 'serif',  # For academic publishing        'font.size': 8,  # As requested, 10pt font        'axes.labelsize': 8,        'axes.titlesize': 8,        'xtick.labelsize': 7,  # Slightly smaller for better readability        'ytick.labelsize': 7,        'legend.fontsize': 7,                # Line and marker settings for consistency        'axes.linewidth': 0.5,        'grid.linewidth': 0.5,        'lines.linewidth': 1.0,        'lines.markersize': 4,                # Layout to prevent clipped labels        'figure.constrained_layout.use': True,                # Default DPI (will override when saving)        'figure.dpi': 600,        'savefig.dpi': 600,                # Despine - remove top and right spines        'axes.spines.top': False,        'axes.spines.right': False,                # Remove legend frame        'legend.frameon': False,                # Additional trim settings        'figure.autolayout': True,  # Alternative to constrained_layout        'savefig.bbox': 'tight',    # Trim when saving        'savefig.pad_inches': 0.1   # Small padding to ensure nothing gets cut off    })def set_notebook_style():    plt.style.use('seaborn-v0_8-white')    sns.set_style("ticks")    sns.set_palette("deep")    mpl.rcParams.update({        # Font settings - using default sizes        'font.family': 'serif',        'axes.labelsize': 10,        'axes.titlesize': 10,        'xtick.labelsize': 9,        'ytick.labelsize': 9,        'legend.fontsize': 9,                # Line and marker settings        'axes.linewidth': 0.5,        'grid.linewidth': 0.5,        'lines.linewidth': 1.0,        'lines.markersize': 4,                # Layout settings        'figure.constrained_layout.use': True,                # Remove only top and right spines        'axes.spines.top': False,        'axes.spines.right': False,                # Remove legend frame        'legend.frameon': False,                # Additional settings        'figure.autolayout': True,        'savefig.bbox': 'tight',        'savefig.pad_inches': 0.1    })def save_for_book(fig, filename, is_vector=True, **kwargs):    """    Save a figure with book-optimized settings.        Parameters:    -----------    fig : matplotlib figure        The figure to save    filename : str        Filename without extension    is_vector : bool        If True, saves as vector at 1000 dpi. If False, saves as raster at 600 dpi.    **kwargs : dict        Additional kwargs to pass to savefig    """        # Set appropriate DPI and format based on figure type    if is_vector:        dpi = 1000        ext = '.pdf'    else:        dpi = 600        ext = '.tif'        # Save the figure with book settings    fig.savefig(f"{filename}{ext}", dpi=dpi, **kwargs)def make_full_width_fig():    return plt.subplots(figsize=(4.7, 2.9), constrained_layout=True)def make_half_width_fig():    return plt.subplots(figsize=(2.35, 1.45), constrained_layout=True)if MAKE_BOOK_FIGURES:    set_book_style()else:    set_notebook_style()make_full_width_fig = make_full_width_fig if MAKE_BOOK_FIGURES else lambda: plt.subplots()make_half_width_fig = make_half_width_fig if MAKE_BOOK_FIGURES else lambda: plt.subplots()

Probabilistic numerics using pyro#

pyro is a probabilistic programming language built on top of pytorch. It is a very powerful tool for building probabilistic models and performing Bayesian inference. pyro can do both sampling (this lecture) and variational inference (next lecture). These notes are necessarily incomplete. You may want to also go over the official introductory tutorial

Coin toss example#

Let’s generate some data to play with. The data are the results of 100 coin tosses (fair coin).

import torch

data = torch.randint(0, 2, (100,)).float()
data
tensor([1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 0., 1., 1., 1., 0., 1., 1., 1.,
        1., 1., 1., 0., 0., 1., 1., 1., 0., 1., 0., 1., 0., 1., 1., 0., 0., 0.,
        1., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1., 0., 1., 0., 1., 0., 0.,
        1., 0., 0., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
        0., 0., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 1., 0., 0., 1., 1., 1.,
        1., 0., 1., 0., 0., 0., 0., 1., 1., 0.])

Our model is the usual:

\[ \theta \sim \text{Uniform}(0,1), \]

and, for each \(i=1,\ldots,100\),

\[ x_i | \theta \sim \text{Bernoulli}(\theta), \]

independently.

We need the following imports:

# Do this in Google Colab
!pip install pyro-ppl
import pyro
import pyro.distributions as dist

We use dist to define distributions. To make a random variable theta that is uniformly distributed on \([0,1]\), we use dist.Uniform(0,1). Specifically, we have to write:

```python
theta = dist.Uniform(0,1).sample()
```

or

```python
theta = pyro.sample("theta", dist.Uniform(0,1))
```

We are going to go with the latter syntax, which is more flexible and allows us to name the random variables.

If we had a single observation, we would write:

```python
x = pyro.sample("x", dist.Bernoulli(theta), obs=data)
```

where data is the observed value of x. But we have 100 observations, so we need to use a loop:

```python
for i in range(100):
    x = pyro.sample("x_{}".format(i), dist.Bernoulli(theta), obs=data[i])
```

However, loops are very inefficient in pyro, so we use pyro.plate instead:

```python
with pyro.plate("data", 100):
    x = pyro.sample("x", dist.Bernoulli(theta), obs=data)
```

The pyro.plate statement tells pyro that the random variables inside the block are conditionally independent given the plate index. This allows pyro to do some optimizations.

Let’s put everything together i a model. pyro models are defined as functions.

def coin_toss_model(data):
    theta = pyro.sample("theta", dist.Uniform(0, 1))
    with pyro.plate("n", len(data)):
        return pyro.sample("obs", dist.Bernoulli(theta), obs=data)

Here is how you can visualize the model:

pyro.render_model(
    coin_toss_model,
    model_args=(data,),
    render_distributions=True,
    render_params=True
)
../_images/e23fa75e336dbe81fe2f1347a24b55a3f4385575989a92091f2def62b5c77aaa.svg

This follows the standard graphical model notation we have been using in class.

Now we are ready to do inference using sampling. In this course, we will use very simple samplers based on the Metropolis-Hastings algorithm. But pyro has much more powerful samplers. In particular, you should be aware of the very powerful algorithm called NUTS (No U-Turn Sampler). This is what we use in practice. Here it is:

from pyro.infer import MCMC, NUTS

nuts_kernel = NUTS(coin_toss_model, jit_compile=True, ignore_jit_warnings=True)
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=1000)
mcmc.run(data)
posterior_samples = mcmc.get_samples()
Warmup:   0%|          | 0/2000 [00:00, ?it/s]
Sample: 100%|██████████| 2000/2000 [00:02, 964.14it/s, step size=9.97e-01, acc. prob=0.932] 

The first line tells pyro to use the NUTS kernel. The jit_compile=True option tells pyro to compile the model. This makes it faster. The next line makes the sampler, and it tells it to do 1000 warmup iterations and 1000 sampling iterations. The warmup iterations are used to tune the sampler and they are discarded. Next we run the sampler. Finally, we get the posterior samples.

Here is a summary of the posterior:

mcmc.summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
     theta      0.50      0.05      0.50      0.42      0.59    373.53      1.00

Number of divergences: 0

You can understand most of the statistics. We see the mean and standard deviation of the posterior, and the quantiles. The r_hat is the Gelman-Rubin statistic, which is used to check for convergence. If it is close to 1, then the sampler has converged. The n_eff is the effective sample size, which is the number of independent samples that we have.

Let’s look at the samples we get at each iteration. This is called the trace plot.

fig, ax = plt.subplots()
ax.plot(posterior_samples["theta"])
ax.set_xlabel("Sampling iteration")
ax.set_ylabel("$\\theta$")
sns.despine(trim=True)
../_images/d1d1fa7f6ebe5df8094b4afbd29e5ea1b91550c2aa93a497ca793808f7ca034b.svg

Let’s also look at the histogram of the samples:

fig, ax = plt.subplots()
ax.hist(posterior_samples["theta"], bins=20, density=True, label="Posterior")
ax.plot([0.5, 0.5], [0, 10], "r--", label="True value")
ax.set_xlim([0.0, 1.0])
ax.set_xlabel("$\\theta$")
ax.set_ylabel("Posterior probability density")
plt.legend(loc="best", frameon=False)
sns.despine(trim=True);
../_images/2e328ea3aff533b5df067beea5cb24cd004bd5eb026f3288fe974b9847f8fa98.svg

Questions#

  • Repeat the analysis with 1000 observations.

Posterior predictive checking with pyro#

You can use pyro to do posterior predictive checking. First, we need to make a version of the model that does not have the obs statement. We cand o this, by using the pyro.poutine.uncondition function:

unconditioned_coin_ross_model = pyro.poutine.uncondition(coin_toss_model)
pyro.render_model(
    unconditioned_coin_ross_model,
    model_args=(data,),
    render_distributions=True,
    render_params=True
)
../_images/b508dfd354ffed95ab1e9eea65b32f1d6735d069b23f2931e297e08b7d402cb9.svg

Now, we can use pyro.infer.Predictive on the unconditioned model:

replicated_data = pyro.infer.Predictive(unconditioned_coin_ross_model, posterior_samples)(data)

Here the experiment is replicated 1000 times, as many times as we have samples from the posterior.

replicated_data["obs"].shape
torch.Size([1000, 100])

Let’s visualize the samples:

fig, ax = plt.subplots()
plt.imshow(replicated_data["obs"], cmap="gray", aspect="auto")
ax.set_ylabel("Replicated dataset")
ax.set_xlabel("Sampling iteration")
# Remove all spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['bottom'].set_visible(False)
../_images/60762c27d5977bcc65321eee4212ea2685e1040755a93efdcdb8c46931103670.svg

And we can also use the results to calculate Bayesian \(p\)-values for any test we like:

def T_s(x):
    """Return the number of switches between 0s and 1s."""
    s = 0
    for i in range(1, x.shape[0]):
        if x[i] != x[i-1]:
            s += 1
    return s
tests = [T_s(x) for x in replicated_data["obs"]]
observed = T_s(data)
bp_val = (np.sum(np.array(tests) >= observed) + 1) / (len(tests) + 1)
fig, ax = plt.subplots()
ax.hist(tests, bins=20, density=True, label="Posterior", alpha=0.5)
ax.axvline(observed, color="red", label="Observed")
ax.set_xlabel("Number of switches")
ax.set_ylabel("Posterior probability density")
ax.set_title(f"Bayesian p-value: {bp_val:.3f}")
plt.legend(loc="best", frameon=False)
sns.despine(trim=True);
../_images/74cf9d9bae9fd3b0143989ecb31a92cebfaefa254826014a604d4e4702287779.svg