Hide code cell source
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks")

# Uncomment the next two lines if running for the book
# import warnings
# warnings.filterwarnings("ignore")

import jax
from jax import lax, jit, vmap, value_and_grad
import jax.random as jr
import jax.numpy as jnp
import jax.scipy.stats as jstats
import pandas as pd
import numpy as np
from tinygp import GaussianProcess, kernels, transforms
import optax
import equinox as eqx
import numpy as np
from functools import partial
from typing import NamedTuple, Callable, Optional
from jaxtyping import Float, Array

jax.config.update("jax_enable_x64", True)
key = jr.PRNGKey(0);

Example of Gaussian Process Surrogate#

Model for subcutaneous autoinjectors#

In this notebook, we’ll repeat the neural network surrogate hands-on activity, but this time we’ll use a gaussian process (GP) as the surrogate. Recall that our goal is to replicate Sree et. al. (2023), i.e., create a surrogate of an expensive biomechanical model.

Just to recap: We have a high-fidelity, finite-element-based model \(f\) of drug autoinjections. The model takes inputs \(x\) like drug viscosity, tissue biomechanical properties, etc. and calculates various outputs \(y=f(x)\). To train a surrogate \(\hat{f}\) of this model, we have a dataset containing thousands of evaluations of \(f\). For further details about the model or dataset, see the neural network surrogate hands-on activity.

Let’s get the data imported. (We’ll follow the same preprocessing steps as the neural network surrogate hands-on activity, except we won’t re-explain the steps here—we’ll just do them.)

Hide code cell source
# Define new names for each input/output variable
old_input_names = ['mu', 'fill_volume', 'hGap0', 'lNeedle', 'dNeedle', 'FSpring0', 'kSpring', 'kappa5', 'kappa6', 'kappa7']
INPUT_NAMES = ['viscosity_cP', 'fill_volume_mL', 'air_gap_height_mm', 'needle_length_mm', 'needle_diameter_mm', 'spring_force_N', 'spring_constant_N_per_mm', 'kappa5', 'kappa6', 'kappa7']
old_output_names = ['Needle displacement (m)', 'Injection time (s)', 'max. acceleration (m/s^2)', 'max. deceleration (m/s^2)']
OUTPUT_NAMES = ['needle_displacement_m', 'injection_time_s', 'max_acceleration_m_per_s2', 'max_deceleration_m_per_s2']
column_name_mapper = dict(zip(old_input_names + old_output_names, INPUT_NAMES + OUTPUT_NAMES))

# Load the data
train_data = pd.read_excel("https://raw.githubusercontent.com/PredictiveScienceLab/advanced-scientific-machine-learning/refs/heads/main/book/data/autoinjector_surrogate/training_data.xlsx", index_col=0).rename(columns=column_name_mapper).sample(frac=1).reset_index(drop=True)
test_data = pd.read_excel("https://raw.githubusercontent.com/PredictiveScienceLab/advanced-scientific-machine-learning/refs/heads/main/book/data/autoinjector_surrogate/test_data.xlsx", index_col=0).rename(columns=column_name_mapper).sample(frac=1).reset_index(drop=True)

# Split data. Pass in column names to ensure correct order.
X_train = pd.DataFrame(train_data[INPUT_NAMES], columns=INPUT_NAMES)  
y_train = pd.DataFrame(train_data[OUTPUT_NAMES], columns=OUTPUT_NAMES)
X_test = pd.DataFrame(test_data[INPUT_NAMES], columns=INPUT_NAMES)
y_test = pd.DataFrame(test_data[OUTPUT_NAMES], columns=OUTPUT_NAMES)

# Set up scalers
X_train_mean = np.mean(X_train.values, axis=0)
X_train_std = np.std(X_train.values, axis=0)
y_train_mean = np.mean(y_train.values, axis=0)
y_train_std = np.std(y_train.values, axis=0)

def build_standard_scaler(data: Float[Array, "N d"]) -> tuple[Callable[[Array], Array], Callable[[Array], Array]]:
    """Factory function to build a standard scaler for the given data."""
    mean = np.mean(data, axis=0)
    std = np.std(data, axis=0)
    scale = lambda x: (x - mean) / std
    unscale = lambda x: x * std + mean
    return scale, unscale

