Hide code cell source
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks");

The Multivariate Normal - Conditioning#

Consider the \(N\)-dimensional multivariate normal:

\[ \mathbf{X} \sim N\left(\boldsymbol{\mu}, \boldsymbol{\Sigma}\right), \]

where \(\boldsymbol{\mu}\) is a \(N\)-dimensional vector, \(\boldsymbol{\Sigma}\) is a positive-definite matrix.

Now split \(\mathbf{X}\) into two vectors \(\mathbf{X}_1\) and \(\mathbf{X}_2\) of dimensions \(N_1\) and \(N_2\) (\(N_1 + N_2 = N\)):

\[\begin{split} \mathbf{X} = \begin{pmatrix} \mathbf{X}_1\\ \mathbf{X}_2 \end{pmatrix}. \end{split}\]

Similarly, split \(\boldsymbol{\mu}\) into two vectors \(\boldsymbol{\mu}_1\) and \(\boldsymbol{\mu}_2\) of dimensions \(N_1\) and \(N_2\) (\(N_1 + N_2 = N\)):

\[\begin{split} \boldsymbol{\mu} = \begin{pmatrix} \boldsymbol{\mu}_1\\ \boldsymbol{\mu}_2 \end{pmatrix}. \end{split}\]

Similarly for \(\boldsymbol{\Sigma}\):

\[\begin{split} \boldsymbol{\Sigma} = \begin{pmatrix} \boldsymbol{\Sigma}_1 & \boldsymbol{\Sigma}_{12}\\ \boldsymbol{\Sigma}_{12}^T&\boldsymbol{\Sigma}_2 \end{pmatrix}, \end{split}\]

where \(\boldsymbol{\Sigma}_{ii}\) are \(N_i\times N_i\) matrices, and \(\boldsymbol{\Sigma}_{12}\) is a \(N_1\times N_2\) matrix.

Using marginalization, we can show that:

\[ \mathbf{X}_1 \sim N\left(\boldsymbol{\mu}_1, \boldsymbol{\Sigma}_1\right). \]

and

\[ \mathbf{X}_2 \sim N\left(\boldsymbol{\mu}_2, \boldsymbol{\Sigma}_2\right). \]

Suppose now we know the value of \(\mathbf{X}_2\), i.e., \(\mathbf{X}_2 = \mathbf{x}_2\). What is the distribution of \(\mathbf{X}_1\)? To do this, we apply Bayes’ rule:

\[ p(\mathbf{x}_1|\mathbf{x}_2) = \frac{p(\mathbf{x}_1,\mathbf{x}_2)}{p(\mathbf{x}_2)}. \]

We have all the required terms on the right-hand side. We need to substitute and do the algebra. If we do it and we use the “complete the square” trick, we get:

\[ \mathbf{X}_1|\mathbf{X}_2 = \mathbf{x}_2 = N(\boldsymbol{\mu}_{1|2}, \boldsymbol{\Sigma}_{1|2})), \]

where

\[ \boldsymbol{\mu}_{1|2} = \boldsymbol{\mu}_1+\boldsymbol{\Sigma}_{12}\boldsymbol{\Sigma}_2^{-1}(\mathbf{x}_2-\boldsymbol{\mu}_2), \]

and

\[ \boldsymbol{\Sigma}_{1|2} = \boldsymbol{\Sigma}_1-\boldsymbol{\Sigma}_{12}\boldsymbol{\Sigma}_2^{-1}\boldsymbol{\Sigma}_{12}^T. \]

Note

More details If you want to see the details, read Chapter 2.3 of [Bishop, 2006].

Let’s demonstrate this with an example.

Hide code cell source
import numpy as np
import scipy.stats as st

# This is the multivariate normal we are going to play with
X = st.multivariate_normal(
    mean=np.array([1.0, 2.0]),
    cov=np.array(
        [
            [2.0, 0.9],
            [0.9, 4.0]
        ]
    )
)

