Hide code cell source
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks")

def convert_seconds_to_hms(seconds):
    hours, remainder = divmod(seconds, 3600)
    minutes, seconds = divmod(remainder, 60)
    return int(hours), int(minutes), float(seconds)

!pip install equinox
Hide code cell output
Requirement already satisfied: equinox in /Users/holtw/Documents/mydocs/software/advanced-scientific-machine-learning/.venv/lib/python3.11/site-packages (0.11.7)
Requirement already satisfied: jax!=0.4.27,>=0.4.13 in /Users/holtw/Documents/mydocs/software/advanced-scientific-machine-learning/.venv/lib/python3.11/site-packages (from equinox) (0.4.33)
Requirement already satisfied: jaxtyping>=0.2.20 in /Users/holtw/Documents/mydocs/software/advanced-scientific-machine-learning/.venv/lib/python3.11/site-packages (from equinox) (0.2.34)
Requirement already satisfied: typing-extensions>=4.5.0 in /Users/holtw/Documents/mydocs/software/advanced-scientific-machine-learning/.venv/lib/python3.11/site-packages (from equinox) (4.12.2)
Requirement already satisfied: jaxlib<=0.4.33,>=0.4.33 in /Users/holtw/Documents/mydocs/software/advanced-scientific-machine-learning/.venv/lib/python3.11/site-packages (from jax!=0.4.27,>=0.4.13->equinox) (0.4.33)
Requirement already satisfied: ml-dtypes>=0.2.0 in /Users/holtw/Documents/mydocs/software/advanced-scientific-machine-learning/.venv/lib/python3.11/site-packages (from jax!=0.4.27,>=0.4.13->equinox) (0.5.0)
Requirement already satisfied: numpy>=1.24 in /Users/holtw/Documents/mydocs/software/advanced-scientific-machine-learning/.venv/lib/python3.11/site-packages (from jax!=0.4.27,>=0.4.13->equinox) (2.1.1)
Requirement already satisfied: opt-einsum in /Users/holtw/Documents/mydocs/software/advanced-scientific-machine-learning/.venv/lib/python3.11/site-packages (from jax!=0.4.27,>=0.4.13->equinox) (3.3.0)
Requirement already satisfied: scipy>=1.10 in /Users/holtw/Documents/mydocs/software/advanced-scientific-machine-learning/.venv/lib/python3.11/site-packages (from jax!=0.4.27,>=0.4.13->equinox) (1.14.1)
Requirement already satisfied: typeguard==2.13.3 in /Users/holtw/Documents/mydocs/software/advanced-scientific-machine-learning/.venv/lib/python3.11/site-packages (from jaxtyping>=0.2.20->equinox) (2.13.3)

Training a neural network on the GPU#

NOTE: For this notebook, you will need to have a GPU available. If your local computer does not have a JAX-compatible compatible GPU, you can use the GPUs provided by Google Colab Pro. See this guide on how to access GPUs on Google Colab.

Why GPUs?#

In non-trivial scientific machine learning tasks, we often need to use graphics processing units (GPUs) to accelerate computation. GPUs are a type of hardware that is optimized for parallel operations such as matrix multiplication, which is core to most machine learning algorithms.

In this notebook, we’ll compare GPU vs. CPU runtimes for training a convolutional neural network (CNN) to classify handwritten digits from the MNIST dataset.

Creating a CNN classifier for MNIST dataset#

Hide code cell source
# Hyperparameters

BATCH_SIZE = 64
LEARNING_RATE = 3e-4
MAX_EPOCHS = 2
MAX_STEPS = 100
PRINT_EVERY = 10
SEED = 5678

import jax.random as jrandom
key = jrandom.PRNGKey(SEED)

Let’s load the MNIST dataset:

import torch, torchvision

normalise_data = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,)),
    ]
)
train_dataset = torchvision.datasets.MNIST(
    "MNIST",
    train=True,
    download=True,
    transform=normalise_data,
)
test_dataset = torchvision.datasets.MNIST(
    "MNIST",
    train=False,
    download=True,
    transform=normalise_data,
)
trainloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True
)
testloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=True
)