def build_log_scaler(data: Float[Array, "N d"]) -> tuple[Callable[[Array], Array], Callable[[Array], Array]]:
    """Factory function to build a log scaler for the given data."""
    mean = np.mean(data, axis=0)
    std = np.std(data, axis=0)
    scale = lambda x: (jnp.log(x) - mean) / std
    unscale = lambda x: jnp.exp(x * std + mean)
    return scale, unscale

# Input scaler
scale_x, unscale_x = build_standard_scaler(X_train.values)

# Separate scalers for each output
scale_y0, unscale_y0 = build_standard_scaler(y_train.values[:, 0].reshape(-1, 1))
scale_y1, unscale_y1 = build_standard_scaler(y_train.values[:, 1].reshape(-1, 1))
scale_y2, unscale_y2 = build_standard_scaler(y_train.values[:, 2].reshape(-1, 1))
scale_y3, unscale_y3 = build_standard_scaler(y_train.values[:, 3].reshape(-1, 1))

# Combine scalers for all outputs (for convenience)
scale_y = lambda y: jnp.column_stack([scale_y0(y[..., 0]), scale_y1(y[..., 1]), scale_y2(y[..., 2]), scale_y3(y[..., 3])])
unscale_y = lambda y: jnp.column_stack([unscale_y0(y[..., 0]), unscale_y1(y[..., 1]), unscale_y2(y[..., 2]), unscale_y3(y[..., 3])])
X_train_scaled = scale_x(X_train.values)
y_train_scaled = scale_y(y_train.values)
X_test_scaled = scale_x(X_test.values)
y_test_scaled = scale_y(y_test.values)

The data are now imported and standardized. Let’s visualize to make sure things look good:

Hide code cell source
fig, ax = plt.subplots(1, X_train_scaled.shape[1], figsize=(15, 2), tight_layout=True, sharey=True)
fig.suptitle("Transformed input data")
for i, i_name in enumerate(INPUT_NAMES):
    ax[i].hist(X_train_scaled[:, i], bins=20)
    ax[i].set_title(i_name)
    ax[i].set_yticks([])
    sns.despine(trim=True, ax=ax[i], left=True)

fig, ax = plt.subplots(1, y_train_scaled.shape[1], figsize=(8, 2), tight_layout=True, sharey=True)
fig.suptitle("Transformed output data")
for i, o_name in enumerate(OUTPUT_NAMES):
    ax[i].hist(y_train_scaled[:, i], bins=20)
    ax[i].set_title(o_name)
    ax[i].set_yticks([])
    sns.despine(trim=True, ax=ax[i], left=True)
../../_images/e816d714ac96392a86b90bc84f52556ae9a5424caa978e06d69f30b0386b10ca.svg ../../_images/0844084bbd06faa6f1983779f170bc3cd7ff2b8bf8272ca289005d1a1dcb6643.svg

Nice! The input/output data are all more-or-less evenly distributed and standarized. Now on to building the Gaussian process surrogate.

Gaussian process regression#

To demonstrate, we will use tinygp to build a GP surrogate for one of the outputs. Let’s pick injection time as the output to model. Our surrogate will be a GP with a zero-mean function with a radial basis function (RBF) kernel.

\[ f \sim \operatorname{GP}(m, k). \]
\[ k(x, x') = \sigma^2 \exp\left(-\frac{1}{2} \sum_{i=1}^d \left(\frac{x_i - x'_i}{\ell_i}\right)^2\right) \]

where \(\sigma^2\) is the variance and \(\ell_i\) is the length scale for input dimension \(i\).

Here is how to implement this in tinygp. First, we’ll create a function that generates a tinygp.GaussianProcess object.

def build_gp(params, X):
    """Build a Gaussian process with RBF kernel.
    
    Parameters
    ----------
    params : dict
        Hyperparameters of the GP.
    X : ndarray
        Training data input locations.
    
    Returns
    -------
    GaussianProcess
        The GP.
    """
    amplitude = jnp.exp(params['log_amplitude'])
    lengthscale = jnp.exp(params['log_lengthscale'])
    noise_variance = jnp.exp(params['log_noise_variance'])
    k = amplitude*transforms.Linear(1/lengthscale, kernels.ExpSquared())
    return GaussianProcess(kernel=k, X=X, diag=noise_variance)

Optimizing hyperparameters#

