Hide code cell source
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks")

import scipy
import scipy.stats as st
import urllib.request
import os

def download(
    url : str,
    local_filename : str = None
):
    """Download a file from a url.
    
    Arguments
    url            -- The url we want to download.
    local_filename -- The filemame to write on. If not
                      specified 
    """
    if local_filename is None:
        local_filename = os.path.basename(url)
    urllib.request.urlretrieve(url, local_filename)

Variational Inference Examples#

We introduce variational inference (VI) for approximate Bayesian Inference. We implement VI from scratch in pytorch and we also show how to do it in pyro.

Note: This notebook was originally developed by Dr. Rohit Tripathy in PyMC. It was modified by Prof. Bilionis using pytorch and pyro.

# Do this in Google Colab
!pip install pyro-ppl

Example 1 - normal-normal model#

Let’s demonstrate the VI process end-to-end with a simple example. Consider the task of inferring the gravitational constant from data. We perform experiments \(x_{1:n}\) which measure the acceleration of gravity, and we know that the measurement standard deviation is \(\sigma = 0.1\). Here are some synthetic data.

import torch
import scipy.constants
g_true = scipy.constants.g

# Generate some synthetic data
n = 10
sigma = 0.1
data = g_true + sigma * torch.randn(n)

plt.plot(np.arange(n), data, 'o', label='Data')
plt.plot(np.linspace(0, n, 100), g_true*np.ones(100), '--', color="r", label='True value')
plt.legend(loc='best', frameon=False)
sns.despine(trim=True)
../_images/c1a962e299eb3f6768b997599b810086a51807a8a216c3c7fb8eb9beefebdb4e.svg

The likelihood of each measurement is given by:

\[ x_i | g, \sigma \sim N(g, \sigma^2). \]

So, the model says that the measured acceleration of gravity is around the true one with some Gaussian noise. Assume that our prior state-of-knowledge over \(g\) is:

\[ g | g_0, s_0 \sim N(g_0, s_0^2), \]

with known \(g_0 = 10\), \(s_0 = 0.4\). This is a, so-called, conjugate prior and the posterior over \(g\) is given analytically by:

\[ g|x_{1:n} \sim N(\tilde{g}, \tilde{s}^2), \]

where

\[ \tilde{s}^2 = \left( \frac{n}{\sigma^2} + \frac{1}{s_0^2} \right)^{-1}, \]

and

\[ \tilde{g} = \tilde{s}^2 \left( \frac{g_0}{s_0^2} + \frac{\sum_{i=1}^{n} x_i}{\sigma^2}\right). \]

Let’s write some code to get this analytical posterior:

def post_mean_and_variance(prior_mean, prior_variance, data, likelihood_var):
    n = len(data)
    sigma2 = likelihood_var
    s02 = prior_variance
    m0 = prior_mean
    sumdata = torch.sum(data)
    post_prec = (n/sigma2) + (1./s02)
    post_var = 1./post_prec
    post_mean = post_var * ((m0/s02) + (sumdata/sigma2))
    return post_mean, post_var

gtilde, s2tilde = post_mean_and_variance(10., 0.4**2, data, 0.1**2)

xs1 = torch.linspace(7, 12, 100)
xs2 = torch.linspace(8, 11, 100)
plt.plot(xs1, st.norm(loc=10., scale=0.4).pdf(xs1), label='Prior')
plt.plot(xs2, st.norm(loc=gtilde, scale=np.sqrt(s2tilde)).pdf(xs2), label='Posterior')
plt.axvline(g_true, color='r', linestyle='--', label='True value')
plt.legend(loc='best', frameon=False)
sns.despine(trim=True);
../_images/ff98a30cd53edb4e47f44a361d3d07320faa38ca175db8bc6a6497707f9a662f.svg

Now let’s try to infer the posterior over \(g\) using VI. We specify the joint log probability of the model first. We do everything using pytorch so that we can use automatic differentiation.

g0 = torch.tensor(10.)
s0 = torch.tensor(0.4)
sigma = torch.tensor(0.1)

def logprior(g):
    return torch.distributions.Normal(g0, s0).log_prob(g)

def loglikelihood(data, g):
    return torch.distributions.Normal(g, sigma).log_prob(data).sum()

def logjoint(data, g):
    return logprior(g) + loglikelihood(data, g)

We must specify a parameterized approximate posterior, \(q_{\phi}(\cdot)\). The obvious choice here is a Gaussian:

\[ q_{\phi}(g) = N(g | \phi_1, \exp(\phi_2)^2), \]

where, \(\phi = (\phi_1, \phi_2)\) are the variational parameters. The ELBO needs to be maximized with respect to \(\phi\). Let’s go ahead and set up the ELBO. Recall that the ELBO is given by:

\[ \text{ELBO}(\phi) = \mathbb{E}_{q(\theta)}[\log p(\theta, \mathcal{D})] + \mathbb{H}[q(\theta)]. \]

The entropy is given by:

\[ \mathbb{H}[q(\theta)] = 1/2\log(2\pi \exp(\phi_2)^2) + 1/2 = \phi_2 + \text{const}. \]

To optimize the ELBO, we must compute an expectation over the variational distribution \(q\) (first term on the RHS in the above equation). This cannot be done analytically. Instead, we resort to a Monte Carlo approximation:

\[ \mathbb{E}_q [\log p(\theta, \mathbf{x})] \approx \frac{1}{S}\sum_{s=1}^{S} \log p(\theta^{(s)}, \mathbf{x}), \]

where the samples \(\theta^{(s)}\) are drawn from \(q\).

Here is the code for the ELBO:

def ELBO(phi, data, num_samples):
    g_samples = phi[0] + torch.exp(phi[1]) * torch.randn(num_samples)
    return logjoint(data, g_samples).mean() + phi[1]


def negELBO(phi, data, num_samples):
    return -ELBO(phi, data, num_samples)

Now we can setup the optimization problem. Here we go:

from torch import optim

num_samples = 10
num_iter = 20_000
phi = torch.tensor([9.0, -1.0])
phi.requires_grad_(True)
optimizer = optim.Adam([phi], lr=0.001)

elbos = []
for i in range(num_iter):
    optimizer.zero_grad()
    loss = negELBO(phi, data, num_samples)
    loss.backward()
    optimizer.step()
    if i % 1_000 == 0:
        print(f"Iteration: {i} Loss: {loss.item()}")
    elbos.append(-loss.item())
phi.requires_grad_(False);
Iteration: 0 Loss: 375.7824401855469
Iteration: 1000 Loss: 26.652555465698242
Iteration: 2000 Loss: 2.5472326278686523
Iteration: 3000 Loss: -1.334648847579956
Iteration: 4000 Loss: 0.9369990825653076
Iteration: 5000 Loss: -7.821878433227539
Iteration: 6000 Loss: -0.5406231880187988
Iteration: 7000 Loss: -6.518988609313965
Iteration: 8000 Loss: -5.652396202087402
Iteration: 9000 Loss: -5.485063076019287
Iteration: 10000 Loss: -5.161447525024414
Iteration: 11000 Loss: -7.376476287841797
Iteration: 12000 Loss: -5.037634372711182
Iteration: 13000 Loss: -3.8640382289886475
Iteration: 14000 Loss: -3.771669626235962
Iteration: 15000 Loss: -4.731598854064941
Iteration: 16000 Loss: -5.233105182647705
Iteration: 17000 Loss: -5.517479419708252
Iteration: 18000 Loss: -5.228094100952148
Iteration: 19000 Loss: -6.046360015869141

Here is the evolution of the ELBO:

fig, ax = plt.subplots()
ax.plot(torch.arange(num_iter), elbos)
ax.set(xlabel='Iteration', ylabel='ELBO')
sns.despine(trim=True);
../_images/5db6c7043c088b948dc10e12366cd321b6405e8879c423a7ebecc2fb35925fd0.svg

Let’s build and visualize the approximate posterior:

postmean = phi[0]
poststdev = torch.exp(phi[1])
gpost = st.norm(postmean, poststdev)

xs = np.linspace(9.6, 10.1, 100)
fig, ax = plt.subplots()
ax.plot(xs, st.norm(loc=gtilde, scale=np.sqrt(s2tilde)).pdf(xs), label='True Posterior')
ax.plot(xs, gpost.pdf(xs), label='VI Posterior')
ax.axvline(g_true, color='r', linestyle='--', label='True value')
plt.legend(loc='best', frameon=False)
sns.despine(trim=True);
../_images/9ba16c1bfbf1a0f6bae363f0be4d605b6d9032adb851aeddffd23f9f76106ac4.svg

As you can see, our approximation of the posterior is close. The normal-normal model is a straightforward example with one latent variable. In practice, setting up the variational posterior for all latent variables, keeping track of transformations, and optimizing the variational parameters can become tedious for models of any reasonable level of complexity.