print("X ~ N(mu, Sigma),")
print(f"mu = {X.mean}")
print("Sigma = ")
print(X.cov)
print("")

x2_observed = -1.0
print(f"x_2 = {x2_observed:.2f} (hypothetical observation)")
X ~ N(mu, Sigma),
mu = [1. 2.]
Sigma = 
[[2.  0.9]
 [0.9 4. ]]

x_2 = -1.00 (hypothetical observation)

Let’s plot the contour of the joint and see where \(x_2\) falls:

Hide code cell source
fig, ax = plt.subplots()
x1 = np.linspace(-3, 5, 64)
x2 = np.linspace(-3, 6, 64)
Xg1, Xg2 = np.meshgrid(x1, x2)
Xg_flat = np.hstack(
    [
        Xg1.flatten()[:, None],
        Xg2.flatten()[:, None]
    ]
)
Z = X.pdf(Xg_flat).reshape(Xg1.shape)
c = ax.contour(Xg1, Xg2, Z)
ax.plot(
    x1,
    [x2_observed] * np.ones(x1.shape[0]),
    "--", 
    label=r"Observed $x_2$")
ax.clabel(c, inline=1, fontsize=10)
plt.legend(loc="best", frameon=False)
ax.set_xlabel('$x_1$')
ax.set_ylabel('$x_2$')
sns.despine(trim=True);
../_images/50d0885f13428d0a937a7c818ed3a1f0f8c2f268a959199f36163638f48bd182.svg

Intuitively, the probability density of getting a particular value \(x_1\) is proportional to the joint PDF of \(x_1\) and \(x_2\) at the intersection of the dashed line. Let’s see what is the answer we get from the theory. We need to calculate the mean and variance of \(x_1\) conditional on observing \(x_2\). Because \(x_1\) is one dimensional, it is very simple to implement the formula we have above.

Sigma11 = X.cov[0, 0]
Sigma12 = X.cov[0, 1]
Sigma22 = X.cov[1,1]

mu1 = X.mean[0]
mu2 = X.mean[1]

mu1_cond = mu1 + Sigma12 * (x2_observed - mu2) / Sigma22

Sigma11_cond = Sigma11 - Sigma12 ** 2 / Sigma22

print(f"x_1 | x_2 ~ N(mu = {mu1_cond:.2f}, sigma^2 = {Sigma11_cond:.2f})")
x_1 | x_2 ~ N(mu = 0.32, sigma^2 = 1.80)

Let’s plot this conditional pdf for \(x_1\) and compare it to its marginal pdf:

Hide code cell source
X1_cond = st.norm(
    loc=mu1_cond,
    scale=np.sqrt(Sigma11_cond)
)
X1_marg = st.norm(
    loc=X.mean[0],
    scale=np.sqrt(Sigma11)
)
fig, ax = plt.subplots()
ax.plot(
    x1,
    X1_marg.pdf(x1),
    label=r"$p(x_1)$"
)
ax.plot(
    x1,
    X1_cond.pdf(x1),
    label=f"$p(x_1|x_2={x2_observed:1.2f})$"
)
ax.set_xlabel(r"$x_1$")
ax.set_ylabel("Probability")
plt.legend(loc="best", frameon=False)
sns.despine(trim=True);
../_images/281c9df6ea0eaf953ac77e4c67a7854ea685870c82004b72e855b5fd47070bd9.svg

This is our first example of how Bayes’ rule can be used to condition on observations. In the plot above, you can think of \(p(x_1)\) as your state of knowledge about \(x_1\) before you observe \(x_2\). Because \(x_1\) and \(x_2\) are correlated, your state of knowledge about \(x_1\) changes after you observe \(x_2\). This is captured by the conditional \(p(x_1|x_2)\).

Questions#

  • Rerun the code above multiple times to see how the conditinal PDF moves around as other points are picked randomly.

  • Modify the code so that you get the conditional PDF of \(X_2\) given \(X_1=x_1\).