The hyperparameters are the (logs of) the kernel amplitude \(\sigma\), length scales \(\ell = (\ell_1, \dots, \ell_d)\), and the measurement noise level \(\sigma_n\). Let’s set these hyperparameters to some arbitrary values for now. (They will be optimized later.)

init_params = {
    'log_amplitude': 1.0, 
    'log_lengthscale': -jnp.ones(10),  # Different lengthscale for each input dimension
    'log_noise_variance': -4.0
}

Now we will optimize the hyperparameters by maximizing the marginal log-likelihood. Here are the loss and train functions:

def loss(params, X, y):
    """Negative marginal log likelihood of the GP."""
    gp = build_gp(params, X)
    return -gp.log_probability(y)

@eqx.filter_jit
def train_step_adam(carry, _, X, y, optim, batch_size):
    params, opt_state, key = carry
    key, subkey = jr.split(key)
    idx = jr.randint(subkey, (batch_size,), 0, X.shape[0])
    value, grads = value_and_grad(loss)(params, X[idx], y[idx])
    updates, opt_state = optim.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return (params, opt_state, key), value

def train_gp(init_params, X, y, num_iters, learning_rate, batch_size, key):
    """Optimize the hyperparameters (xi) of a GP using the Adam optimizer.
    
    Parameters
    ----------
    init_params : dict
        Initial values of the hyperparameters.
    X, y: ndarray
        Training data.
    num_iters : int
        Number of optimization steps.
    learning_rate : float
        Learning rate for the optimizer.
    
    Returns
    -------
    dict
        The optimized hyperparameters.
    ndarray
        The loss values at each iteration.
    """
    
    # Initialize the optimizer
    optim = optax.adam(learning_rate)

    # Initialize the optimizer state
    init_carry = (init_params, optim.init(init_params), key)

    # Do optimization
    train_step = partial(train_step_adam, X=X, y=y, optim=optim, batch_size=batch_size)
    carry, losses = lax.scan(train_step, init_carry, None, num_iters)

    return carry[0], losses  # (optimized params, loss values)

We’ll only use 200 data points. (Standard GPs are slow with large datasets, so we’ll use a subset of the data for demonstration purposes.) Here we go:

# Grab a subset of the training data
idx_output = 1  # Corresponds to injection time
idx_subsample = np.random.choice(X_train_scaled.shape[0], size=200)
X_train_scaled_sub = X_train_scaled[idx_subsample]
y_train_scaled_injection_time = y_train_scaled[idx_subsample, idx_output]

# Optimize the hyperparameters of the GP
key, subkey = jr.split(key)
trained_params, losses = train_gp(init_params, X_train_scaled_sub, y_train_scaled_injection_time, num_iters=1000, learning_rate=1e-2, batch_size=100, key=subkey)
fig, ax = plt.subplots()
ax.plot(losses)
ax.set_xlabel("Iteration")
ax.set_ylabel("Negative marginal log likelihood")
ax.set_title("Convergence of loss for optimizating the GP hyperparameters", fontsize=16)
sns.despine(trim=True);
../../_images/d8d3b6ad6946b4cededb24317dec956dc6a51a885ec39bd9c27976124e1b5efd.svg

Great! We now have a set of optimized hyperparameters.

Conditioning the GP#

Let’s condition the GP on the training data and compute it’s mean at some arbitrary point. Here is how to do it with tinygp:

# First, build the GP with the trained hyperparameters
gp = build_gp(trained_params, X_train_scaled_sub)

# Next, pick a point at which to evaluate the GP
X_eval = jnp.zeros((1, 10))  # Must be shape (n_pts, n_dim)

# Finally, condition the GP on the training data
_, cond_gp = gp.condition(y_train_scaled_injection_time, X_eval)  # Returns (log marginal probability, conditioned GP)

# The surrogate model is the mean of the conditioned GP
mean = cond_gp.mean

# We can also draw samples of the GP at this point (which will just be samples from a univariate Gaussian)
key, subkey = jr.split(key)
samples = cond_gp.sample(subkey, shape=(5000,))

Note that cond_gp is just another tinygp.GaussianProcess object. Let’s plot the samples we’ve just collected in a histogram:

Hide code cell source
def samples_hist_plot(samples, mean, xlabel):
    fig, ax = plt.subplots(figsize=(6, 4), tight_layout=True)
    ax.hist(samples, bins=20, density=True, alpha=0.5, label="GP samples")
    ax.axvline(mean, color='red', label="GP mean")
    ax.set_title(r"Surrogate model of injection time evaluated at $x_\text{scaled}=(0,0,0,0,0,0,0,0,0,0)$", fontsize=12)
    ax.set_xlabel(xlabel)
    # ax.set_ylabel("Density")
    ax.legend()
    sns.despine(trim=True, ax=ax, left=True);

samples_hist_plot(samples, mean, "Scaled injection time")
# samples_hist_plot(unscale_y1(samples), np.mean(unscale_y1(samples)), "Injection time (s)")
../../_images/00df31ec4edcfcf14779abe6ad7390e2a7c205a8f97258d368545c8d0cd13bb1.svg

Diagnostics: How good is the surrogate?#

The next question is: How accurate is our surrogate? We will use our test dataset to find out.

Parity plot#

Let’s plot the predicted output against the true output (for the test dataset):

_, cond_gp = gp.condition(y_train_scaled_injection_time, X_train_scaled_sub)
means_train = cond_gp.mean
stdevs_train = jnp.sqrt(cond_gp.variance)

_, cond_gp = gp.condition(y_train_scaled_injection_time, X_test_scaled)
means_test = cond_gp.mean
stdevs_test = jnp.sqrt(cond_gp.variance)
Hide code cell source
rmse = lambda y, y_hat: jnp.sqrt(jnp.mean((y - y_hat)**2))

fig, ax = plt.subplots()
fig.suptitle('Model predictions vs. training data', fontsize=16)
ax.errorbar(y_train_scaled_injection_time, means_train, yerr=2 * stdevs_train, fmt='o', ms=4, alpha=0.5, lw=0.1)
ax.plot([y_train_scaled_injection_time.min(), y_train_scaled_injection_time.max()], [y_train_scaled_injection_time.min(), y_train_scaled_injection_time.max()], "r-", lw=1, zorder=100)
ax.annotate(f"RMSE: {rmse(y_train_scaled_injection_time, means_train):.3f}", xy=(0.05, 0.9), xycoords='axes fraction')
ax.set_title('Scaled injection time')
sns.despine(trim=True, ax=ax)
ax.set_xlabel('True')
ax.set_ylabel('Predicted')

y_test_scaled_injection_time = y_test_scaled[:, 1]

fig, ax = plt.subplots()
fig.suptitle('Model predictions vs. test data', fontsize=16)
ax.errorbar(y_test_scaled_injection_time, means_test, yerr=2 * stdevs_test, fmt='o', ms=4, alpha=0.5, lw=0.1)
ax.plot([y_test_scaled_injection_time.min(), y_test_scaled_injection_time.max()], [y_test_scaled_injection_time.min(), y_test_scaled_injection_time.max()], "r-", lw=1, zorder=100)
ax.annotate(f"RMSE: {rmse(y_test_scaled_injection_time, means_test):.3f}", xy=(0.05, 0.9), xycoords='axes fraction')
ax.set_title('Scaled injection time')
sns.despine(trim=True, ax=ax)
ax.set_xlabel('True')
ax.set_ylabel('Predicted');
../../_images/752a27d0715d9006756dd264536ce0663495c9f150fc6794818dccbca1d4c735.svg ../../_images/1693101e8182bc3e8269522364a419c877b8f71d2514384479ce6e31458fd38a.svg

As expected, we’ve fit the training data perfectly (all points lie on the red line). However, we mostly care about how well we predict unseen (test) data. The fit looks decent.

For GPs, parity plots and RSME values are good for checking if our predictive mean function is correct. However, they do not say anything about whether our uncertainty estimates are correct. For that, we must look at the standardized errors.

Standardized errors#

For the \(i^\text{th}\) test data point, the standardized error \(e_i\) is

\[ e_i = \frac{y_i - m(x_i)}{\sigma(x_i)} \]

where \(y_i\) is the true output, and \(m(x_i)\) and \(\sigma(x_i)\) are the mean and standard deviation of the prediction, respectively.

standardized_errors = (y_test_scaled_injection_time - means_test) / stdevs_test
Hide code cell source
fig, ax = plt.subplots()
fig.suptitle('Standardized errors', fontsize=16)
ax.hist(standardized_errors, bins=30, density=True)
x = jnp.linspace(-5, 5, 100)
y = jnp.exp(-0.5 * x**2) / jnp.sqrt(2 * jnp.pi)
ax.plot(x, y, 'r--', lw=2)
ax.set_title('Scaled injection time')
ax.set_yticks([])
sns.despine(trim=True, ax=ax, left=True);
../../_images/05f24505f2a331592d945c7b2c4a9dbe3663fd27954cbe743d9d0ecbeb9c3bc1.svg