The MNIST dataset consists of 28-by-28 pixel, grayscale images of handwritten digits (0-9) that look like this:

Hide code cell source
import numpy as np
import matplotlib.pyplot as plt

def show_images(images, labels):
    ncols = len(labels)
    nrows = int(np.ceil(len(images) / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 2, nrows * 2))
    for ax in axes.ravel():
        ax.axis("off")
    for ax, image, label in zip(axes.ravel(), images, labels):
        ax.imshow(image.squeeze(), cmap="gray")
        ax.set_title(label)
    plt.tight_layout()
# Visualize random digits
images, labels = zip(*[train_dataset[i] for i in range(5)])
show_images(images, labels)
../../_images/ead30b981ed6cd7846b25e9ac63782a9f2167e7f79c9d02ee9c11c9156037354.svg

Let’s define the CNN:

import jax
import jax.numpy as jnp
from jax import vmap
import equinox as eqx
from jaxtyping import Array, Float, Int, PyTree

class CNN(eqx.Module):
    """Convolutional neural network for classifying handwritten digits from the MNIST dataset."""
    layers: list

    def __init__(self, key):
        key1, key2, key3, key4 = jrandom.split(key, 4)
        self.layers = [
            eqx.nn.Conv2d(1, 32, kernel_size=3, key=key1),
            jax.nn.relu,
            eqx.nn.MaxPool2d(kernel_size=2),
            eqx.nn.Conv2d(32, 64, kernel_size=3, key=key2),
            jax.nn.relu,
            eqx.nn.MaxPool2d(kernel_size=2),
            jnp.ravel,
            eqx.nn.Linear(30976, 128, key=key3),
            jax.nn.relu,
            eqx.nn.Linear(128, 10, key=key4),
            jax.nn.log_softmax,
        ]

    def __call__(self, x: Float[Array, "1 28 28"]) -> Float[Array, "10"]:  # Side note: These are shaped-array type hints, made possible by the package jaxtyping.
        for layer in self.layers:
            x = layer(x)
        return x

And finally, here is the training code (modified from this equinox tutorial), wrapped in the function train for convenience:

from time import time
import optax

optim = optax.adamw(LEARNING_RATE)

@eqx.filter_jit
def loss(
    model: CNN, 
    x: Float[Array, "batch 1 28 28"],  
    y: Int[Array, "batch"]
) -> Float[Array, ""]:
    pred_y = vmap(model)(x)
    return cross_entropy(y, pred_y)

def cross_entropy(
    y: Int[Array, "batch"], 
    pred_y: Float[Array, "batch 10"]
) -> Float[Array, ""]:
    pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
    return -jnp.mean(pred_y)

@eqx.filter_jit
def compute_accuracy(
    model: CNN, 
    x: Float[Array, "batch 1 28 28"], 
    y: Int[Array, "batch"]
) -> Float[Array, ""]:
    """This function takes as input the current model
    and computes the average accuracy on a batch.
    """
    pred_y = vmap(model)(x)
    pred_y = jnp.argmax(pred_y, axis=1)
    return jnp.mean(y == pred_y)

def evaluate(model: CNN, testloader: torch.utils.data.DataLoader):
    """This function evaluates the model on the test dataset,
    computing both the average loss and the average accuracy.
    """
    avg_loss = 0
    avg_acc = 0
    for x, y in testloader:
        x = x.numpy()
        y = y.numpy()
        avg_loss += loss(model, x, y)
        avg_acc += compute_accuracy(model, x, y)
    return avg_loss / len(testloader), avg_acc / len(testloader)