Now let’s do the same thing using pyro:

import pyro
import pyro.distributions as dist


def model(data):
    g = pyro.sample("g", pyro.distributions.Normal(g0, s0))
    with pyro.plate("data", data.shape[0]):
        pyro.sample("obs", pyro.distributions.Normal(g, sigma), obs=data)

We need to define the guide. We can either do it by hand or use the AutoNormal guide. The AutoNormal scans the model for laten variables (in our model the only such variable is g) and it constructs a guide that is a normal distribution with learnable parameters.

guide = pyro.infer.autoguide.AutoNormal(model)

Then we can do inference:

# This removes old parameters from the parameter store.
# If you don't do this, you'll overwrite parameters from
# previous runs.
pyro.clear_param_store()

# This is the stochastic variational inference algorithm:
svi = pyro.infer.SVI(
    model,              # model to optimize
    guide,              # variational distribution
    pyro.optim.Adam(    # optimizer to use
        {"lr": 0.001}   # parameters of the optimizer
    ),
    pyro.infer.JitTrace_ELBO() # loss optimization function - here, the ELBO
)

# And we can iterate:
num_iter = 20_000
elbos = []
for i in range(num_iter):
    loss = svi.step(data)
    elbos.append(-loss)
    if i % 1_000 == 0:
        print(f"Iteration: {i} Loss: {loss}")
Iteration: 0 Loss: 49686.4765625
Iteration: 1000 Loss: 40596.5390625
Iteration: 2000 Loss: 30614.01953125
Iteration: 3000 Loss: 24874.3671875
Iteration: 4000 Loss: 19644.50390625
Iteration: 5000 Loss: 12995.2060546875
Iteration: 6000 Loss: 9540.4306640625
Iteration: 7000 Loss: 6226.10107421875
Iteration: 8000 Loss: 3715.46435546875
Iteration: 9000 Loss: 2070.754150390625
Iteration: 10000 Loss: 1292.1932373046875
Iteration: 11000 Loss: 252.4281463623047
Iteration: 12000 Loss: 38.22501754760742
Iteration: 13000 Loss: -5.493821620941162
Iteration: 14000 Loss: -2.04667329788208
Iteration: 15000 Loss: -7.136266708374023
Iteration: 16000 Loss: -5.793765544891357
Iteration: 17000 Loss: -7.158337116241455
Iteration: 18000 Loss: -7.028347015380859
Iteration: 19000 Loss: -6.882040023803711

Let’s put the training loop in a function for later use:

def train(model, guide, data, num_iter=5_000):
    """Train a model with a guide.

    Arguments
    ---------
    model    -- The model to train.
    guide    -- The guide to train.
    data     -- The data to train the model with.
    num_iter -- The number of iterations to train.
    
    Returns
    -------
    elbos -- The ELBOs for each iteration.
    param_store -- The parameters of the model.
    """

    pyro.clear_param_store()

    optimizer = pyro.optim.Adam({"lr": 0.001})

    svi = pyro.infer.SVI(
        model,
        guide,
        optimizer,
        loss=pyro.infer.JitTrace_ELBO()
    )

    elbos = []
    for i in range(num_iter):
        loss = svi.step(*data)
        elbos.append(-loss)
        if i % 100 == 0:
            print(f"Iteration: {i} Loss: {loss}")

    return elbos, pyro.get_param_store()

Let’s plot the ELBO:

fig, ax = plt.subplots()
ax.plot(torch.arange(num_iter), elbos)
ax.set(xlabel='Iteration', ylabel='ELBO')
sns.despine(trim=True);
../_images/7aa14cddac6c9151b57a61d68b647c35383790065cb1c27423670fa3881cce2a.svg

Here is how you can extract information from the posterior:

guide.median() # Only works for AutoGuide
{'g': tensor(9.7988)}
guide.quantiles([0.25, 0.75]) # Only works for AutoGuide
{'g': tensor([9.7757, 9.8218])}

The other thing that you can do is use pyro.infer.Predictive to get samples from the guide:

guide_g_samples = pyro.infer.Predictive(model, guide=guide, num_samples=1_000)(data)["g"]
fig, ax = plt.subplots()
# You have to detach the tensor from the computational graph
# otherwise it doesn't work.
ax.hist(guide_g_samples.detach().numpy(), bins=20, density=True, label='VI Posterior', alpha=0.5)
ax.plot(xs, gpost.pdf(xs), label='True Posterior')
ax.axvline(g_true, color='r', linestyle='--', label='True value')
plt.legend(loc='best', frameon=False)
sns.despine(trim=True);
../_images/9f44bde354b4d838117863b20730649f5262d6c9e2ca476e173f75e9bb847735.svg

The above code works with every guide, even hand-written ones.

Let’s repeat the above example using our own guide. The results will be identical because AutoNormal is just a normal distribution with learnable parameters.

def guide(data):
    mu = pyro.param("mu", torch.tensor(9.0))
    sigma = pyro.param("sigma", torch.tensor(2.0),
                       constraint=dist.constraints.positive)
    pyro.sample("g", pyro.distributions.Normal(mu, sigma))

That’s it. Now, whenever pyro sees a pyro.param it understands that it needs to optimize it. Let’s test i:

elbos, params = train(model, guide, (data,), num_iter=20_000);
/var/folders/5y/28n32xmx0551k29hd21qs87c0000gp/T/ipykernel_58779/2942119673.py:2: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  mu = pyro.param("mu", torch.tensor(9.0))
/var/folders/5y/28n32xmx0551k29hd21qs87c0000gp/T/ipykernel_58779/2942119673.py:3: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  sigma = pyro.param("sigma", torch.tensor(2.0),
Iteration: 0 Loss: 6863.5166015625
Iteration: 100 Loss: 4726.62841796875
Iteration: 200 Loss: 6681.0615234375
Iteration: 300 Loss: 1037.7930908203125
Iteration: 400 Loss: 5512.32958984375
Iteration: 500 Loss: 1687.9530029296875
Iteration: 600 Loss: 5955.34765625
Iteration: 700 Loss: 318.48486328125
Iteration: 800 Loss: -0.43179094791412354
Iteration: 900 Loss: 1379.1673583984375
Iteration: 1000 Loss: 1462.825927734375
Iteration: 1100 Loss: 62.86969757080078
Iteration: 1200 Loss: 1187.942138671875
Iteration: 1300 Loss: -10.333321571350098
Iteration: 1400 Loss: 1810.6552734375
Iteration: 1500 Loss: 17.4462890625
Iteration: 1600 Loss: 1257.30078125
Iteration: 1700 Loss: 7.44822883605957
Iteration: 1800 Loss: 1563.3077392578125
Iteration: 1900 Loss: 198.15220642089844
Iteration: 2000 Loss: 833.2096557617188
Iteration: 2100 Loss: 280.9329833984375
Iteration: 2200 Loss: -1.0198445320129395
Iteration: 2300 Loss: 342.187744140625
Iteration: 2400 Loss: 11.269716262817383
Iteration: 2500 Loss: -8.514278411865234
Iteration: 2600 Loss: 512.078369140625
Iteration: 2700 Loss: 542.8423461914062
Iteration: 2800 Loss: 32.65810012817383
Iteration: 2900 Loss: 20.113100051879883
Iteration: 3000 Loss: 723.205322265625
Iteration: 3100 Loss: 9.511959075927734
Iteration: 3200 Loss: 399.13018798828125
Iteration: 3300 Loss: -0.7385709285736084
Iteration: 3400 Loss: 582.1942138671875
Iteration: 3500 Loss: 129.11476135253906
Iteration: 3600 Loss: 133.29835510253906
Iteration: 3700 Loss: 161.52444458007812
Iteration: 3800 Loss: 13.85338020324707
Iteration: 3900 Loss: 275.53204345703125
Iteration: 4000 Loss: 37.831024169921875
Iteration: 4100 Loss: 420.3694763183594
Iteration: 4200 Loss: 618.74365234375
Iteration: 4300 Loss: -8.76242733001709
Iteration: 4400 Loss: -2.659954786300659
Iteration: 4500 Loss: 125.52645111083984
Iteration: 4600 Loss: 9.65471076965332
Iteration: 4700 Loss: 24.79283905029297
Iteration: 4800 Loss: -6.347537994384766
Iteration: 4900 Loss: 26.814159393310547
Iteration: 5000 Loss: 150.06910705566406
Iteration: 5100 Loss: 231.2753448486328
Iteration: 5200 Loss: 186.62539672851562
Iteration: 5300 Loss: 2.983067274093628
Iteration: 5400 Loss: -6.850186347961426
Iteration: 5500 Loss: 53.442626953125
Iteration: 5600 Loss: 35.3376579284668
Iteration: 5700 Loss: 22.088422775268555
Iteration: 5800 Loss: 4.767504692077637
Iteration: 5900 Loss: 43.433319091796875
Iteration: 6000 Loss: -8.174591064453125
Iteration: 6100 Loss: 15.100848197937012
Iteration: 6200 Loss: -5.886928558349609
Iteration: 6300 Loss: 2.2705323696136475
Iteration: 6400 Loss: 0.807808518409729
Iteration: 6500 Loss: 16.042556762695312
Iteration: 6600 Loss: 160.491455078125
Iteration: 6700 Loss: 88.98486328125
Iteration: 6800 Loss: 16.722379684448242
Iteration: 6900 Loss: -7.700603485107422
Iteration: 7000 Loss: -7.422141075134277
Iteration: 7100 Loss: 20.69399642944336
Iteration: 7200 Loss: 12.6574068069458
Iteration: 7300 Loss: 75.57760620117188
Iteration: 7400 Loss: -8.795297622680664
Iteration: 7500 Loss: 2.135460376739502
Iteration: 7600 Loss: 15.581229209899902
Iteration: 7700 Loss: 34.729217529296875
Iteration: 7800 Loss: 5.033120632171631
Iteration: 7900 Loss: 64.08548736572266
Iteration: 8000 Loss: -8.337617874145508
Iteration: 8100 Loss: -6.07992696762085
Iteration: 8200 Loss: 0.6487371921539307
Iteration: 8300 Loss: -8.175068855285645
Iteration: 8400 Loss: -0.6840838193893433
Iteration: 8500 Loss: -0.5311352014541626
Iteration: 8600 Loss: 7.673257350921631
Iteration: 8700 Loss: -7.948726654052734
Iteration: 8800 Loss: 39.045772552490234
Iteration: 8900 Loss: -5.875088214874268
Iteration: 9000 Loss: -6.13936710357666
Iteration: 9100 Loss: -5.38765287399292
Iteration: 9200 Loss: -3.403750419616699
Iteration: 9300 Loss: -7.115334510803223
Iteration: 9400 Loss: -3.978450298309326
Iteration: 9500 Loss: -8.129117965698242
Iteration: 9600 Loss: -8.257028579711914
Iteration: 9700 Loss: -8.088420867919922
Iteration: 9800 Loss: -8.254077911376953
Iteration: 9900 Loss: -5.425186634063721
Iteration: 10000 Loss: 0.27979975938796997
Iteration: 10100 Loss: -3.1878466606140137
Iteration: 10200 Loss: -0.5475963354110718
Iteration: 10300 Loss: -6.924840927124023
Iteration: 10400 Loss: -8.02815055847168
Iteration: 10500 Loss: -1.2006770372390747
Iteration: 10600 Loss: -5.848913669586182
Iteration: 10700 Loss: -2.4069879055023193
Iteration: 10800 Loss: -6.6708269119262695
Iteration: 10900 Loss: -6.061272621154785
Iteration: 11000 Loss: -6.134784698486328
Iteration: 11100 Loss: -2.4381508827209473
Iteration: 11200 Loss: -0.30546820163726807
Iteration: 11300 Loss: 11.974300384521484
Iteration: 11400 Loss: -7.944908618927002
Iteration: 11500 Loss: -0.7399576902389526
Iteration: 11600 Loss: -1.2389609813690186
Iteration: 11700 Loss: 25.194276809692383
Iteration: 11800 Loss: -7.398594379425049
Iteration: 11900 Loss: -7.51558780670166
Iteration: 12000 Loss: -7.035286903381348
Iteration: 12100 Loss: -7.560616493225098
Iteration: 12200 Loss: -7.617308616638184
Iteration: 12300 Loss: -7.54684591293335
Iteration: 12400 Loss: -4.699420928955078
Iteration: 12500 Loss: -2.934391498565674
Iteration: 12600 Loss: -4.508200645446777
Iteration: 12700 Loss: 1.2815603017807007
Iteration: 12800 Loss: -6.792910099029541
Iteration: 12900 Loss: -7.661536693572998
Iteration: 13000 Loss: -7.5992207527160645
Iteration: 13100 Loss: -5.169339179992676
Iteration: 13200 Loss: -7.44121789932251
Iteration: 13300 Loss: 1.4796619415283203
Iteration: 13400 Loss: -7.568403244018555
Iteration: 13500 Loss: -4.503005027770996
Iteration: 13600 Loss: -7.495359897613525
Iteration: 13700 Loss: -7.481572151184082
Iteration: 13800 Loss: -6.9588398933410645
Iteration: 13900 Loss: -1.478920817375183
Iteration: 14000 Loss: -4.450002193450928
Iteration: 14100 Loss: -4.372940540313721
Iteration: 14200 Loss: -7.3279242515563965
Iteration: 14300 Loss: -6.8539018630981445
Iteration: 14400 Loss: -6.539721965789795
Iteration: 14500 Loss: -7.326222896575928
Iteration: 14600 Loss: -5.344386577606201
Iteration: 14700 Loss: -6.607089519500732
Iteration: 14800 Loss: -7.293154239654541
Iteration: 14900 Loss: -7.291573524475098
Iteration: 15000 Loss: -6.9718499183654785
Iteration: 15100 Loss: -7.021787643432617
Iteration: 15200 Loss: -7.099093437194824
Iteration: 15300 Loss: -4.358132362365723
Iteration: 15400 Loss: -5.313650608062744
Iteration: 15500 Loss: -6.455587863922119
Iteration: 15600 Loss: -6.4632182121276855
Iteration: 15700 Loss: -5.837204456329346
Iteration: 15800 Loss: -7.031949520111084
Iteration: 15900 Loss: -7.05983304977417
Iteration: 16000 Loss: -6.796117305755615
Iteration: 16100 Loss: -6.874577045440674
Iteration: 16200 Loss: -6.93787956237793
Iteration: 16300 Loss: -6.6703972816467285
Iteration: 16400 Loss: -6.997195243835449
Iteration: 16500 Loss: -7.0669331550598145
Iteration: 16600 Loss: -6.786100387573242
Iteration: 16700 Loss: -5.548006534576416
Iteration: 16800 Loss: -7.049220561981201
Iteration: 16900 Loss: -6.343357086181641
Iteration: 17000 Loss: -7.048659324645996
Iteration: 17100 Loss: -6.335675239562988
Iteration: 17200 Loss: -6.211577892303467
Iteration: 17300 Loss: -6.943521022796631
Iteration: 17400 Loss: -6.881516933441162
Iteration: 17500 Loss: -6.718965530395508
Iteration: 17600 Loss: -6.93778133392334
Iteration: 17700 Loss: -6.441198825836182
Iteration: 17800 Loss: -6.462118148803711
Iteration: 17900 Loss: -6.886161804199219
Iteration: 18000 Loss: -6.8985772132873535
Iteration: 18100 Loss: -6.964326858520508
Iteration: 18200 Loss: -6.686488628387451
Iteration: 18300 Loss: -6.87588357925415
Iteration: 18400 Loss: -6.595379829406738
Iteration: 18500 Loss: -6.88122034072876
Iteration: 18600 Loss: -6.807351112365723
Iteration: 18700 Loss: -6.926386833190918
Iteration: 18800 Loss: -6.8103742599487305
Iteration: 18900 Loss: -6.609341621398926
Iteration: 19000 Loss: -6.961921691894531
Iteration: 19100 Loss: -6.779007911682129
Iteration: 19200 Loss: -6.730861186981201
Iteration: 19300 Loss: -6.651981830596924
Iteration: 19400 Loss: -6.811529636383057
Iteration: 19500 Loss: -6.911961555480957
Iteration: 19600 Loss: -6.773965835571289
Iteration: 19700 Loss: -6.892739772796631
Iteration: 19800 Loss: -6.684868812561035
Iteration: 19900 Loss: -6.803631782531738

The result, is essentially the same as before.

fig, ax = plt.subplots()
ax.plot(torch.arange(num_iter), elbos)
ax.set(xlabel='Iteration', ylabel='ELBO')
sns.despine(trim=True);
../_images/fb008fd9c6d7358cd2c729ab8d9dc707dcc68a6694067ff304e2fbeebc8882a2.svg
guide_g_samples = pyro.infer.Predictive(model, guide=guide, num_samples=1_000)(data)["g"]
fig, ax = plt.subplots()
# You have to detach the tensor from the computational graph
# otherwise it doesn't work.
ax.hist(guide_g_samples.detach().numpy(), bins=20, density=True, label='VI Posterior', alpha=0.5)
ax.plot(xs, gpost.pdf(xs), label='True Posterior')
ax.axvline(g_true, color='r', linestyle='--', label='True value')
plt.legend(loc='best', frameon=False)
sns.despine(trim=True);
../_images/ef33e44dc9d592fb1e6eba002a4950dea7fe9186f0a9682d1202c0e4a2d59942.svg

Example 2 - Coin-toss example#

Just like in the MCMC lecture, let’s look at the process of setting up a model and performing variational inference and diagnostics with the coin toss example.

The probabilistic model is as follows. We observe binary coin toss data:

\[ x_i|\theta \sim \text{Bernoulli}(\theta), \]

for \(i=1, \dots, n\).

The prior over the latent variable \(\theta\) is a Beta distribution:

\[ \theta \sim \text{Beta}(2, 2). \]

We assign the prior as a Beta distribution with shape parameters 2 and 2, corresponding to a weak apriori belief that the coin is most likely fair.

thetaprior = st.beta(2., 2.)
x = np.linspace(0.001, 0.999, 1000)
plt.plot(x, thetaprior.pdf(x))
plt.title('Prior')
sns.despine(trim=True);
../_images/2ab1787854f2bdd4ee0ac28fa6564c0665c47e8d2725de24a02465c4110c8f59.svg

We wish to perform posterior inference on \(\theta\):

\[ p(\theta| x_{1:n}) \propto p(\theta) \prod_{i=1}^{n} p(x_i | \theta). \]

Since this is a conjugate model, we know the posterior in closed form:

\[ \theta | x_{1:n} \sim \text{Beta}\left(2+ \sum_{i=1}^n x_i, 2 + n - \sum_{i=1}^nx_i\right) \]

Let’s generate some fake data and get the analytical posterior for comparison.

thetatrue =0.3
n = 200
data = torch.tensor(np.random.binomial(1, thetatrue, size=(n,)), dtype=torch.float32)
nheads = data.sum()
ntails = n - nheads
theta_post = st.beta(2. + nheads, 2. + ntails)

# plot data 
plt.figure()
plt.subplot(121)
_=plt.bar(*np.unique(data, return_counts=True), width=0.2)
_=plt.xticks([0, 1])
_=plt.title('Observed H/T frequencies')

# plot posterior
plt.subplot(122)
x = np.linspace(0.001, 0.999, 1000)
postpdf = theta_post.pdf(x)
y = np.linspace(0., np.max(postpdf), 100)
plt.plot(x, postpdf, label='Posterior')
plt.plot(x, thetaprior.pdf(x), label='Prior')
plt.plot(thetatrue*np.ones_like(y), y, color="r", linestyle='--', label='True $\\theta$')
plt.legend(loc='best', frameon=False)
plt.xticks()
plt.title('Coin Toss Bayesian Inference')
sns.despine(trim=True);
../_images/5f99d614869c22769b8197b826af7830a3d2ea6c0f5fea27938fd1d325b90273.svg

Let’s make the model in pyro:

def model(data):
    theta = pyro.sample("theta", dist.Beta(1., 1.))
    with pyro.plate('data', data.shape[0]):
        pyro.sample('obs', dist.Bernoulli(theta), obs=data)

Now we need to make the guide. We have many choices. Let’s try a couple. First, let’s just use a Gaussian. We cannot use a Gaussian directly on \(\theta\) though because its support is \([0, 1]\). Instead, we put a Gaussian on a variable \(z\) with support \(\mathbb{R}\):

\[ z \sim N(\mu, \sigma^2). \]

Then we transform \(z\) to \(\theta\) using the sigmoid function:

\[ \theta = \frac{1}{1 + \exp(-z)}. \]

Here is how we can do this in pyro. We can either do it by hand:

from pyro.distributions import constraints
from pyro.distributions import transforms

def gaussian_guide(data):
    mu = pyro.param('mu', torch.tensor(0.0))
    sigma = pyro.param('sigma', torch.tensor(.1), constraint=constraints.positive)
    pyro.sample(
        "theta",
        dist.TransformedDistribution(
            dist.Normal(mu, sigma),
            transforms.SigmoidTransform()
        )
    )

Or we can use the AutoNormal guide:

auto_gaussian_guide = pyro.infer.autoguide.AutoNormal(model)

These are identical. Let’s just train the second one:

elbos, params = train(model, auto_gaussian_guide, (data,), num_iter=20_000);
Iteration: 0 Loss: 146.40711975097656
Iteration: 100 Loss: 141.49867248535156
Iteration: 200 Loss: 134.42784118652344
Iteration: 300 Loss: 129.84580993652344
Iteration: 400 Loss: 128.21202087402344
Iteration: 500 Loss: 130.95372009277344
Iteration: 600 Loss: 129.74700927734375
Iteration: 700 Loss: 128.5478515625
Iteration: 800 Loss: 122.86455535888672
Iteration: 900 Loss: 126.47606658935547
Iteration: 1000 Loss: 124.5881118774414
Iteration: 1100 Loss: 124.98439025878906
Iteration: 1200 Loss: 125.36991119384766
Iteration: 1300 Loss: 125.0341567993164
Iteration: 1400 Loss: 123.77005004882812
Iteration: 1500 Loss: 124.68547058105469
Iteration: 1600 Loss: 124.85704040527344
Iteration: 1700 Loss: 124.83834838867188
Iteration: 1800 Loss: 124.78903198242188
Iteration: 1900 Loss: 124.77840423583984
Iteration: 2000 Loss: 124.79694366455078
Iteration: 2100 Loss: 124.2547378540039
Iteration: 2200 Loss: 124.69900512695312
Iteration: 2300 Loss: 124.7703857421875
Iteration: 2400 Loss: 124.54391479492188
Iteration: 2500 Loss: 124.7482681274414
Iteration: 2600 Loss: 124.72441101074219
Iteration: 2700 Loss: 124.7225570678711
Iteration: 2800 Loss: 124.5625
Iteration: 2900 Loss: 124.7328109741211
Iteration: 3000 Loss: 124.63932037353516
Iteration: 3100 Loss: 124.71368408203125
Iteration: 3200 Loss: 124.7137680053711
Iteration: 3300 Loss: 124.69998168945312
Iteration: 3400 Loss: 124.72123718261719
Iteration: 3500 Loss: 124.69303131103516
Iteration: 3600 Loss: 124.7291259765625
Iteration: 3700 Loss: 124.6519775390625
Iteration: 3800 Loss: 124.72945404052734
Iteration: 3900 Loss: 124.67196655273438
Iteration: 4000 Loss: 124.6773910522461
Iteration: 4100 Loss: 124.67179870605469
Iteration: 4200 Loss: 124.59809875488281
Iteration: 4300 Loss: 124.79076385498047
Iteration: 4400 Loss: 124.67097473144531
Iteration: 4500 Loss: 124.6942138671875
Iteration: 4600 Loss: 124.68804168701172
Iteration: 4700 Loss: 124.74881744384766
Iteration: 4800 Loss: 124.62818908691406
Iteration: 4900 Loss: 124.75748443603516
Iteration: 5000 Loss: 124.6734390258789
Iteration: 5100 Loss: 124.70122528076172
Iteration: 5200 Loss: 124.689208984375
Iteration: 5300 Loss: 124.73365783691406
Iteration: 5400 Loss: 124.67781829833984
Iteration: 5500 Loss: 124.63542938232422
Iteration: 5600 Loss: 124.6197738647461
Iteration: 5700 Loss: 124.66807556152344
Iteration: 5800 Loss: 124.70440673828125
Iteration: 5900 Loss: 124.77986145019531
Iteration: 6000 Loss: 124.6969985961914
Iteration: 6100 Loss: 124.66635131835938
Iteration: 6200 Loss: 124.66107940673828
Iteration: 6300 Loss: 124.65746307373047
Iteration: 6400 Loss: 124.61973571777344
Iteration: 6500 Loss: 124.67634582519531
Iteration: 6600 Loss: 124.75775146484375
Iteration: 6700 Loss: 124.818115234375
Iteration: 6800 Loss: 124.67281341552734
Iteration: 6900 Loss: 124.69622039794922
Iteration: 7000 Loss: 124.6736831665039
Iteration: 7100 Loss: 124.70228576660156
Iteration: 7200 Loss: 124.62702941894531
Iteration: 7300 Loss: 124.63257598876953
Iteration: 7400 Loss: 124.68147277832031
Iteration: 7500 Loss: 124.73921203613281
Iteration: 7600 Loss: 124.63043212890625
Iteration: 7700 Loss: 124.66267395019531
Iteration: 7800 Loss: 124.7544174194336
Iteration: 7900 Loss: 124.66488647460938
Iteration: 8000 Loss: 124.64690399169922
Iteration: 8100 Loss: 124.80863189697266
Iteration: 8200 Loss: 124.59138488769531
Iteration: 8300 Loss: 124.572998046875
Iteration: 8400 Loss: 124.65132141113281
Iteration: 8500 Loss: 124.72315979003906
Iteration: 8600 Loss: 124.62680053710938
Iteration: 8700 Loss: 124.70709228515625
Iteration: 8800 Loss: 124.59783935546875
Iteration: 8900 Loss: 124.52831268310547
Iteration: 9000 Loss: 124.68631744384766
Iteration: 9100 Loss: 124.63186645507812
Iteration: 9200 Loss: 124.74536895751953
Iteration: 9300 Loss: 124.70989227294922
Iteration: 9400 Loss: 124.61333465576172
Iteration: 9500 Loss: 124.81995391845703
Iteration: 9600 Loss: 124.8375244140625
Iteration: 9700 Loss: 124.60042572021484
Iteration: 9800 Loss: 124.78146362304688
Iteration: 9900 Loss: 124.74494934082031
Iteration: 10000 Loss: 124.73494720458984
Iteration: 10100 Loss: 124.68846130371094
Iteration: 10200 Loss: 124.67047119140625
Iteration: 10300 Loss: 124.60511779785156
Iteration: 10400 Loss: 124.61138153076172
Iteration: 10500 Loss: 124.61978912353516
Iteration: 10600 Loss: 124.62849426269531
Iteration: 10700 Loss: 124.71034240722656
Iteration: 10800 Loss: 124.76557159423828
Iteration: 10900 Loss: 124.65545654296875
Iteration: 11000 Loss: 124.73983001708984
Iteration: 11100 Loss: 124.66980743408203
Iteration: 11200 Loss: 124.72845458984375
Iteration: 11300 Loss: 124.65077209472656
Iteration: 11400 Loss: 124.68489074707031
Iteration: 11500 Loss: 124.71796417236328
Iteration: 11600 Loss: 124.6515884399414
Iteration: 11700 Loss: 124.70018768310547
Iteration: 11800 Loss: 124.8572769165039
Iteration: 11900 Loss: 124.78717041015625
Iteration: 12000 Loss: 124.71411895751953
Iteration: 12100 Loss: 124.67271423339844
Iteration: 12200 Loss: 124.47366333007812
Iteration: 12300 Loss: 124.68609619140625
Iteration: 12400 Loss: 124.72593688964844
Iteration: 12500 Loss: 124.66497039794922
Iteration: 12600 Loss: 124.69234466552734
Iteration: 12700 Loss: 124.6677017211914
Iteration: 12800 Loss: 124.67044830322266
Iteration: 12900 Loss: 124.65167999267578
Iteration: 13000 Loss: 124.66949462890625
Iteration: 13100 Loss: 124.60675811767578
Iteration: 13200 Loss: 124.75745391845703
Iteration: 13300 Loss: 124.63687896728516
Iteration: 13400 Loss: 124.67958068847656
Iteration: 13500 Loss: 124.68183898925781
Iteration: 13600 Loss: 124.70476531982422
Iteration: 13700 Loss: 124.62003326416016
Iteration: 13800 Loss: 124.70536041259766
Iteration: 13900 Loss: 124.66747283935547
Iteration: 14000 Loss: 124.63028717041016
Iteration: 14100 Loss: 124.65388488769531
Iteration: 14200 Loss: 124.66047668457031
Iteration: 14300 Loss: 124.7000503540039
Iteration: 14400 Loss: 124.63937377929688
Iteration: 14500 Loss: 124.7081527709961
Iteration: 14600 Loss: 124.58367919921875
Iteration: 14700 Loss: 124.73462677001953
Iteration: 14800 Loss: 124.57368469238281
Iteration: 14900 Loss: 124.71100616455078
Iteration: 15000 Loss: 124.70767211914062
Iteration: 15100 Loss: 124.68736267089844
Iteration: 15200 Loss: 124.70284271240234
Iteration: 15300 Loss: 124.64090728759766
Iteration: 15400 Loss: 124.76419830322266
Iteration: 15500 Loss: 124.64111328125
Iteration: 15600 Loss: 124.69387817382812
Iteration: 15700 Loss: 124.68409729003906
Iteration: 15800 Loss: 124.6617660522461
Iteration: 15900 Loss: 124.68851470947266
Iteration: 16000 Loss: 124.71793365478516
Iteration: 16100 Loss: 124.70984649658203
Iteration: 16200 Loss: 124.69921112060547
Iteration: 16300 Loss: 124.51740264892578
Iteration: 16400 Loss: 124.77755737304688
Iteration: 16500 Loss: 124.66885375976562
Iteration: 16600 Loss: 124.71269989013672
Iteration: 16700 Loss: 124.70817565917969
Iteration: 16800 Loss: 124.25801086425781
Iteration: 16900 Loss: 124.83180236816406
Iteration: 17000 Loss: 124.67884063720703
Iteration: 17100 Loss: 124.67129516601562
Iteration: 17200 Loss: 124.70156860351562
Iteration: 17300 Loss: 124.67505645751953
Iteration: 17400 Loss: 124.71238708496094
Iteration: 17500 Loss: 124.7290267944336
Iteration: 17600 Loss: 124.76727294921875
Iteration: 17700 Loss: 124.67666625976562
Iteration: 17800 Loss: 124.63899993896484
Iteration: 17900 Loss: 124.52882385253906
Iteration: 18000 Loss: 124.59961700439453
Iteration: 18100 Loss: 124.64694213867188
Iteration: 18200 Loss: 124.68242645263672
Iteration: 18300 Loss: 124.67755126953125
Iteration: 18400 Loss: 124.6822280883789
Iteration: 18500 Loss: 124.6970443725586
Iteration: 18600 Loss: 124.69671630859375
Iteration: 18700 Loss: 124.67792510986328
Iteration: 18800 Loss: 124.6851577758789
Iteration: 18900 Loss: 124.79765319824219
Iteration: 19000 Loss: 124.63285064697266
Iteration: 19100 Loss: 124.67835235595703
Iteration: 19200 Loss: 124.69660186767578
Iteration: 19300 Loss: 124.70985412597656
Iteration: 19400 Loss: 124.7074966430664
Iteration: 19500 Loss: 124.78852081298828
Iteration: 19600 Loss: 124.65253448486328
Iteration: 19700 Loss: 124.67334747314453
Iteration: 19800 Loss: 124.64019775390625
Iteration: 19900 Loss: 124.61317443847656

The evolution of the ELBO:

fig, ax = plt.subplots()
ax.plot(elbos)
ax.set(xlabel='Iteration', ylabel='ELBO', title='ELBO vs SGD Iterations')
sns.despine(trim=True);
../_images/7115a43faee1ada5afa5caeba6e79ccabc1c96e405e55915dbc77ffe94fb6153.svg

Here is the posterior:

thetas = pyro.infer.Predictive(model, guide=auto_gaussian_guide, num_samples=1_000)(data)["theta"]

fig, ax = plt.subplots()
ax.hist(thetas.detach().numpy(), bins=50, density=True, label='VI Posterior (Gaussian)')
ax.plot(x, thetaprior.pdf(x), label='Prior')
ax.plot(x, theta_post.pdf(x), label='True Posterior')
ax.plot(thetatrue*np.ones_like(y), y, color="r", linestyle='--', label='True $\\theta$')
ax.legend(loc='best', frameon=False)
ax.set(xlabel='$\\theta$', ylabel='Density', title='Coin Toss Bayesian Inference')
sns.despine(trim=True);
../_images/ad31737493f5ac7e03caf761b97ded7cf238176c7d4610947a3f1a659e06a630.svg

Pretty good.

Example 3 - Challenger Space Shuttle Disaster#

Let’s revisit this example from the MCMC lecture.

url = "https://github.com/PredictiveScienceLab/data-analytics-se/raw/master/lecturebook/data/challenger_data.csv"
download(url)
# load data 
challenger_data = np.genfromtxt("challenger_data.csv", skip_header=1,
                                usecols=[1, 2], missing_values="NA",
                                delimiter=",")
challenger_data = challenger_data[~np.isnan(challenger_data[:, 1])]
print("Temp (F), O-Ring failure?")
print(challenger_data)
Temp (F), O-Ring failure?
[[66.  0.]
 [70.  1.]
 [69.  0.]
 [68.  0.]
 [67.  0.]
 [72.  0.]
 [73.  0.]
 [70.  0.]
 [57.  1.]
 [63.  1.]
 [70.  1.]
 [78.  0.]
 [67.  0.]
 [53.  1.]
 [67.  0.]
 [75.  0.]
 [70.  0.]
 [81.  0.]
 [76.  0.]
 [79.  0.]
 [75.  1.]
 [76.  0.]
 [58.  1.]]
fig, ax = plt.subplots()
ax.plot(challenger_data[:, 0], challenger_data[:, 1], 'ro')
ax.set(ylabel="Damage Incident?", xlabel="Outside temperature (Fahrenheit)")
plt.yticks([0, 1])
plt.xticks()
sns.despine(trim=True);
../_images/e720eea53b9cb1914b6d48ddce1aa9605f5d6f9f195ca8205640233ae70b1906.svg

Probabilistic model#

The defect probability is modeled as a function of the outside temperature:

\[ \sigma(t;\alpha,\beta) = \frac{1}{ 1 + e^{ \;\beta t + \alpha } }. \]

The goal is to infer the latent variables \(\alpha\) and \(\beta\).

We set normal priors on the latent variables:

\[ \alpha \sim N(0, 10^2), \]

and

\[ \beta \sim N(0, 10^2), \]

and the likelihood model is given by:

\[ p(x_i | \alpha, \beta, t) = \text{Bernoulli}(x_i | \sigma(t; \alpha, \beta) ). \]

Here is the model in pyro:

challenger_data = torch.tensor(challenger_data)
temperature = challenger_data[:, 0]
defect = challenger_data[:, 1]

def model(temperature, defect):
    alpha = pyro.sample('alpha', dist.Normal(0., 100.))
    beta = pyro.sample('beta', dist.Normal(0., 100.))
    with pyro.plate('data', temperature.shape[0]):
        logits = pyro.deterministic('logits', alpha + beta * temperature)
        pyro.sample('obs', dist.Bernoulli(logits=logits), obs=defect)
    return locals()

We will use the AutoNormal guide:

guide = pyro.infer.autoguide.AutoDiagonalNormal(model)

elbos, params = train(model, guide, (temperature, defect), num_iter=40_000);
Iteration: 0 Loss: 5531.840440750122
Iteration: 100 Loss: 5467.439811468124
Iteration: 200 Loss: 5410.4613201618195
Iteration: 300 Loss: 5439.929831027985
Iteration: 400 Loss: 5344.992549419403
Iteration: 500 Loss: 5336.480083465576
Iteration: 600 Loss: 5237.499451160431
Iteration: 700 Loss: 5137.388278126717
Iteration: 800 Loss: 5212.8310779333115
Iteration: 900 Loss: 5149.711973071098
Iteration: 1000 Loss: 5135.877537846565
Iteration: 1100 Loss: 5152.2343854904175
Iteration: 1200 Loss: 4916.988680720329
Iteration: 1300 Loss: 4904.5195878744125
Iteration: 1400 Loss: 4859.160486340523
Iteration: 1500 Loss: 4790.789856433868
Iteration: 1600 Loss: 4793.873239278793
Iteration: 1700 Loss: 4716.390684723854
Iteration: 1800 Loss: 4715.97630906105
Iteration: 1900 Loss: 4699.5179080963135
Iteration: 2000 Loss: 4561.090007901192
Iteration: 2100 Loss: 4613.398386597633
Iteration: 2200 Loss: 4485.692254781723
Iteration: 2300 Loss: 4550.878390073776
Iteration: 2400 Loss: 4473.403885304928
Iteration: 2500 Loss: 4429.689296364784
Iteration: 2600 Loss: 4329.89715385437
Iteration: 2700 Loss: 4292.824923992157
Iteration: 2800 Loss: 4254.371640324593
Iteration: 2900 Loss: 4187.539891242981
Iteration: 3000 Loss: 4122.154160141945
Iteration: 3100 Loss: 4129.982268214226
Iteration: 3200 Loss: 4145.018218278885
Iteration: 3300 Loss: 4029.89741563797
Iteration: 3400 Loss: 3999.8724485635757
Iteration: 3500 Loss: 3987.2998255491257
Iteration: 3600 Loss: 3925.0976563692093
Iteration: 3700 Loss: 3825.536500453949
Iteration: 3800 Loss: 3795.5549738407135
Iteration: 3900 Loss: 3741.4490801095963
Iteration: 4000 Loss: 3689.951833844185
Iteration: 4100 Loss: 3666.78730905056
Iteration: 4200 Loss: 3522.1737109422684
Iteration: 4300 Loss: 3558.4692071676254
Iteration: 4400 Loss: 3436.4042862653732
Iteration: 4500 Loss: 3473.5815712213516
Iteration: 4600 Loss: 3347.4105162620544
Iteration: 4700 Loss: 3312.775734066963
Iteration: 4800 Loss: 3326.7731975317
Iteration: 4900 Loss: 3255.529089808464
Iteration: 5000 Loss: 3244.5093750953674
Iteration: 5100 Loss: 3099.211985349655
Iteration: 5200 Loss: 3126.5875456929207
Iteration: 5300 Loss: 3212.296360850334
Iteration: 5400 Loss: 3104.723169028759
Iteration: 5500 Loss: 2971.083088517189
Iteration: 5600 Loss: 2919.641637444496
Iteration: 5700 Loss: 2984.3093383312225
Iteration: 5800 Loss: 2986.3900289535522
Iteration: 5900 Loss: 2907.4370554089546
Iteration: 6000 Loss: 2736.9824554920197
Iteration: 6100 Loss: 2755.4019446372986
Iteration: 6200 Loss: 2645.9654043912888
Iteration: 6300 Loss: 2659.0154242515564
Iteration: 6400 Loss: 2601.622731566429
Iteration: 6500 Loss: 2560.792522072792
Iteration: 6600 Loss: 2546.4471768140793
Iteration: 6700 Loss: 2459.5261366963387
Iteration: 6800 Loss: 2483.570861876011
Iteration: 6900 Loss: 2442.6171367764473
Iteration: 7000 Loss: 2415.6285407543182
Iteration: 7100 Loss: 2333.2669653892517
Iteration: 7200 Loss: 2314.1345807909966
Iteration: 7300 Loss: 2233.409161865711
Iteration: 7400 Loss: 2108.001457452774
Iteration: 7500 Loss: 2091.66572278738
Iteration: 7600 Loss: 2147.742255270481
Iteration: 7700 Loss: 1986.1463367938995
Iteration: 7800 Loss: 2074.554480791092
Iteration: 7900 Loss: 1988.131707072258
Iteration: 8000 Loss: 1856.3882100582123
Iteration: 8100 Loss: 1912.0738167762756
Iteration: 8200 Loss: 1778.7179647684097
Iteration: 8300 Loss: 1725.1792503595352
Iteration: 8400 Loss: 1720.107176065445
Iteration: 8500 Loss: 1602.0498898029327
Iteration: 8600 Loss: 1603.1558706760406
Iteration: 8700 Loss: 1556.1041314601898
Iteration: 8800 Loss: 1492.0800383090973
Iteration: 8900 Loss: 1543.2749433517456
Iteration: 9000 Loss: 1486.2857413291931
Iteration: 9100 Loss: 1415.5964547395706
Iteration: 9200 Loss: 1400.3231258392334
Iteration: 9300 Loss: 1257.0089687108994
Iteration: 9400 Loss: 1329.7085238695145
Iteration: 9500 Loss: 1251.2551156282425
Iteration: 9600 Loss: 1277.0183324813843
Iteration: 9700 Loss: 1234.7495634555817
Iteration: 9800 Loss: 1147.988602757454
Iteration: 9900 Loss: 1057.0762363672256
Iteration: 10000 Loss: 1031.7693130970001
Iteration: 10100 Loss: 892.5639562606812
Iteration: 10200 Loss: 909.6632229089737
Iteration: 10300 Loss: 877.7933025360107
Iteration: 10400 Loss: 753.6981749534607
Iteration: 10500 Loss: 715.0268689393997
Iteration: 10600 Loss: 721.5288599729538
Iteration: 10700 Loss: 693.9215047359467
Iteration: 10800 Loss: 730.3371777534485
Iteration: 10900 Loss: 607.5542094707489
Iteration: 11000 Loss: 479.1928918361664
Iteration: 11100 Loss: 477.115443110466
Iteration: 11200 Loss: 438.5263637304306
Iteration: 11300 Loss: 387.6681262254715
Iteration: 11400 Loss: 397.86532521247864
Iteration: 11500 Loss: 357.18476259708405
Iteration: 11600 Loss: 257.59286415577355
Iteration: 11700 Loss: 203.42165434589057
Iteration: 11800 Loss: 252.13669669628655
Iteration: 11900 Loss: 155.4922615321634
Iteration: 12000 Loss: 108.06966466168193
Iteration: 12100 Loss: 120.54031702338864
Iteration: 12200 Loss: 28.131296986370266
Iteration: 12300 Loss: 48.86347971390331
Iteration: 12400 Loss: 81.91930581011749
Iteration: 12500 Loss: 34.40188861010839
Iteration: 12600 Loss: 50.633073009812
Iteration: 12700 Loss: 46.92756452962382
Iteration: 12800 Loss: 37.385267973718896
Iteration: 12900 Loss: 42.964566793642724
Iteration: 13000 Loss: 49.122924703198976
Iteration: 13100 Loss: 28.693002473911697
Iteration: 13200 Loss: 21.94563579882517
Iteration: 13300 Loss: 24.6155895023483
Iteration: 13400 Loss: 53.0983372642654
Iteration: 13500 Loss: 45.253892200357114
Iteration: 13600 Loss: 31.90642666580855
Iteration: 13700 Loss: 22.36246727093565
Iteration: 13800 Loss: 64.1029375010568
Iteration: 13900 Loss: 38.094568499962506
Iteration: 14000 Loss: 32.348311205312584
Iteration: 14100 Loss: 33.94302588581073
Iteration: 14200 Loss: 41.013375893873125
Iteration: 14300 Loss: 46.296571659837056
Iteration: 14400 Loss: 22.21393667483687
Iteration: 14500 Loss: 28.75701686681686
Iteration: 14600 Loss: 23.146422104396294
Iteration: 14700 Loss: 41.55522962110102
Iteration: 14800 Loss: 23.621197192003187
Iteration: 14900 Loss: 22.23556067286522
Iteration: 15000 Loss: 75.79573570577291
Iteration: 15100 Loss: 28.071933327482274
Iteration: 15200 Loss: 26.54055291123494
Iteration: 15300 Loss: 25.109224081943513
Iteration: 15400 Loss: 33.03099874210464
Iteration: 15500 Loss: 30.37252674959501
Iteration: 15600 Loss: 23.35132493350241
Iteration: 15700 Loss: 25.2834504006692
Iteration: 15800 Loss: 33.63218533849563
Iteration: 15900 Loss: 33.35619829171088
Iteration: 16000 Loss: 33.2399583875767
Iteration: 16100 Loss: 22.823433792856573
Iteration: 16200 Loss: 23.606800600222307
Iteration: 16300 Loss: 26.03480179996911
Iteration: 16400 Loss: 23.373104679618894
Iteration: 16500 Loss: 32.687484351071376
Iteration: 16600 Loss: 27.81495510019398
Iteration: 16700 Loss: 26.497814345600695
Iteration: 16800 Loss: 40.23385912496974
Iteration: 16900 Loss: 23.544696206629716
Iteration: 17000 Loss: 26.092794149913892
Iteration: 17100 Loss: 25.733030370726496
Iteration: 17200 Loss: 23.590334892826025
Iteration: 17300 Loss: 22.771762556847413
Iteration: 17400 Loss: 26.93055204228299
Iteration: 17500 Loss: 23.620860005580717
Iteration: 17600 Loss: 32.72158898102857
Iteration: 17700 Loss: 29.851976563585705
Iteration: 17800 Loss: 30.637191060142325
Iteration: 17900 Loss: 26.575991850100266
Iteration: 18000 Loss: 22.897891640719276
Iteration: 18100 Loss: 24.232797890467587
Iteration: 18200 Loss: 27.709378925535635
Iteration: 18300 Loss: 26.139397063040374
Iteration: 18400 Loss: 24.27635989949359
Iteration: 18500 Loss: 24.3850958452265
Iteration: 18600 Loss: 23.523260291355385
Iteration: 18700 Loss: 20.11709161262572
Iteration: 18800 Loss: 26.137081675258724
Iteration: 18900 Loss: 24.100693261496772
Iteration: 19000 Loss: 24.080908056868388
Iteration: 19100 Loss: 26.71790017765957
Iteration: 19200 Loss: 22.539628616002393
Iteration: 19300 Loss: 28.01652611476491
Iteration: 19400 Loss: 26.264811232480724
Iteration: 19500 Loss: 25.632215380126546
Iteration: 19600 Loss: 26.184025786404945
Iteration: 19700 Loss: 24.991171580304353
Iteration: 19800 Loss: 20.762267645382124
Iteration: 19900 Loss: 23.443045101904502
Iteration: 20000 Loss: 25.557818279907185
Iteration: 20100 Loss: 25.413245100376635
Iteration: 20200 Loss: 24.35236036325578
Iteration: 20300 Loss: 23.695267222021002
Iteration: 20400 Loss: 30.088536627596085
Iteration: 20500 Loss: 25.456576625732943
Iteration: 20600 Loss: 24.855850276203046
Iteration: 20700 Loss: 25.98218212723841
Iteration: 20800 Loss: 25.203806922643402
Iteration: 20900 Loss: 23.739397518029477
Iteration: 21000 Loss: 24.998446105565144
Iteration: 21100 Loss: 26.125275932352075
Iteration: 21200 Loss: 25.284852833565587
Iteration: 21300 Loss: 23.335551004859354
Iteration: 21400 Loss: 24.780393375426073
Iteration: 21500 Loss: 23.53304818189704
Iteration: 21600 Loss: 24.30075760145206
Iteration: 21700 Loss: 25.873531213493262
Iteration: 21800 Loss: 24.763838635407726
Iteration: 21900 Loss: 24.65449638854551
Iteration: 22000 Loss: 25.238088151636
Iteration: 22100 Loss: 25.184846087505
Iteration: 22200 Loss: 26.412554197003594
Iteration: 22300 Loss: 24.31290720351019
Iteration: 22400 Loss: 27.01054110364891
Iteration: 22500 Loss: 24.541662875075783
Iteration: 22600 Loss: 25.171726538521007
Iteration: 22700 Loss: 24.63645757150681
Iteration: 22800 Loss: 24.716924068539175
Iteration: 22900 Loss: 25.09385697342757
Iteration: 23000 Loss: 24.54968923651359
Iteration: 23100 Loss: 24.939830410122596
Iteration: 23200 Loss: 24.77008452932077
Iteration: 23300 Loss: 24.693774089049796
Iteration: 23400 Loss: 25.04027028670434
Iteration: 23500 Loss: 24.71046433660923
Iteration: 23600 Loss: 24.74220483982083
Iteration: 23700 Loss: 24.930580907278273
Iteration: 23800 Loss: 24.63441159764347
Iteration: 23900 Loss: 25.89271028425992
Iteration: 24000 Loss: 22.31327231766053
Iteration: 24100 Loss: 24.86475461533178
Iteration: 24200 Loss: 24.642195178416017
Iteration: 24300 Loss: 24.1808025435154
Iteration: 24400 Loss: 24.32992456372284
Iteration: 24500 Loss: 24.648959134876698
Iteration: 24600 Loss: 24.615056160032665
Iteration: 24700 Loss: 25.01263727221301
Iteration: 24800 Loss: 22.950913192952093
Iteration: 24900 Loss: 26.242057011479908
Iteration: 25000 Loss: 25.99351987860893
Iteration: 25100 Loss: 25.211691610602067
Iteration: 25200 Loss: 24.264029837725033
Iteration: 25300 Loss: 24.766253170246493
Iteration: 25400 Loss: 23.491357273749117
Iteration: 25500 Loss: 23.379243879204736
Iteration: 25600 Loss: 24.22769627558646
Iteration: 25700 Loss: 25.445399706624265
Iteration: 25800 Loss: 24.978734815570046
Iteration: 25900 Loss: 24.042654637850212
Iteration: 26000 Loss: 24.326386272259043
Iteration: 26100 Loss: 24.53441774696436
Iteration: 26200 Loss: 24.28823064848941
Iteration: 26300 Loss: 24.18363859979795
Iteration: 26400 Loss: 25.158189741818024
Iteration: 26500 Loss: 24.947390314835634
Iteration: 26600 Loss: 24.18184237198568
Iteration: 26700 Loss: 24.92485720576039
Iteration: 26800 Loss: 24.085106304787985
Iteration: 26900 Loss: 25.154899753953497
Iteration: 27000 Loss: 24.729930345436323
Iteration: 27100 Loss: 26.68456166479145
Iteration: 27200 Loss: 24.785857837192374
Iteration: 27300 Loss: 24.714093779444056
Iteration: 27400 Loss: 24.951900485138168
Iteration: 27500 Loss: 32.60548680519277
Iteration: 27600 Loss: 24.547519421970584
Iteration: 27700 Loss: 24.776071810375644
Iteration: 27800 Loss: 25.1776069486447
Iteration: 27900 Loss: 25.980662046562063
Iteration: 28000 Loss: 24.8114534916983
Iteration: 28100 Loss: 24.569198951483795
Iteration: 28200 Loss: 26.071178728460126
Iteration: 28300 Loss: 24.52668833912155
Iteration: 28400 Loss: 24.544690495539232
Iteration: 28500 Loss: 24.361628396894943
Iteration: 28600 Loss: 25.047156040604236
Iteration: 28700 Loss: 24.701738551817808
Iteration: 28800 Loss: 24.416437732109834
Iteration: 28900 Loss: 25.21581594681751
Iteration: 29000 Loss: 24.40752640350304
Iteration: 29100 Loss: 24.62851916030602
Iteration: 29200 Loss: 23.400316986498588
Iteration: 29300 Loss: 24.116426538859006
Iteration: 29400 Loss: 25.668033055956013
Iteration: 29500 Loss: 25.01011626478007
Iteration: 29600 Loss: 26.892617660074222
Iteration: 29700 Loss: 24.623005919196082
Iteration: 29800 Loss: 25.363858118834564
Iteration: 29900 Loss: 24.700079187379238
Iteration: 30000 Loss: 23.176525311038823
Iteration: 30100 Loss: 24.794078392183582
Iteration: 30200 Loss: 24.47935007929105
Iteration: 30300 Loss: 24.769586561049437
Iteration: 30400 Loss: 25.493242609219337
Iteration: 30500 Loss: 25.126654099587313
Iteration: 30600 Loss: 24.677962226451413
Iteration: 30700 Loss: 22.933394798496593
Iteration: 30800 Loss: 24.137780233145975
Iteration: 30900 Loss: 24.095835924508656
Iteration: 31000 Loss: 24.489420498804947
Iteration: 31100 Loss: 23.630513066941244
Iteration: 31200 Loss: 24.72929591966595
Iteration: 31300 Loss: 24.93994864282298
Iteration: 31400 Loss: 23.566433791228064
Iteration: 31500 Loss: 24.241119301102124
Iteration: 31600 Loss: 24.199153438155506
Iteration: 31700 Loss: 21.00680318833583
Iteration: 31800 Loss: 24.893609652941482
Iteration: 31900 Loss: 23.47715658267218
Iteration: 32000 Loss: 24.88978917170494
Iteration: 32100 Loss: 24.6301785334378
Iteration: 32200 Loss: 25.18476292518303
Iteration: 32300 Loss: 24.920443379892866
Iteration: 32400 Loss: 24.680385032943512
Iteration: 32500 Loss: 24.728282281483672
Iteration: 32600 Loss: 23.635705259289374
Iteration: 32700 Loss: 25.95586955837196
Iteration: 32800 Loss: 24.723580397017614
Iteration: 32900 Loss: 24.666875734401636
Iteration: 33000 Loss: 25.032087855050513
Iteration: 33100 Loss: 24.723617230234183
Iteration: 33200 Loss: 25.25495396426038
Iteration: 33300 Loss: 26.741601171418573
Iteration: 33400 Loss: 25.25516262143809
Iteration: 33500 Loss: 25.23764309868569
Iteration: 33600 Loss: 25.21635326284201
Iteration: 33700 Loss: 26.260585099524526
Iteration: 33800 Loss: 24.841972574427604
Iteration: 33900 Loss: 24.708010783228623
Iteration: 34000 Loss: 24.578025781224785
Iteration: 34100 Loss: 25.067507734425412
Iteration: 34200 Loss: 23.26360834401894
Iteration: 34300 Loss: 24.589993648834046
Iteration: 34400 Loss: 22.88195302138935
Iteration: 34500 Loss: 25.25582019984487
Iteration: 34600 Loss: 24.19041685155333
Iteration: 34700 Loss: 24.874783366052842
Iteration: 34800 Loss: 23.928824580766328
Iteration: 34900 Loss: 23.57181191540545
Iteration: 35000 Loss: 25.4204153442002
Iteration: 35100 Loss: 27.3725087588739
Iteration: 35200 Loss: 25.82517000924866
Iteration: 35300 Loss: 24.672316666841574
Iteration: 35400 Loss: 24.954654894457246
Iteration: 35500 Loss: 25.064021662474637
Iteration: 35600 Loss: 24.467243174303327
Iteration: 35700 Loss: 24.58402136130203
Iteration: 35800 Loss: 25.183380740669044
Iteration: 35900 Loss: 26.84162246961659
Iteration: 36000 Loss: 24.871028071220614
Iteration: 36100 Loss: 25.961890336413756
Iteration: 36200 Loss: 24.70609163104905
Iteration: 36300 Loss: 25.214652789167886
Iteration: 36400 Loss: 25.53647420010599
Iteration: 36500 Loss: 24.887390255542947
Iteration: 36600 Loss: 24.52048166771558
Iteration: 36700 Loss: 25.187576438120992
Iteration: 36800 Loss: 24.97245007480887
Iteration: 36900 Loss: 24.514226378996316
Iteration: 37000 Loss: 26.312269941474767
Iteration: 37100 Loss: 24.709581556815486
Iteration: 37200 Loss: 24.912317514930283
Iteration: 37300 Loss: 24.759607036111817
Iteration: 37400 Loss: 23.855255417048244
Iteration: 37500 Loss: 24.83772185189435
Iteration: 37600 Loss: 24.244234649730217
Iteration: 37700 Loss: 24.878440081829496
Iteration: 37800 Loss: 24.755173988172373
Iteration: 37900 Loss: 25.22221595510717
Iteration: 38000 Loss: 25.835765420953486
Iteration: 38100 Loss: 27.64800215890736
Iteration: 38200 Loss: 24.890146405070634
Iteration: 38300 Loss: 24.791493714645217
Iteration: 38400 Loss: 24.79983196491165
Iteration: 38500 Loss: 24.626213595485872
Iteration: 38600 Loss: 24.59559494682597
Iteration: 38700 Loss: 24.390439427679482
Iteration: 38800 Loss: 26.236338326435767
Iteration: 38900 Loss: 24.699624270748654
Iteration: 39000 Loss: 24.34034355986111
Iteration: 39100 Loss: 24.099570085433996
Iteration: 39200 Loss: 23.612481569072724
Iteration: 39300 Loss: 25.33456478873308
Iteration: 39400 Loss: 25.796397527083784
Iteration: 39500 Loss: 22.78682486263129
Iteration: 39600 Loss: 24.111162545394915
Iteration: 39700 Loss: 24.733798883841885
Iteration: 39800 Loss: 25.33748707753372
Iteration: 39900 Loss: 25.219414985062468

It is always a good idea to plot the ELBO:

fig, ax = plt.subplots()
ax.plot(elbos)
ax.set(xlabel='Iteration', ylabel='ELBO', title='ELBO vs SGD Iterations')
sns.despine(trim=True);
../_images/49ba4d61db17a752bd5a5ae0ccaf5bf47e0ecf5475b01465d14dd5046fb73292.svg

Now let’s plot the posterior of the parameters:

param_samples = pyro.infer.Predictive(model, guide=guide, num_samples=1_000)(temperature, defect)

alpha_samples = param_samples["alpha"]
beta_samples = param_samples["beta"]

fig, ax = plt.subplots()
ax.hist(alpha_samples.detach().numpy(), bins=50, density=True, label='VI Posterior ($\\alpha$)', alpha=0.5)
ax.set(xlabel='$\\alpha$', ylabel='VI Posterior Density')
sns.despine(trim=True);

fig, ax = plt.subplots()
ax.hist(beta_samples.detach().numpy(), bins=50, density=True, label='VI Posterior ($\\beta$)', alpha=0.5)
ax.set(xlabel='$\\beta$', ylabel='VI Posterior Density')
sns.despine(trim=True);

# Scatter plot of the samples
fig, ax = plt.subplots()
ax.scatter(alpha_samples.detach().numpy(), beta_samples.detach().numpy(), alpha=0.5)
ax.set(xlabel='$\\alpha$', ylabel='$\\beta$')
sns.despine(trim=True);
../_images/e15e81a6d3e158f0e817c34f43dd344d86e5c13e038f9e0a906b34be879a26ad.svg../_images/ffc4f517c4c4d928dc13788321a25ba9a6bf50fafab31d91b36c7f940e8477e8.svg../_images/0a8e5b4e13ca1176c4d360074e673bcdff0e06cd5459da39d8b76f3add81f394.svg

Now we would like to plot the posterior predictive distribution as a function of the temperature.

def predictive_model(temperature, defect):
    # Use the original model to define the variables
    vars = model(temperature, defect)
    alpha = vars["alpha"]
    beta = vars["beta"]
    temps = torch.linspace(temperature.min(), temperature.max(), 100)
    predictions = pyro.deterministic('predictions', torch.sigmoid(alpha + beta * temps))
    return locals()
param_samples = pyro.infer.Predictive(
    predictive_model,
    guide=guide,
    num_samples=1_000
)(temperature, defect)

temps = torch.linspace(temperature.min(), temperature.max(), 100)
predictions = param_samples["predictions"]
pred_500, pred_025, pred_975 = predictions.quantile(torch.tensor([0.5, 0.025, 0.975]), dim=0)

fig, ax = plt.subplots()
ax.plot(temperature, defect, 'ro')
ax.plot(temps, pred_500.flatten(), 'b')
ax.fill_between(temps, pred_025.flatten(), pred_975.flatten(), color='b', alpha=0.1)
ax.set(ylabel="Damage Incident?", xlabel="Outside temperature (Fahrenheit)")
# Plot a few samples
for i in range(10):
    ax.plot(temps, predictions[i].flatten(), 'b', alpha=0.1)
plt.yticks([0, 1])
plt.xticks()
sns.despine(trim=True);
../_images/c30a518a72e6b135a7c66c17c9ef096d6f4326afa276a481addfa954e89a6ebb.svg