QQ plot#

The histogram above looks kind of like a normal distributions, but it’s sometimes hard to tell. A quantile-quantile (QQ) plot is another way to check normality:

quantiles = jnp.linspace(0.01, 0.99, 40)
normal_quantiles = jstats.norm.ppf(quantiles)
error_quantiles = jnp.quantile(standardized_errors, quantiles, axis=0)
Hide code cell source
fig, ax = plt.subplots()
ax.plot(normal_quantiles, error_quantiles, 'o')
ax.plot(normal_quantiles, normal_quantiles, 'r--')
ax.set_xlabel('Theoretical quantiles')
ax.set_ylabel('Sample quantiles')
ax.set_title('Normal Q-Q plot', fontsize=16)
sns.despine(trim=True, ax=ax);
../../_images/0d8452f5239645de49a5f244536d83ea5e44eeff812b01d18a4cad17c04773e1.svg

The straighter the line, the more “normal” the errors are. They look decent.

Residuals plot#

Another assumption in our GP model is that the noise is homoscedastic—i.e., the same across all input/output values. We can check this assumption by plotting the error against the model prediction. This is called a residual plot. (Note that residuals is just another name for the errors.)

Hide code cell source
fig, ax = plt.subplots()
fig.suptitle('Residual plot', fontsize=16)
ax.scatter(means_test, standardized_errors, 1)
ax.set_title(o_name)
ax.axhline(-2, color='r', lw=0.5, label=r'95% CI')
ax.axhline(2, color='r', lw=0.5)
sns.despine(trim=True, ax=ax)
ax.set_xlabel('Predicted')
ax.set_ylabel('Standardized error')
ax.legend();
../../_images/580b85f468f910cc57d0946abbaae5c5f302fd450ba04eb35a70160591d31ec8.svg

Not too bad.

Multiple outputs#

Let’s now build a surrogate for all four outputs. We will do this by building a GP for each output independently. This is implemented below:

Hide code cell source
def create_dataset(size=None, X=X_train_scaled, y=y_train_scaled):
    idx = np.random.choice(X.shape[0], size=size)
    return X[idx], y[idx]

def train_gps(init_params, X, Y, num_iters, learning_rate, batch_size, key):
    """Optimize a different set of GP hyperparameters for each output variable."""
    keys = jr.split(key, len(OUTPUT_NAMES))
    trained_params = {}
    loss_histories = {}
    for i, o_name in enumerate(OUTPUT_NAMES):
        print(f"Training GP for {o_name}")
        y = Y[:, i]
        trained_params_i, losses_i = train_gp(init_params, X, y, num_iters, learning_rate, batch_size, keys[i])
        trained_params[o_name] = trained_params_i
        loss_histories[o_name] = losses_i
    return trained_params, loss_histories

class Surrogate(eqx.Module):
    """A Gaussian process surrogate with multiple inputs and outputs.
    
    Parameters
    ----------
    gps : list[dict]
        A list of dictionaries, each containing the hyperparameters of a GP for a different output variable.
    y_obs : Float[Array, "n_data n_outputs"]
        The observed output data.
    """
    gps: list[dict]
    y_obs: Float[Array, "n_data n_outputs"]
    
    @eqx.filter_jit
    def mean(self, X):
        """Evaluate the unscaled model."""
        y = []
        for i in range(self.num_outputs):
            _, cond_gp = self.gps[i].condition(self.y_obs[:, i], X)
            y.append( cond_gp.mean )
        return jnp.stack(y, axis=-1)
    
    @eqx.filter_jit
    def sample(self, X, num_samples, *, key):
        """Sample the output."""
        samples = []
        for i in range(self.num_outputs):
            key, subkey = jr.split(key)
            _, cond_gp = self.gps[i].condition(self.y_obs[:, i], X)
            samples.append( cond_gp.sample(subkey, shape=(num_samples,)) )
        return jnp.stack(samples, axis=-1)
    
    @property
    def num_outputs(self):
        return len(self.gps)
