Hide code cell source
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks")

Dimensionality Reduction Examples#

Throughout this lecture, we will use the MNIST dataset. The MNIST dataset consists of thousands of images of handwritten digits from \(0\) to \(1\). The dataset is a standard benchmark in machine learning. Here is how to get the dataset from the sklearn library.

from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split

# Load data from https://www.openml.org/d/554
X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)
X = X / 255.0

# Split data into train partition and test partition
np.random.seed(12345)
x_train, x_test, y_train, y_test = train_test_split(X, y, random_state=0, test_size=0.3)
/opt/homebrew/lib/python3.11/site-packages/sklearn/datasets/_openml.py:968: FutureWarning: The default value of `parser` will change from `'liac-arff'` to `'auto'` in 1.4. You can set `parser='auto'` to silence this warning. Therefore, an `ImportError` will be raised from 1.4 if the dataset is dense and pandas is not installed. Note that the pandas parser may return different data types. See the Notes Section in fetch_openml's API doc for details.
  warn(

The dataset comes with inputs (images of digits) and labels (the label of the digit). We will not use the labels in this lecture as we will be doing unsupervised learning. Let’s look at the dimensions of the training dataset:

x_train.shape
(49000, 784)

The training dataset is a 3D array. The first dimension is 49,0000. This is the number of different images that we have. Then, each image consists of 28x28 pixels. Here is the first image in terms of numbers:

x_train[0]
Hide code cell output
array([0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.33333333, 1.        , 0.09019608, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.14901961, 0.96862745,
       0.99607843, 0.09019608, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.18431373, 0.89411765, 0.95686275, 0.18039216, 0.01568627,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.6745098 , 0.9254902 ,
       0.17254902, 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.10196078,
       0.73333333, 0.94509804, 0.2627451 , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.18431373, 0.70588235, 0.94509804, 0.27058824,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.48235294, 0.96078431,
       0.82352941, 0.2627451 , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.46666667, 0.96862745, 0.82352941, 0.27058824, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.02745098, 0.53333333, 0.99607843, 0.94509804,
       0.2627451 , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.03529412, 0.51764706,
       0.99215686, 0.94509804, 0.27058824, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.64313725, 0.99215686, 0.95686275, 0.2627451 ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.10980392, 0.45490196, 0.45490196, 0.71372549, 0.08235294,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.33333333, 0.96862745,
       0.95686275, 0.27058824, 0.        , 0.        , 0.        ,
       0.        , 0.10196078, 0.55686275, 0.94117647, 0.99215686,
       0.99607843, 0.99215686, 0.57254902, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.02745098,
       0.40784314, 0.99607843, 0.78039216, 0.06666667, 0.        ,
       0.        , 0.14901961, 0.43921569, 0.76470588, 0.86666667,
       0.99607843, 0.99607843, 0.99607843, 0.99607843, 0.99607843,
       0.80392157, 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.29019608, 0.99215686, 0.94509804,
       0.10588235, 0.        , 0.        , 0.        , 0.74117647,
       0.99215686, 0.99215686, 0.94509804, 0.48235294, 0.52941176,
       0.62745098, 0.99607843, 0.82352941, 0.11372549, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.08627451,
       0.56470588, 0.95686275, 0.2627451 , 0.        , 0.        ,
       0.        , 0.        , 0.49411765, 0.5372549 , 0.27843137,
       0.13333333, 0.08627451, 0.63137255, 0.99215686, 0.67058824,
       0.11372549, 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.31372549, 0.99215686, 0.36862745,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.4       , 0.7254902 , 0.96862745,
       0.9254902 , 0.27058824, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.2745098 ,
       0.99607843, 0.80392157, 0.        , 0.        , 0.        ,
       0.        , 0.2       , 0.40784314, 0.6       , 0.92941176,
       0.89411765, 0.7254902 , 0.5254902 , 0.06666667, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.2745098 , 0.99215686, 0.49411765,
       0.18431373, 0.18431373, 0.50980392, 0.83921569, 0.96862745,
       0.99607843, 0.89019608, 0.64313725, 0.28627451, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.2745098 , 0.99215686, 0.99215686, 0.99215686, 0.99607843,
       0.99215686, 0.9254902 , 0.9254902 , 0.31372549, 0.08235294,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.0745098 , 0.79607843,
       0.99215686, 0.99215686, 0.8627451 , 0.65882353, 0.1372549 ,
       0.0745098 , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        ])

Each number corresponds to the pixel value. Say zero is a white pixel, and 255 is a black pixel. Values between 0 and 255 correspond to some shade of gray. Here is how to visualize the first image:

plt.imshow(x_train[0].reshape(28, 28), cmap=plt.cm.gray_r, interpolation='nearest');
../_images/9b477b1c098697bdb8601d6fd9e2b3e736427f1b337a24894d985d6a78f586a7.svg

I want to work with just images of threes. So, let me keep all the threes and throw away all other data:

threes = x_train[y_train == '3']
threes.shape
(5024, 784)

Okay. We now have a few thousand vectors, each with 784 dimensions. That is our dataset. Let’s apply PCA to it to reduce its dimensionality. We will use the PCA class of scikit-learn. Here is how to import the class:

from sklearn.decomposition import PCA

Here is how to initialize the model and fit it to the data:

pca = PCA(n_components=0.98, whiten=True).fit(threes)

See its documentation for the complete definition of the inputs to the PCA class. The particular parameters that I define above have the following effect:

  • n_components: If you set this to an integer, the PCA will have this many components. If you set it to a number between \(0\) and \(1\), say 0.98, then PCA will keep as many components as it needs to capture 98% of the variance of the data. I use the second type of input.

  • whiten: This ensures that the projections have unit variance. If you don’t specify this, their variance will be the corresponding eigenvalue. Setting whiten=True is consistent with the theory developed in the video.

Okay, so now that the model is trained, let’s investigate it. First, we asked PCA to keep enough components to describe 98% of the variance. How many did it keep? Here is how to check this:

pca.n_components_
225

It kept 227 components. This looks like it could be more impressive, but we will take it for now.

Now, let’s focus on the eigenvalues of the covariance matrix. Here is how to get them:

fig, ax = plt.subplots()
ax.plot(pca.explained_variance_)
ax.set_xlabel('$i$')
ax.set_ylabel(r'$\lambda_i$')
sns.despine(trim=True);
../_images/09eb21412e48ea48b2b5d322fe568dc832a60227a5fc0f3d2523459a920ba0ca.svg

Remember that the sum of the first \(k\) eigenvalues, \(\sum_{i=1}^k\lambda_i\), tells you how much variance is explained with a model that keeps the first \(k\) PCA components.

Okay. As we discussed in the lecture videos, each of the observations expanded as follows:

\[ \mathbf{x}_j = \mathbf{m} + \sum_{i=1}^kz_{ji}\sqrt{\lambda}_i\mathbf{v}_i. \]

Let’s visualize first the mean \(\mathbf{m}\). It is this vector:

pca.mean_.shape
(784,)

So let’s reshape it and plot it as an image:

def show_digit_image(data):
    """Show a digit as an image.
    
    Arguments
    data -- The image data.
    """
    fig, ax = plt.subplots()
    ax.imshow(data.reshape((28, 28)), cmap=plt.cm.gray_r, interpolation='nearest')
    ax.set_xticks([])
    ax.set_xticklabels([])
    ax.set_yticks([])
    ax.set_yticklabels([]);

show_digit_image(pca.mean_)
../_images/715f10d512766f222bd7b34fdd08e6fd683d070df97f423777332b657a88e101.svg

Now let’s go for the eigenvectors \(\mathbf{v}_i\). Here is where the are:

pca.components_.shape
(225, 784)

and here is how to visualize them as images:

for i in range(5):
    show_digit_image(
        pca.components_[i, :]
    )
Hide code cell output
../_images/b0e63985d23cf28e406374b524054a750ac28ccca44d556b099c544a07975f0c.svg../_images/c9711e03f558d8328361bcfdc761ff8c868606fe8ba75015adafee0c1346f8c1.svg../_images/5d930bc051a5deac94fb44d797a556b4f6c34d21ab8fcac46f8f4dd77aaafb56.svg../_images/7e5dec0a11c626bc4f1d2d928dfba87b25040b9bf7b86dce63ebb03f8f2bd17a.svg../_images/983653dd90a9dc2b86dd2f5d3cfa418c0ff30aa3f18da79cb748b3ddfaa1d643.svg

Let’s visualize the first two principal components \(\mathbf{z}_j\) of each observation \(\mathbf{x}_j\). This will project the dataset from 784 dimensions to two dimensions. Here is how to find the principal components:

Z = pca.transform(threes)
Z.shape
(5024, 225)

Visualize the first two:

fig, ax = plt.subplots()
ax.scatter(Z[:, 0], Z[:, 1])
ax.set_xlabel('$z_1$')
ax.set_ylabel('$z_2$')
sns.despine(trim=True);
../_images/8c4c6936234552880fce942d4eef6680c34cf9b8685dd2d74b0cf1db65f4a5e1.svg

Alright! Each dot in this plot corresponds to an image of a 3. This is nice, but there are better things we can do regarding visualization. Let’s plot the actual image instead of a dot. Here is how to do this:

Hide code cell source
# The following code is a modification of the code found here:
# https://stackoverflow.com/questions/35651932/plotting-img-with-matplotlib
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from matplotlib.cbook import get_sample_data

def imscatter(
    x,
    y,
    images,
    cmap=plt.cm.gray_r,
    ax=None,
    zoom=1
):
    """Do a scatter plot with images instead of points.
    
    Arguments
    x      -- The x coordinates.
    y      -- The y coordinates.
    images -- The images. Must be of shape (x.shape[0], d, d).
    
    Keyword Arguments
    cmap   -- A color map.
    ax     -- An axes object to plot on.
    zoom   -- How much to zoom.
    """
    x, y = np.atleast_1d(x, y)
    artists = []
    for x0, y0, image in zip(x, y, images):
        im = OffsetImage(
            image,
            zoom=zoom,
            cmap=cmap,
            interpolation='nearest'
        )
        ab = AnnotationBbox(
            im,
            (x0, y0),
            xycoords='data',
            frameon=False
        )
        artists.append(ax.add_artist(ab))
    ax.update_datalim(np.column_stack([x, y]))
    ax.autoscale()
    return artists

Here it is:

fig, ax = plt.subplots()
imscatter(
    Z[:, 0],
    Z[:, 1],
    threes.reshape((threes.shape[0], 28, 28)),
    ax=ax,
    zoom=0.2
)
ax.set_xlabel('$z_1$')
ax.set_ylabel('$z_2$')
sns.despine(trim=True);
../_images/9140b927f41e78bea16c88c73d14b41aa876595f5c496f9d26fd59dda6f1ef95.svg

In this plot, you can see the interpretation of the principal components. The first principal component seems to rotate the three about an axis coming out of the screen. The second principal component changes the thickness of the bottom of the three.

Questions#

  • Keeping the index \(i=0\) fixed, play with the corresponding \(z\) and observe that it rotates the three.

  • Change \(i\) to \(1,2\) and \(3\) and study the effect of the corresponding principal component.

Now, we will study the reconstruction error for the validation dataset. First, throw everything that is not a three:

valid_threes = x_test[y_test == '3']
valid_threes.shape
(2117, 784)

We have about two thousand images for validation. Project all the validation points:

Z_valid = pca.transform(valid_threes)

And then reconstruct them and compare them.

# Reconstruct image
idx = 1
n_components = 1
x = pca.inverse_transform(
    np.hstack(
        [
            Z_valid[idx][:n_components],
            np.zeros((Z_valid.shape[1] - n_components,))
        ]
    )
)

# The original image
show_digit_image(valid_threes[idx])

# The reconstructed image
show_digit_image(x)
../_images/46ca89bcf354c92c510e7dcf80de8dfbb6e0f3fa46c18ea77c77723198a82006.svg../_images/4644e4bfd54853e270f374bbd39de4e58e1bc92959f8825702e714174752800a.svg

Questions#

  • Play with the code block above, increasing n_components to 2, 4, 8, and so on up to 227. Observe how the reconstruction becomes better (but not perfect).

  • Repeat the above question, but change also the idx variable so that you see some more examples of three.

  • Go back a few code blocks, and change your validation set to include only fives. It would help if you changed this:

valid_threes = x_test[y_test == '3']
valid_threes.shape

to this:

valid_threes = x_test[y_test == '5']
valid_threes.shape

Don’t bother renaming valid_threes. Can the PCA model constructed with threes describe 5s? Why yes or why not?

  • Repeat the previous question with a couple of other digits.