Hide code cell source
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks")

Clustering using k-means#

K-means is the most straightforward algorithm for splitting the dataset \(\mathbf{x}_{1:n}\) in \(K\) clusters. If you want to study the algorithm independently, I suggest reading Chapter 20.1, D. MacKay (2003).

In K-means the clusters are defined by their centroids \(\mathbf{m}_{1:K}\), which are the means of the data points assigned to the cluster. Each observation \(\mathbf{x}_i\) is assigned to the cluster with the closest centroid. We can write this assignment as a one-hot encoding \(\mathbf{z}_i\) of the cluster index \(k\):

\[\begin{split} \mathbf{z}_i = \begin{bmatrix} 0 \\ \vdots \\ 1 \\ \vdots \\ 0 \end{bmatrix} \in \mathbb{R}^K, \quad \text{where} \quad z_{ik} = \begin{cases} 1 & \text{if } k = \arg\min_{k'} \|\mathbf{x}_i - \mathbf{m}_{k'}\|^2 \\ 0 & \text{otherwise} \end{cases} \end{split}\]

The centroids are found by minimizing the sum of squared distances between the data points and their assigned centroids:

\[ \min_{\mathbf{m}_{1:K}} \sum_{i=1}^n \sum_{k=1}^K z_{ik} \|\mathbf{x}_i - \mathbf{m}_k\|^2 \]

The algorithm starts by initializing the centroids randomly. Then, it iterates between two steps until convergence:

  1. Assign each data point to the cluster with the closest centroid.

  2. Update the centroids to the mean of the data points assigned to the cluster. Convergence is checked by comparing the centroids of the current and previous iteration. The algorithm is guaranteed to converge, but it may converge to a local minimum.

There is a nice visualization of the algorithm by Naftali Harris.

Example#

Let’s start by generating a synthetic dataset using with three clusters:

np.random.seed(123456)

# Make synthetic dataset for clustering
num_clusters_true = 3
# The means of each cluster
mu_true = 3.0 * np.random.randn(num_clusters_true, 2)
# The variance of the observations around the cluster
sigma_true = 0.5
# How many observations to generate per cluster
num_obs_cluster = [50, 50, 50]

# Generate the data
data = []
for i in range(num_clusters_true):
    x_i = mu_true[i] + sigma_true * np.random.randn(num_obs_cluster[i], 2)
    data.append(x_i)
data = np.vstack(data)
# Permute the data so that order info is lost
data = np.random.permutation(data)

Now let’s visualize the data forgetting about the underlying clusers that gave rise to them.

Hide code cell source
fig, ax = plt.subplots()
ax.plot(data[:, 0], data[:, 1], '.')
ax.set_xlabel('$x_1$')
ax.set_ylabel('$x_2$')
sns.despine(trim=True);
../_images/4874a6aef1d582b2fe033a6341bc4a176490d067a1428475e0d84a404a707cae.svg

Let’s apply K-means to the data:

from sklearn.cluster import KMeans

model = KMeans(n_clusters=3, n_init='auto').fit(data)

Here is how you can access the cluster centers (the \(\mu_k\)’s) from the trained model:

model.cluster_centers_
array([[ 3.55976184, -0.43723791],
       [-4.57184776, -3.2974181 ],
       [ 1.35520175, -0.91177514]])

Compare the identified cluster centers to the actual cluster centers:

mu_true
array([[ 1.4073369 , -0.84859003],
       [-4.52717551, -3.40689711],
       [ 3.63633608, -0.51964395]])

K-means has also labeled each observation point with its cluster id. Here is how to get this info:

model.labels_
array([2, 2, 1, 1, 1, 1, 0, 1, 2, 0, 0, 2, 1, 1, 2, 0, 0, 2, 1, 1, 0, 0,
       1, 1, 2, 2, 0, 2, 2, 2, 0, 0, 2, 1, 0, 2, 1, 2, 0, 0, 2, 0, 2, 0,
       0, 0, 1, 2, 1, 0, 0, 0, 1, 0, 1, 1, 2, 0, 2, 1, 2, 2, 0, 2, 0, 0,
       1, 0, 1, 1, 2, 2, 1, 1, 1, 2, 1, 0, 2, 1, 2, 2, 0, 1, 2, 1, 1, 2,
       0, 0, 2, 2, 0, 1, 2, 0, 2, 2, 0, 2, 0, 1, 1, 0, 1, 2, 1, 1, 0, 1,
       1, 2, 2, 2, 1, 2, 2, 1, 0, 2, 0, 1, 1, 0, 2, 0, 0, 1, 1, 1, 1, 1,
       1, 0, 0, 0, 2, 2, 0, 2, 0, 2, 1, 0, 2, 0, 2, 1, 2, 2], dtype=int32)

Since we have 2D observations, we can actually visualize the clusters. Here is a nice way to do this:

labels = model.predict(data)
fig, ax = plt.subplots()
plt.scatter(data[:, 0], data[:, 1], c=labels)
ax.set_xlabel('$x_1$')
ax.set_ylabel('$x_2$')
sns.despine(trim=True);
../_images/12a8dbe8a918a6d4981ee0b09903e9338364309e36d04ee91ad17b3291b9a7be.svg

Okay, this seems to work perfectly. However, notice that we asked K-means to find three clusters which happens to be the true number of clusters in our dataset. What would happen if we had asked K-means to find a larger number of clusters, say 5? Here it is:

model5 = KMeans(n_clusters=5, n_init='auto').fit(data)

labels = model5.predict(data)

fig, ax = plt.subplots()
plt.scatter(data[:, 0], data[:, 1], c=labels)
ax.set_xlabel('$x_1$')
ax.set_ylabel('$x_2$')
sns.despine(trim=True);
../_images/70ed5394c6ad9535615d7cc139b5ecb65b5b5681b9252694261e125373e52509.svg

Questions#

  • We saw what happens when you ask K-means to find more clusters than there exist. What would happen if you asked it to find fewer clusters? Try \(K=1\) and \(K=2\) in the code block immediately above. What do you observe? Can choose between \(K=1, 2,\) or \(3\)?

  • Rerun the entire example from the first code block, but set the number of true clusters to 6 this time. Investigate what happens when you try to fit K-means with a minimal number of clusters, what happens when you pick \(K\) to be around 6, and what happens when you pick a very big \(K\), say 10.