X_train_scaled_sub, y_train_scaled_sub = create_dataset(size=200)
key, subkey = jr.split(key)
trained_params, _ = train_gps(init_params, X_train_scaled_sub, y_train_scaled_sub, num_iters=1000, learning_rate=1e-2, batch_size=100, key=subkey)

# Create a GP for each output variable
gps = [build_gp(trained_params[o_name], X_train_scaled_sub) for o_name in OUTPUT_NAMES]

surrogate = Surrogate(gps=gps, y_obs=y_train_scaled_sub)
Training GP for needle_displacement_m
Training GP for injection_time_s
Training GP for max_acceleration_m_per_s2
Training GP for max_deceleration_m_per_s2

Though not necessary, we have packaged the surrogate into an equinox model Surrogate for convenience. Here is how we can evaluate the surrogate at some \(X\):

# Pick an input point
X = jnp.array([
    10.0,  # viscosity_cP
    1.02,  # fill_volume_mL
    4.5,  # air_gap_height_mm
    12.0,  # needle_length_mm
    0.17,  # needle_diameter_mm
    27.0, # spring_force_N
    200.0,  # spring_constant_N_per_mm
    1.5,  # kappa5
    2.0,  # kappa6
    0.05  # kappa7
])[None, :]  # Shape must be (n_pts, n_dim)

# Scale the input
X_s = scale_x(X)

# Evaluate the surrogate mean
Y_s = surrogate.mean(X_s)  # Must be a scaled X!

# Transform the output back to physical space
Y = unscale_y(Y_s)

print(f'needle_displacement_m:     {Y.squeeze(0)[0]:.5f}')
print(f'injection_time_s:          {Y.squeeze(0)[1]:.2f}')
print(f'max_acceleration_m_per_s2: {Y.squeeze(0)[2]:.0f}')
print(f'max_deceleration_m_per_s2: {Y.squeeze(0)[3]:.1f}')
needle_displacement_m:     0.00741
injection_time_s:          9.67
max_acceleration_m_per_s2: 42790
max_deceleration_m_per_s2: 40246.4

We can also sample the surrogate output:

# Sample the scaled outputs
scaled_samples = surrogate.sample(X_s, 1000, key=key)

# Transform to physical space
samples = vmap(unscale_y)(scaled_samples)
Hide code cell source
fig, ax = plt.subplots(1, len(OUTPUT_NAMES), figsize=(10, 2), tight_layout=True)
for i, o_name in enumerate(OUTPUT_NAMES):
    ax[i].hist(samples.squeeze(1)[:, i], bins=16, density=True)
    ax[i].set_title(o_name)
    ax[i].set_yticks([])
    sns.despine(trim=True, ax=ax[i], left=True)
../../_images/379bd396c90c5174e265bc39b09304c86948454ddfd15e2e23a54408b4d9948b.svg

The uncertainty you see here is the epistemic uncertainty (or “lack-of-data” uncertainty). It should go down as we add more data points. We’ll look at this next.

Convergence: How much data do we need?#

Ideally, we want to use enough data so that the epistemic uncertainty is low. Let’s test how much data we need for this problem. We’ll create training datasets of sizes \(N=50, 100, 200, 500, 1000, 2000\) and observe the convergence of the trained GP.

dataset_sizes = [50, 100, 200, 500, 1000, 2000]
datasets_N = {size: create_dataset(size) for size in dataset_sizes}

surrogates_N = {}
trained_params_N = {}
loss_histories_N = {}

for N in dataset_sizes:
    print(f'Now training GP with dataset of size {N:<3} ... ')

    key, subkey = jr.split(key)
    X, y = datasets_N[N]
    key, subkey = jr.split(key)
    trained_params_N[N], loss_histories_N[N] = train_gps(init_params, X, y, num_iters=1000, learning_rate=1e-2, batch_size=100, key=subkey)
    gps = [build_gp(trained_params[o_name], X) for o_name in OUTPUT_NAMES]
    surrogates_N[N] = Surrogate(gps=gps, y_obs=y)

    print('done.')
