Clustering using k-means

Contents

Hide code cell source
MAKE_BOOK_FIGURES=Trueimport numpy as npimport scipy.stats as stimport matplotlib as mplimport matplotlib.pyplot as plt%matplotlib inlineimport matplotlib_inlinematplotlib_inline.backend_inline.set_matplotlib_formats('svg')import seaborn as snssns.set_context("paper")sns.set_style("ticks")def set_book_style():    plt.style.use('seaborn-v0_8-white')     sns.set_style("ticks")    sns.set_palette("deep")    mpl.rcParams.update({        # Font settings        'font.family': 'serif',  # For academic publishing        'font.size': 8,  # As requested, 10pt font        'axes.labelsize': 8,        'axes.titlesize': 8,        'xtick.labelsize': 7,  # Slightly smaller for better readability        'ytick.labelsize': 7,        'legend.fontsize': 7,                # Line and marker settings for consistency        'axes.linewidth': 0.5,        'grid.linewidth': 0.5,        'lines.linewidth': 1.0,        'lines.markersize': 4,                # Layout to prevent clipped labels        'figure.constrained_layout.use': True,                # Default DPI (will override when saving)        'figure.dpi': 600,        'savefig.dpi': 600,                # Despine - remove top and right spines        'axes.spines.top': False,        'axes.spines.right': False,                # Remove legend frame        'legend.frameon': False,                # Additional trim settings        'figure.autolayout': True,  # Alternative to constrained_layout        'savefig.bbox': 'tight',    # Trim when saving        'savefig.pad_inches': 0.1   # Small padding to ensure nothing gets cut off    })def set_notebook_style():    plt.style.use('seaborn-v0_8-white')    sns.set_style("ticks")    sns.set_palette("deep")    mpl.rcParams.update({        # Font settings - using default sizes        'font.family': 'serif',        'axes.labelsize': 10,        'axes.titlesize': 10,        'xtick.labelsize': 9,        'ytick.labelsize': 9,        'legend.fontsize': 9,                # Line and marker settings        'axes.linewidth': 0.5,        'grid.linewidth': 0.5,        'lines.linewidth': 1.0,        'lines.markersize': 4,                # Layout settings        'figure.constrained_layout.use': True,                # Remove only top and right spines        'axes.spines.top': False,        'axes.spines.right': False,                # Remove legend frame        'legend.frameon': False,                # Additional settings        'figure.autolayout': True,        'savefig.bbox': 'tight',        'savefig.pad_inches': 0.1    })def save_for_book(fig, filename, is_vector=True, **kwargs):    """    Save a figure with book-optimized settings.        Parameters:    -----------    fig : matplotlib figure        The figure to save    filename : str        Filename without extension    is_vector : bool        If True, saves as vector at 1000 dpi. If False, saves as raster at 600 dpi.    **kwargs : dict        Additional kwargs to pass to savefig    """        # Set appropriate DPI and format based on figure type    if is_vector:        dpi = 1000        ext = '.pdf'    else:        dpi = 600        ext = '.tif'        # Save the figure with book settings    fig.savefig(f"{filename}{ext}", dpi=dpi, **kwargs)def make_full_width_fig():    return plt.subplots(figsize=(4.7, 2.9), constrained_layout=True)def make_half_width_fig():    return plt.subplots(figsize=(2.35, 1.45), constrained_layout=True)if MAKE_BOOK_FIGURES:    set_book_style()else:    set_notebook_style()make_full_width_fig = make_full_width_fig if MAKE_BOOK_FIGURES else lambda: plt.subplots()make_half_width_fig = make_half_width_fig if MAKE_BOOK_FIGURES else lambda: plt.subplots()

Clustering using k-means#

K-means is the most straightforward algorithm for splitting the dataset \(\mathbf{x}_{1:n}\) in \(K\) clusters. If you want to study the algorithm independently, I suggest reading Chapter 20.1, D. MacKay (2003).

In K-means the clusters are defined by their centroids \(\mathbf{m}_{1:K}\), which are the means of the data points assigned to the cluster. Each observation \(\mathbf{x}_i\) is assigned to the cluster with the closest centroid. We can write this assignment as a one-hot encoding \(\mathbf{z}_i\) of the cluster index \(k\):

\[\begin{split} \mathbf{z}_i = \begin{bmatrix} 0 \\ \vdots \\ 1 \\ \vdots \\ 0 \end{bmatrix} \in \mathbb{R}^K, \quad \text{where} \quad z_{ik} = \begin{cases} 1 & \text{if } k = \arg\min_{k'} \|\mathbf{x}_i - \mathbf{m}_{k'}\|^2 \\ 0 & \text{otherwise} \end{cases} \end{split}\]

The centroids are found by minimizing the sum of squared distances between the data points and their assigned centroids:

\[ \min_{\mathbf{m}_{1:K}} \sum_{i=1}^n \sum_{k=1}^K z_{ik} \|\mathbf{x}_i - \mathbf{m}_k\|^2 \]

The algorithm starts by initializing the centroids randomly. Then, it iterates between two steps until convergence:

  1. Assign each data point to the cluster with the closest centroid.

  2. Update the centroids to the mean of the data points assigned to the cluster. Convergence is checked by comparing the centroids of the current and previous iteration. The algorithm is guaranteed to converge, but it may converge to a local minimum.

There is a nice visualization of the algorithm by Naftali Harris.

Example#

Let’s start by generating a synthetic dataset using with three clusters:

np.random.seed(123456)

# Make synthetic dataset for clustering
num_clusters_true = 3
# The means of each cluster
mu_true = 3.0 * np.random.randn(num_clusters_true, 2)
# The variance of the observations around the cluster
sigma_true = 0.5
# How many observations to generate per cluster
num_obs_cluster = [50, 50, 50]

# Generate the data
data = []
for i in range(num_clusters_true):
    x_i = mu_true[i] + sigma_true * np.random.randn(num_obs_cluster[i], 2)
    data.append(x_i)
data = np.vstack(data)
# Permute the data so that order info is lost
data = np.random.permutation(data)

Now let’s visualize the data forgetting about the underlying clusers that gave rise to them.

Hide code cell source
fig, ax = plt.subplots()
ax.plot(data[:, 0], data[:, 1], '.')
ax.set_xlabel('$x_1$')
ax.set_ylabel('$x_2$')
sns.despine(trim=True);
../_images/4874a6aef1d582b2fe033a6341bc4a176490d067a1428475e0d84a404a707cae.svg

Let’s apply K-means to the data:

from sklearn.cluster import KMeans

model = KMeans(n_clusters=3, n_init='auto').fit(data)

Here is how you can access the cluster centers (the \(\mu_k\)’s) from the trained model:

model.cluster_centers_
array([[ 3.55976184, -0.43723791],
       [-4.57184776, -3.2974181 ],
       [ 1.35520175, -0.91177514]])

Compare the identified cluster centers to the actual cluster centers:

mu_true
array([[ 1.4073369 , -0.84859003],
       [-4.52717551, -3.40689711],
       [ 3.63633608, -0.51964395]])

K-means has also labeled each observation point with its cluster id. Here is how to get this info:

model.labels_
array([2, 2, 1, 1, 1, 1, 0, 1, 2, 0, 0, 2, 1, 1, 2, 0, 0, 2, 1, 1, 0, 0,
       1, 1, 2, 2, 0, 2, 2, 2, 0, 0, 2, 1, 0, 2, 1, 2, 0, 0, 2, 0, 2, 0,
       0, 0, 1, 2, 1, 0, 0, 0, 1, 0, 1, 1, 2, 0, 2, 1, 2, 2, 0, 2, 0, 0,
       1, 0, 1, 1, 2, 2, 1, 1, 1, 2, 1, 0, 2, 1, 2, 2, 0, 1, 2, 1, 1, 2,
       0, 0, 2, 2, 0, 1, 2, 0, 2, 2, 0, 2, 0, 1, 1, 0, 1, 2, 1, 1, 0, 1,
       1, 2, 2, 2, 1, 2, 2, 1, 0, 2, 0, 1, 1, 0, 2, 0, 0, 1, 1, 1, 1, 1,
       1, 0, 0, 0, 2, 2, 0, 2, 0, 2, 1, 0, 2, 0, 2, 1, 2, 2], dtype=int32)

Since we have 2D observations, we can actually visualize the clusters. Here is a nice way to do this:

labels = model.predict(data)
fig, ax = plt.subplots()
plt.scatter(data[:, 0], data[:, 1], c=labels)
ax.set_xlabel('$x_1$')
ax.set_ylabel('$x_2$')
sns.despine(trim=True);
../_images/12a8dbe8a918a6d4981ee0b09903e9338364309e36d04ee91ad17b3291b9a7be.svg

Okay, this seems to work perfectly. However, notice that we asked K-means to find three clusters which happens to be the true number of clusters in our dataset. What would happen if we had asked K-means to find a larger number of clusters, say 5? Here it is:

model5 = KMeans(n_clusters=5, n_init='auto').fit(data)

labels = model5.predict(data)

fig, ax = plt.subplots()
plt.scatter(data[:, 0], data[:, 1], c=labels)
ax.set_xlabel('$x_1$')
ax.set_ylabel('$x_2$')
sns.despine(trim=True);
../_images/70ed5394c6ad9535615d7cc139b5ecb65b5b5681b9252694261e125373e52509.svg

Questions#

  • We saw what happens when you ask K-means to find more clusters than there exist. What would happen if you asked it to find fewer clusters? Try \(K=1\) and \(K=2\) in the code block immediately above. What do you observe? Can choose between \(K=1, 2,\) or \(3\)?

  • Rerun the entire example from the first code block, but set the number of true clusters to 6 this time. Investigate what happens when you try to fit K-means with a minimal number of clusters, what happens when you pick \(K\) to be around 6, and what happens when you pick a very big \(K\), say 10.