def train(
    model: CNN,
    trainloader: torch.utils.data.DataLoader,
    testloader: torch.utils.data.DataLoader,
    optim: optax.GradientTransformation,
    max_epochs: int,
    max_steps: int,
    print_every: int,
) -> CNN:
    """Trains the CNN."""

    # Initialise the optimizer state.
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    # Always wrap everything -- computing gradients, running the optimiser, updating
    # the model -- into a single JIT region. This ensures things run as fast as
    # possible.
    @eqx.filter_jit
    def make_step(
        model: CNN,
        opt_state: PyTree,
        x: Float[Array, "batch 1 28 28"],
        y: Int[Array, " batch"],
    ):
        loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
        updates, opt_state = optim.update(grads, opt_state, model)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss_value

    # Training loop
    step = 0
    max_steps_reached = False
    start_time = time()
    for epoch in range(max_epochs):
        for x, y in trainloader:
            x = x.numpy()
            y = y.numpy()
            model = eqx.nn.inference_mode(model, value=False)
            model, opt_state, train_loss = make_step(model, opt_state, x, y)
            if (step % print_every) == 0:
                model = eqx.nn.inference_mode(model, value=True)
                test_loss, test_accuracy = evaluate(model, testloader)
                hr, min, sec = convert_seconds_to_hms(time() - start_time)
                print(
                    f"elapsed_time = {hr:1d}h {min:1d}m {sec:.1f}s, {epoch=}, {step=}, train_loss={train_loss.item():.2f}, "
                    f"test_loss={test_loss.item():.2f}, test_accuracy={test_accuracy.item():.3f}"
                )
            step += 1
            if step >= max_steps:
                max_steps_reached = True
                print("Training complete. (Maximum number of steps reached.)")
                break
        if max_steps_reached:
            break
    if not max_steps_reached:
        print("Training complete.")
    hr, min, sec = convert_seconds_to_hms(time() - start_time)
    print(f"Total elapsed_time = {hr:1d}h {min:1d}m {sec:.1f}s")
    return model

JAX devices#

Before we start training, let’s talk a bit about JAX devices.

JAX will automatically detect all available devices (e.g., CPU, GPU) to use for computation. Let’s look at what devices are available:

print("CPU devices: ", jax.devices('cpu'))
try:
    is_gpu_avail = True
    print("GPU devices: ", jax.devices('gpu'))
except RuntimeError as e:
    is_gpu_avail = False
    print("GPU devices: ", e)
CPU devices:  [CpuDevice(id=0)]
GPU devices:  [cuda(id=0)]

If a GPU device was detected, then JAX will automatically make GPU the default backend. Otherwise, CPU is the default backend. Unless otherwise specified, JAX will do all computations in the default backend.

print('The default backend is:', jax.default_backend())
The default backend is: gpu

In short, JAX will automatically use the GPU if available (provided that you have installed the correct version of JAX compatible with your GPU). This is very convenient!

Committing arrays to a device#

In some cases, you may want to manually “force” computations to happen on devices/backends other than the default one. You do this by committing JAX arrays to a device. Then, all subsequent operations on that array will happen on the committed device. This is accomplished using the function jax.device_put. See the JAX FAQ for more details.

Let’s demonstrate this. First, we initialize our model:

model = CNN(key=key)  # An instance of our CNN class.

All the arrays in the pytree model are uncommitted, so subsequent operations on model will happen on the default device. Since our goal is to compare CPU and GPU runtimes, we will make two copies of model—one on CPU and the other on GPU.

Let’s do it:

# Helper functions for committing JAX arrays to a GPU or CPU
put_on_cpu = lambda x: jax.device_put(x, jax.devices('cpu')[0]) if isinstance(x, jax.Array) else x
put_on_gpu = lambda x: jax.device_put(x, jax.devices('gpu')[0]) if isinstance(x, jax.Array) else x

model_cpu = jax.tree.map(put_on_cpu, model)
if is_gpu_avail:
    model_gpu = jax.tree.map(put_on_gpu, model)
else:
    print("GPU not available.")

Great! All arrays in model_cpu are committed to CPU, and all arrays in model_gpu are committed to GPU. Let’s check one of the arrays just to be sure:

print('The CPU\'s copy of model is on: ', model_cpu.layers[0].weight.devices())
if is_gpu_avail:
    print('The GPU\'s copy of model is on: ', model_gpu.layers[0].weight.devices())
else:
    print("GPU not available.")