Hide code cell output
Now training GP with dataset of size 50  ... 
Training GP for needle_displacement_m
Training GP for injection_time_s
Training GP for max_acceleration_m_per_s2
Training GP for max_deceleration_m_per_s2
done.
Now training GP with dataset of size 100 ... 
Training GP for needle_displacement_m
Training GP for injection_time_s
Training GP for max_acceleration_m_per_s2
Training GP for max_deceleration_m_per_s2
done.
Now training GP with dataset of size 200 ... 
Training GP for needle_displacement_m
Training GP for injection_time_s
Training GP for max_acceleration_m_per_s2
Training GP for max_deceleration_m_per_s2
done.
Now training GP with dataset of size 500 ... 
Training GP for needle_displacement_m
Training GP for injection_time_s
Training GP for max_acceleration_m_per_s2
Training GP for max_deceleration_m_per_s2
done.
Now training GP with dataset of size 1000 ... 
Training GP for needle_displacement_m
Training GP for injection_time_s
Training GP for max_acceleration_m_per_s2
Training GP for max_deceleration_m_per_s2
done.
Now training GP with dataset of size 2000 ... 
Training GP for needle_displacement_m
Training GP for injection_time_s
Training GP for max_acceleration_m_per_s2
Training GP for max_deceleration_m_per_s2
done.
y_latent_samples_N = {}
for i, N in enumerate(dataset_sizes):
    y_latent_samples_N[N] = surrogates_N[N].sample(X_s, num_samples=1000, key=key)  # Shape is (n_samples, n_points, n_outputs)
Hide code cell source
output_index = 1  # Select an output

fig, ax = plt.subplots(len(dataset_sizes), 1, figsize=(4, 4), tight_layout=True, sharex=True, sharey=False)
for i, N in enumerate(dataset_sizes):
    fig.suptitle(f'Epistemic uncertainty for {OUTPUT_NAMES[output_index]} (scaled)')
    sns.kdeplot(y_latent_samples_N[N].squeeze(1)[:,output_index], ax=ax[i], lw=2)
    ax[i].set_title(f'N={N}')
    ax[i].set_yticks([])
    ax[i].set_xlabel('')
    ax[i].set_ylabel('')
    sns.despine(trim=True, ax=ax[i], left=True)
../../_images/641ba828f727d7e2e41bec35d15b8079b80ff3ef9bab6bb86a3fe82d23809a43.svg
means_N = {}
stdevs_N = {}
for i, N in enumerate(dataset_sizes):
    s = surrogates_N[N].sample(X_s, num_samples=1000, key=key)
    means_N[N] = jnp.mean(s, axis=0)
    stdevs_N[N] = jnp.std(s, axis=0)
Hide code cell source
fig, ax = plt.subplots()
fig.suptitle('Predictive standard deviation vs. dataset size', fontsize=16)
for i, o_name in enumerate(OUTPUT_NAMES):
    ax.plot(dataset_sizes, [stdevs_N[N][0, i] for N in dataset_sizes], 'o-', label=f'{o_name} (scaled)', color=f'C{i}', alpha=0.7)
ax.set_xlabel('Dataset size')
ax.set_ylabel('Predictive standard deviation')
ax.legend()
sns.despine(trim=True, ax=ax)
../../_images/370ef6bc0dbb6c521812c9a9af5b4110d25dcdb1bf740df079154755106cef94.svg

For most of the outputs, we only need a few hundred data points. (Predicting needle displacement accurately, however, will require more data or some inductive biases built into the surrogate.)

Finally, let’s look at the parity plots for the GP trained on the most data.

surrogate = surrogates_N[2000]
samples = surrogate.sample(X_test_scaled, num_samples=1000, key=key)
means_test = samples.mean(axis=0)
stdevs_test = samples.std(axis=0)
Hide code cell source
rmse = lambda y, y_hat: jnp.sqrt(jnp.mean((y - y_hat)**2))

fig, ax = plt.subplots(1, len(OUTPUT_NAMES), figsize=(10, 3), tight_layout=True)
fig.suptitle('Validation with test data', fontsize=16)
for i, o_name in enumerate(OUTPUT_NAMES):
    ax[i].errorbar(y_test_scaled[:, i], means_test[:, i], yerr=2 * stdevs_test[:, i], fmt='o', ms=2, alpha=0.3, lw=0.1)
    ax[i].plot([y_test_scaled[:, i].min(), y_test_scaled[:, i].max()], [y_test_scaled[:, i].min(), y_test_scaled[:, i].max()], "r-", lw=0.5, zorder=100)
    ax[i].annotate(f"RMSE: {rmse(y_test_scaled[:, i], means_test[:, i]):.3f}", xy=(0.05, 0.9), xycoords='axes fraction')
    ax[i].set_title(o_name)
    sns.despine(trim=True, ax=ax[i])
