Show code cell source
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks");
Sampling the categorical#
We now show how to sample from a categorical distribution using samples from a uniform distribution. We start with the Bernoulli distribution, a particular case of the categorical distribution. Then, we show how to sample from the categorical distribution in general.
Sampling the Bernoulli distribution#
The Bernoulli distribution arises from a binary random variable representing the outcome of an experiment with a given probability of success. Let us encode success with 1 and failure with 0. It is a particular case of the Categorical (2 labels). Then, we say that the random variable
is a Bernoulli random variable with parameter \(\theta\) if:
To sample from it, we do the following steps:
Sample a uniform number \(u\) (i.e., a number of \(U([0,1])\)).
If \(u\le \theta\), then set \(x = 1\).
Otherwise, set \(x = 0\).
Proof
How do we know that this works? Let us compute the probability of \(X=1\). We start by employing the sum rule:
We know that \(p(u)=1\). Also, following the definition of our algorithm, we have that \(p(x=1|u)=1\) if \(u\le \theta\) and \(p(x=1|u)=0\) otherwise. So, the integral becomes:
Well, this is what we wanted.
Let’s test numerically if this process produces the desired result. Here is the code:
Show code cell source
import numpy as np
def sample_bernoulli(theta : float):
"""Sample from the Bernoulli.
Arguments:
theta -- The probability of success.
"""
u = np.random.rand()
if u <= theta:
return 1
return 0
And here is how to use it:
for _ in range(10):
print(sample_bernoulli(0.5))
1
1
0
1
1
0
1
0
0
0
Let’s do a histogram of a huge number of samples:
Show code cell source
N = 1000
X = np.array(
[sample_bernoulli(0.3) for _ in range(N)]
)
fig, ax = plt.subplots()
ax.hist(X, alpha=0.5)
ax.set_xlabel(r"$x$")
ax.set_ylabel(r"$p(x)$")
sns.despine(trim=True);
Ok, it looks fine. About \(\theta N\) samples went to 1 and \((1-\theta)N\) samples went to 0.
Of course, we have already seen this implemented in scipy.stats. Here is a quick reminder of that code.
import scipy.stats as st
X = st.bernoulli(0.3)
X.rvs(size=10)
array([1, 1, 0, 1, 0, 0, 0, 0, 0, 0])
Sampling the \(K\)-label Categorical#
Consider a generic discrete random variable \(X\) taking \(K\) different values. You may assume that these values are integers \(\{0, 1,2,\dots,K-1\}\) (they are just the labels of the discrete objects anyway).
The probability mass function of \(X\) is:
where, of course, we must have:
and
In any case, here is how you sample from such a distribution:
Draw a uniform sample \(u\).
Find the index \(j\in\{0,1,\dots,K-1\}\) such that:
Then, your sample is \(j\).
Why does this work?
The probability that \(u\) falls in the interval \([\sum_{k=0}^{j-1}p_k, \sum_{k=0}^jp_k)\) is:
So, it is exactly the probability that \(X=j\).
Let’s code it:
Show code cell source
def sample_categorical(p):
"""Sample from a discrete probability density.
Arguments:
p -- An array specifying the probability of each possible state.
The number of states ``m=len(p)``.
"""
K = len(p)
u = np.random.rand()
c = 0.
for j in range(K):
c += p[j]
if u <= c:
return j
Let’s test it with a four-state discrete random variable with probabilities:
p = [0.2, 0.3, 0.4, 0.1]
N = 100
X = np.array(
[sample_categorical(p) for _ in range(N)]
)
fig, ax = plt.subplots()
ax.hist(X, alpha=0.5)
ax.set_xlabel(r"$x$")
ax.set_ylabel(r"$p(x)$")
sns.despine(trim=True);
Of course, scipy.stats
already implements this functionality. Let’s compare.
Show code cell source
K = len(p)
X_st = st.rv_discrete(values=(np.arange(K), p))
x_st_samples = X_st.rvs(size=N)
# Let's compare the two histograms
fig, ax = plt.subplots()
ax.hist(X, alpha=0.5, label="Our implementation")
ax.hist(x_st_samples, alpha=0.5, label="Scipy.stats implementation")
ax.set_xlabel(r"$x$")
ax.set_ylabel(r"$p(x)$")
plt.legend(loc="best", frameon=False)
sns.despine(trim=True);
Questions#
It looks like there is a lot of variability every time you run the results. You need to go back to the code and increase the number of samples \(N\) until the results stop changing. Then you should be able to observe that our code does exactly the same thing as
scipy.stats.rv_discrete
.