The CPU's copy of model is on:  {CpuDevice(id=0)}
The GPU's copy of model is on:  {cuda(id=0)}

Perfect. We are now ready to benchmark.

CPU vs. GPU runtime#

Let’s start by training on CPU (this may take a few minutes):

print("Begin training on CPU.")
trained_model_cpu = train(model_cpu, trainloader, testloader, optim, MAX_EPOCHS, MAX_STEPS, PRINT_EVERY)
Begin training on CPU.
elapsed_time = 0h 0m 32.0s, epoch=0, step=0, train_loss=2.31, test_loss=2.28, test_accuracy=0.193
elapsed_time = 0h 1m 3.2s, epoch=0, step=10, train_loss=1.11, test_loss=1.04, test_accuracy=0.721
elapsed_time = 0h 1m 37.8s, epoch=0, step=20, train_loss=0.55, test_loss=0.58, test_accuracy=0.815
elapsed_time = 0h 2m 7.9s, epoch=0, step=30, train_loss=0.40, test_loss=0.43, test_accuracy=0.862
elapsed_time = 0h 2m 49.0s, epoch=0, step=40, train_loss=0.44, test_loss=0.38, test_accuracy=0.884
elapsed_time = 0h 3m 19.7s, epoch=0, step=50, train_loss=0.46, test_loss=0.32, test_accuracy=0.909
elapsed_time = 0h 3m 52.6s, epoch=0, step=60, train_loss=0.24, test_loss=0.27, test_accuracy=0.922
elapsed_time = 0h 4m 22.7s, epoch=0, step=70, train_loss=0.14, test_loss=0.27, test_accuracy=0.919
elapsed_time = 0h 4m 53.1s, epoch=0, step=80, train_loss=0.28, test_loss=0.23, test_accuracy=0.935
elapsed_time = 0h 5m 23.1s, epoch=0, step=90, train_loss=0.38, test_loss=0.22, test_accuracy=0.939
Training complete. (Maximum number of steps reached.)
Total elapsed_time = 0h 5m 28.7s

This is fairly slow. Let’s try training on GPU:

print("Begin training on GPU.")
trained_model_gpu = train(model_gpu, trainloader, testloader, optim, MAX_EPOCHS, MAX_STEPS, PRINT_EVERY)
Begin training on GPU.
elapsed_time = 0h 0m 3.6s, epoch=0, step=0, train_loss=2.31, test_loss=2.30, test_accuracy=0.195
elapsed_time = 0h 0m 6.9s, epoch=0, step=10, train_loss=1.21, test_loss=1.13, test_accuracy=0.698
elapsed_time = 0h 0m 10.7s, epoch=0, step=20, train_loss=0.54, test_loss=0.54, test_accuracy=0.829
elapsed_time = 0h 0m 13.5s, epoch=0, step=30, train_loss=0.55, test_loss=0.41, test_accuracy=0.868
elapsed_time = 0h 0m 16.3s, epoch=0, step=40, train_loss=0.38, test_loss=0.40, test_accuracy=0.871
elapsed_time = 0h 0m 19.1s, epoch=0, step=50, train_loss=0.22, test_loss=0.34, test_accuracy=0.897
elapsed_time = 0h 0m 22.8s, epoch=0, step=60, train_loss=0.44, test_loss=0.28, test_accuracy=0.920
elapsed_time = 0h 0m 25.7s, epoch=0, step=70, train_loss=0.27, test_loss=0.27, test_accuracy=0.924
elapsed_time = 0h 0m 28.8s, epoch=0, step=80, train_loss=0.33, test_loss=0.23, test_accuracy=0.937
elapsed_time = 0h 0m 31.6s, epoch=0, step=90, train_loss=0.26, test_loss=0.22, test_accuracy=0.938
Training complete. (Maximum number of steps reached.)
Total elapsed_time = 0h 0m 31.7s

Much faster! Using the GPU resulted in 5-10X speedup over CPU for this small example. For larger models, the speedup will be even more dramatic. This is why virtually all modern deep learning models are trained on GPUs.