ax[0].set_xlabel('True')
ax[0].set_ylabel('Predicted');
../../_images/794b8eb73fe87f96fcea81676fcd391c8fceae97c475f3636dd82d191a594c96.svg

For most outputs, the fit is excellent.

Sensitivity analysis with surrogate#

Now that we have a trained and tested surrogate, there are many useful things we can do with it. As in the neural network surrogate hands-on activity, we’ll demonstrate using the surrogate to do Sobol sensitivity analysis.

Now, let’s create bounds for the inputs (based on physical intuition and/or literature values):

import SALib.sample.sobol as sobol
import SALib.analyze.sobol as analyze_sobol

input_bounds_dict = {
    'viscosity_cP': [1.0, 20.0],
    'fill_volume_mL': [1.0, 1.05],
    'air_gap_height_mm': [4.0, 5.0],
    'needle_length_mm': [8.0, 15.9],
    'needle_diameter_mm': [0.133, 0.21],
    'spring_force_N': [18.0, 36.0],
    'spring_constant_N_per_mm': [150.0, 250.0],
    'kappa5': [0.0, 3.0],
    'kappa6': [0.0, 4.0],
    'kappa7': [0.0, 0.1]
}
input_bounds = np.array([[input_bounds_dict[i][0], input_bounds_dict[i][1]] for i in INPUT_NAMES])
print('The input bounds are:\n', input_bounds)
The input bounds are:
 [[1.00e+00 2.00e+01]
 [1.00e+00 1.05e+00]
 [4.00e+00 5.00e+00]
 [8.00e+00 1.59e+01]
 [1.33e-01 2.10e-01]
 [1.80e+01 3.60e+01]
 [1.50e+02 2.50e+02]
 [0.00e+00 3.00e+00]
 [0.00e+00 4.00e+00]
 [0.00e+00 1.00e-01]]

Next, we create Sobol samples of the inputs and pass them through the surrogate:

problem = {
    'num_vars': len(INPUT_NAMES),
    'names': INPUT_NAMES,
    'bounds': input_bounds
}

# The number of samples to generate (should be a power of 2).
N = 512

# Generate the samples.
sobol_samples = sobol.sample(problem, N, calc_second_order=False)
sobol_samples_scaled = scale_x(sobol_samples)

# Evaluate the surrogate model at the Sobol samples.
sobol_outputs = surrogate.mean(sobol_samples_scaled)

Finally, we calculate and plot the Sobol indices:

sobol_indices = {}
for i, o_name in enumerate(OUTPUT_NAMES):
    sobol_indices[o_name] = analyze_sobol.analyze(problem, sobol_outputs[:, i], calc_second_order=False, print_to_console=False)
/Users/holtw/Documents/mydocs/software/advanced-scientific-machine-learning/.venv/lib/python3.11/site-packages/SALib/util/__init__.py:274: FutureWarning: unique with argument that is not not a Series, Index, ExtensionArray, or np.ndarray is deprecated and will raise in a future version.
  names = list(pd.unique(groups))
Hide code cell source
# Plot sobol indices
fig, ax = plt.subplots(1, len(OUTPUT_NAMES), figsize=(15, 4), tight_layout=True, sharey=True)
fig.suptitle('Sobol indices', fontsize=16)
for i, o_name in enumerate(OUTPUT_NAMES):
    ax[i].bar(problem['names'], sobol_indices[o_name]['S1'])
    ax[i].set_title(o_name)
    for tick in ax[i].get_xticklabels():
        tick.set_rotation(90)
    sns.despine(trim=True, ax=ax[i])
../../_images/cdab05a2b36270f0a569126ce61e74bd0d014b7fb1911303fc6bdea773a21439.svg

Excellent! We now know to which inputs the outputs are most sensitive. For example, we can see that the injection time is very sensitive to the drug viscosity and needle diameter, somewhat sensitive to needle length and injector spring force, and not sensitive to any other inputs. This information can be used, for example, to further study and understand the physics behind the model, or to investigate how identifiable each parameter is given an experimental dataset.

Remember that sensitivity analysis would not have been feasible without a surrogate model—the true physical model was just too expensive. With a surrogate however, we can do sensitivity analysis, design optimization, uncertainty quantification, etc. all at a reasonable computational cost.