Show code cell source
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks");
The Multivariate Normal - Conditioning#
Consider the \(N\)-dimensional multivariate normal:
where \(\boldsymbol{\mu}\) is a \(N\)-dimensional vector, \(\boldsymbol{\Sigma}\) is a positive-definite matrix.
Now split \(\mathbf{X}\) into two vectors \(\mathbf{X}_1\) and \(\mathbf{X}_2\) of dimensions \(N_1\) and \(N_2\) (\(N_1 + N_2 = N\)):
Similarly, split \(\boldsymbol{\mu}\) into two vectors \(\boldsymbol{\mu}_1\) and \(\boldsymbol{\mu}_2\) of dimensions \(N_1\) and \(N_2\) (\(N_1 + N_2 = N\)):
Similarly for \(\boldsymbol{\Sigma}\):
where \(\boldsymbol{\Sigma}_{ii}\) are \(N_i\times N_i\) matrices, and \(\boldsymbol{\Sigma}_{12}\) is a \(N_1\times N_2\) matrix.
Using marginalization, we can show that:
and
Suppose now we know the value of \(\mathbf{X}_2\), i.e., \(\mathbf{X}_2 = \mathbf{x}_2\). What is the distribution of \(\mathbf{X}_1\)? To do this, we apply Bayes’ rule:
We have all the required terms on the right-hand side. We need to substitute and do the algebra. If we do it and we use the “complete the square” trick, we get:
where
and
Note
More details If you want to see the details, read Chapter 2.3 of [Bishop, 2006].
Let’s demonstrate this with an example.
Show code cell source
import numpy as np
import scipy.stats as st
# This is the multivariate normal we are going to play with
X = st.multivariate_normal(
mean=np.array([1.0, 2.0]),
cov=np.array(
[
[2.0, 0.9],
[0.9, 4.0]
]
)
)
print("X ~ N(mu, Sigma),")
print(f"mu = {X.mean}")
print("Sigma = ")
print(X.cov)
print("")
x2_observed = -1.0
print(f"x_2 = {x2_observed:.2f} (hypothetical observation)")
X ~ N(mu, Sigma),
mu = [1. 2.]
Sigma =
[[2. 0.9]
[0.9 4. ]]
x_2 = -1.00 (hypothetical observation)
Let’s plot the contour of the joint and see where \(x_2\) falls:
Show code cell source
fig, ax = plt.subplots()
x1 = np.linspace(-3, 5, 64)
x2 = np.linspace(-3, 6, 64)
Xg1, Xg2 = np.meshgrid(x1, x2)
Xg_flat = np.hstack(
[
Xg1.flatten()[:, None],
Xg2.flatten()[:, None]
]
)
Z = X.pdf(Xg_flat).reshape(Xg1.shape)
c = ax.contour(Xg1, Xg2, Z)
ax.plot(
x1,
[x2_observed] * np.ones(x1.shape[0]),
"--",
label=r"Observed $x_2$")
ax.clabel(c, inline=1, fontsize=10)
plt.legend(loc="best", frameon=False)
ax.set_xlabel('$x_1$')
ax.set_ylabel('$x_2$')
sns.despine(trim=True);
Intuitively, the probability density of getting a particular value \(x_1\) is proportional to the joint PDF of \(x_1\) and \(x_2\) at the intersection of the dashed line. Let’s see what is the answer we get from the theory. We need to calculate the mean and variance of \(x_1\) conditional on observing \(x_2\). Because \(x_1\) is one dimensional, it is very simple to implement the formula we have above.
Sigma11 = X.cov[0, 0]
Sigma12 = X.cov[0, 1]
Sigma22 = X.cov[1,1]
mu1 = X.mean[0]
mu2 = X.mean[1]
mu1_cond = mu1 + Sigma12 * (x2_observed - mu2) / Sigma22
Sigma11_cond = Sigma11 - Sigma12 ** 2 / Sigma22
print(f"x_1 | x_2 ~ N(mu = {mu1_cond:.2f}, sigma^2 = {Sigma11_cond:.2f})")
x_1 | x_2 ~ N(mu = 0.32, sigma^2 = 1.80)
Let’s plot this conditional pdf for \(x_1\) and compare it to its marginal pdf:
Show code cell source
X1_cond = st.norm(
loc=mu1_cond,
scale=np.sqrt(Sigma11_cond)
)
X1_marg = st.norm(
loc=X.mean[0],
scale=np.sqrt(Sigma11)
)
fig, ax = plt.subplots()
ax.plot(
x1,
X1_marg.pdf(x1),
label=r"$p(x_1)$"
)
ax.plot(
x1,
X1_cond.pdf(x1),
label=f"$p(x_1|x_2={x2_observed:1.2f})$"
)
ax.set_xlabel(r"$x_1$")
ax.set_ylabel("Probability")
plt.legend(loc="best", frameon=False)
sns.despine(trim=True);
This is our first example of how Bayes’ rule can be used to condition on observations. In the plot above, you can think of \(p(x_1)\) as your state of knowledge about \(x_1\) before you observe \(x_2\). Because \(x_1\) and \(x_2\) are correlated, your state of knowledge about \(x_1\) changes after you observe \(x_2\). This is captured by the conditional \(p(x_1|x_2)\).
Questions#
Rerun the code above multiple times to see how the conditinal PDF moves around as other points are picked randomly.
Modify the code so that you get the conditional PDF of \(X_2\) given \(X_1=x_1\).