Hide code cell source
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks")

# Uncomment the next two lines if running for the book
# import warnings
# warnings.filterwarnings("ignore")

import jax
from jax import vmap
import jax.random as jrandom
import jax.numpy as jnp
import jax.scipy.stats as jstats
import pandas as pd
import numpy as np
import gpjax as gpx
from flax import nnx
import optax
import equinox as eqx
from typing import Optional
from jaxtyping import Float, Array

jax.config.update("jax_enable_x64", True)
key = jrandom.PRNGKey(0);

Example of Gaussian Process Surrogate#

Model for subcutaneous autoinjectors#

In this notebook, we’ll repeat the neural network surrogate hands-on activity, but this time we’ll use a gaussian process (GP) as the surrogate. Recall that our goal is to replicate Sree et. al. (2023), i.e., create a surrogate of an expensive biomechanical model.

Just to recap: We have a high-fidelity, finite-element-based model \(f\) of drug autoinjections. The model takes inputs \(x\) like drug viscosity, tissue biomechanical properties, etc. and calculates various outputs \(y=f(x)\). To train a surrogate \(\hat{f}\) of this model, we have a dataset containing thousands of evaluations of \(f\). For further details about the model or dataset, see the neural network surrogate hands-on activity.

Let’s get the data imported. (We’ll follow the same preprocessing steps as the neural network surrogate hands-on activity, except we won’t re-explain the steps here—we’ll just do them.)

Hide code cell source
class StandardScaler:
    """A transform that standarizes the data to have mean=0 and stdev=1.
    
    Parameters
    ----------
    data : Float[Array, "N d"]
        The data to be standarized. Shape is (N, d)=(number of samples, dimensionality).
    pretransform_forward : Optional[callable]
        An optional transformation to apply to the data before standarizing.
    pretransform_inverse : Optional[callable]
        The inverse of pretransform_forward.
    """
    def __init__(
        self, 
        data: Float[Array, "N d"], 
        pretransform_forward: Optional[callable] = None, 
        pretransform_inverse: Optional[callable] = None
    ):
        # Set up pre-transform functions
        if (pretransform_forward is None) ^ (pretransform_inverse is None):
            raise ValueError("Both pretransform_forward and pretransform_inverse must be provided.")
        elif pretransform_forward is None:
            pretransform_forward = lambda x: x
            pretransform_inverse = lambda x: x
        self.pre_forward = pretransform_forward
        self.pre_inverse = pretransform_inverse
        
        # Compute parameters for the standardizer
        pretransformed_data = vmap(self.pre_forward)(data)
        self.mean = jnp.mean(pretransformed_data, axis=0)
        self.std = jnp.std(pretransformed_data, axis=0)        
    
    def forward(self, x):
        return (self.pre_forward(x) - self.mean) / self.std
    
    def inverse(self, y):
        return self.pre_inverse(y * self.std + self.mean)

# Define new names for each input/output variable
old_input_names = ['mu', 'fill_volume', 'hGap0', 'lNeedle', 'dNeedle', 'FSpring0', 'kSpring', 'kappa5', 'kappa6', 'kappa7']
INPUT_NAMES = ['viscosity_cP', 'fill_volume_mL', 'air_gap_height_mm', 'needle_length_mm', 'needle_diameter_mm', 'spring_force_N', 'spring_constant_N_per_mm', 'kappa5', 'kappa6', 'kappa7']
old_output_names = ['Needle displacement (m)', 'Injection time (s)', 'max. acceleration (m/s^2)', 'max. deceleration (m/s^2)']
OUTPUT_NAMES = ['needle_displacement_m', 'injection_time_s', 'max_acceleration_m_per_s2', 'max_deceleration_m_per_s2']
column_name_mapper = dict(zip(old_input_names + old_output_names, INPUT_NAMES + OUTPUT_NAMES))

# Load the data
train_data = pd.read_excel("../../data/training_data.xlsx", index_col=0).rename(columns=column_name_mapper).sample(frac=1).reset_index(drop=True)
test_data = pd.read_excel("../../data/test_data.xlsx", index_col=0).rename(columns=column_name_mapper).sample(frac=1).reset_index(drop=True)

# Split data. Pass in column names to ensure correct order.
X_train = pd.DataFrame(train_data[INPUT_NAMES], columns=INPUT_NAMES)  
y_train = pd.DataFrame(train_data[OUTPUT_NAMES], columns=OUTPUT_NAMES)
X_test = pd.DataFrame(test_data[INPUT_NAMES], columns=INPUT_NAMES)
y_test = pd.DataFrame(test_data[OUTPUT_NAMES], columns=OUTPUT_NAMES)

input_transform = StandardScaler(X_train.values)

def partial_log_transform(x):
    return jnp.hstack([x[0], jnp.log(x[1]), jnp.log(x[2]), x[3]])

def partial_exp_transform(y):
    return jnp.array([y[0], jnp.exp(y[1]), jnp.exp(y[2]), y[3]])

output_transform = StandardScaler(y_train.values, pretransform_forward=partial_log_transform, pretransform_inverse=partial_exp_transform)
X_train_scaled = vmap(input_transform.forward)(X_train.values)
X_test_scaled = vmap(input_transform.forward)(X_test.values)
y_train_scaled = vmap(output_transform.forward)(y_train.values)
y_test_scaled = vmap(output_transform.forward)(y_test.values)

The data are now imported and standardized. Let’s visualize to make sure things look good:

Hide code cell source
fig, ax = plt.subplots(1, X_train_scaled.shape[1], figsize=(15, 2), tight_layout=True, sharey=True)
fig.suptitle("Transformed input data")
for i, i_name in enumerate(INPUT_NAMES):
    ax[i].hist(X_train_scaled[:, i], bins=20)
    ax[i].set_title(i_name)
    ax[i].set_yticks([])
    sns.despine(trim=True, ax=ax[i], left=True)

fig, ax = plt.subplots(1, y_train_scaled.shape[1], figsize=(8, 2), tight_layout=True, sharey=True)
fig.suptitle("Transformed output data")
for i, o_name in enumerate(OUTPUT_NAMES):
    ax[i].hist(y_train_scaled[:, i], bins=20)
    ax[i].set_title(o_name)
    ax[i].set_yticks([])
    sns.despine(trim=True, ax=ax[i], left=True)
../../_images/9dfb613dbdae49ca74edf0a306b25b2c4c508ef1b893e9db3cbf32202e62ace9.svg ../../_images/9cbac3ce60e7061e2e27b1bbf590367a3092d0aa03d606a63d2b5219eb30d802.svg

Nice! The input/output data are all more-or-less evenly distributed and standarized. Now on to building the Gaussian process surrogate.

Gaussian process regression#

We will use GPJax to construct 4 different GP surrogates, one for each output. We’ll write the math down for one of the outputs, but the same applies to the other outputs.

Prior#

First, let’s create the GP prior with zero mean, i.e.,

\[ m(x) = 0 \]

and a radial basis function (RBF) as the covariance kernel, i.e.,

\[ k(x, x') = \sigma^2 \exp\left(-\frac{1}{2} \sum_{i=1}^d \left(\frac{x_i - x'_i}{\ell_i}\right)^2\right) \]

where \(\sigma^2\) is the variance and \(\ell_i\) is the length scale for input dimension \(i\). We write the prior as

\[ f \sim \operatorname{GP}(m, k). \]

Here is how to do it in GPJax:

# Zero mean
mean = gpx.mean_functions.Zero()

# Set the lengthscale/variance to decent values - they will be optimized later.
kernel = gpx.kernels.RBF(lengthscale=jnp.ones(X_train_scaled.shape[1]), variance=1.0)

# GP prior
prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)

Likelihood#

We’ll use a Gaussian likelihood. For training data \(\mathcal{D}\equiv\{x_i, y_i = f(x_i)\}_{i=1}^n\), we have that

\[ y_i|f, x_i \sim \mathcal{N}(f(x_i), \sigma_n^2) \quad \text{independently for } i = 1, \ldots, N, \]

where \(f\) is the physical model and \(\sigma_n^2\) is the measurement noise variance.

Note that for surrogates of computer simulations, \(\sigma_n^2\) will typically be small. This is because, for most deterministic computer programs, running the program several times with the exact same inputs will yield identical results (up to machine precision).

Coding this in GPJax is straightworward:

likelihood = gpx.likelihoods.Gaussian(
    num_datapoints=y_train.shape[0],
    obs_stddev=0.001
)

Posterior#

We can now construct the GP posterior

\[ f | \mathcal{D} \sim \operatorname{GP}(m_{\text{post}}, k_{\text{post}}) \]
posterior = prior * likelihood

Optimizing hyperparameters#

For each GP posterior, the hyperparameters are the RBF kernel variance \(\sigma^2\), lengthscales \(\mathbf{\ell}=(\ell_1, \dots, \ell_d)\), and measurement noise variance \(\sigma_n^2\).

Let’s denote the hyperparameters as \(\psi=(\sigma^2, \mathbf{\ell}, \sigma_n^2)\). We’ll pick optimal \(\psi\) by maximizing the marginal log likelihood, i.e.,

\[ \psi^* = \arg\max_{\psi} ~ \log p(\mathbf{y} | \mathbf{X}, \psi). \]

GPJax provides the functions conjugate_mll and fit to do this. (Note that we do everything in a loop so each output gets its own GP surrogate.)

Hide code cell source
# Loss function: negative marginal log likelihood
negative_mll = nnx.jit(lambda p, d: -gpx.objectives.conjugate_mll(p, d))

def optimize_gps(posterior, datasets, num_iters=1000, batch_size=128, learning_rate=1e-2, verbose=True, *, key):
    """Optimizes the hyperparameters in the GP model (separately for each output)."""
    posteriors = {}
    loss_histories = {}
    
    # Do this in a loop to construct a GP for each output.
    for i, o_name in enumerate(OUTPUT_NAMES):
        key, subkey = jrandom.split(key)
        posteriors[o_name], loss_histories[o_name] = gpx.fit(
            model=posterior,
            objective=negative_mll,
            train_data=datasets[o_name],
            optim=optax.adam(learning_rate),  # Adam optimizer
            num_iters=num_iters,
            batch_size=batch_size,
            key=subkey,
            verbose=verbose
        )

    return posteriors, loss_histories

def create_datasets(size=None, X=X_train_scaled, y=y_train_scaled, output_names=OUTPUT_NAMES):
    """Creates a GPJax datasets (one for each output) out of a subset of the training data."""
    if size is None:
        size = X.shape[0]
    datasets = {}
    for i, o_name in enumerate(output_names):
        datasets[o_name] = gpx.Dataset(X[:size], y[:size, i:i+1])
    return datasets
# Create a separate dataset for each output (and put it all in a single dictionary)
datasets = create_datasets(size=200)

# Optimize the hyperparameters
key, subkey = jrandom.split(key)
posteriors, loss_histories = optimize_gps(posterior, datasets, key=subkey)

We now have an optimized GP for each output!

Predictive distribution: How to evaluate the GP?#

The question now is how do we use our GPs to make predictions? The predictive distribution for a new input \(\mathbf{x}^*\) is given by

\[\begin{split} p\big( f(\mathbf{x}^*) \big | \mathcal{D}) = \int \underbrace{p\big( f(\mathbf{x}^*) \big | f(\mathbf{x}) \big)}_{\substack{\text{conditional probability} \\ \text{of output at } \mathbf{x}^*}} \underbrace{p\big(f(\mathbf{x}) \big| \mathcal{D}\big)}_{\substack{\text{posterior GP} \\ \text{evaluated at } \mathbf{x}}} df(\mathbf{x}) \end{split}\]

GPJax provides the function predict which can be used to give the mean and variance of the predictive distribution \(p\big( f(\mathbf{x}^*) \big| \mathcal{D} \big)\). Or, we can sample from the predictive distribution using the function sample.

To facilitate making new predictions, let’s wrap predict and sample with the following helper class, which represents the surrogate:

Hide code cell source
class Surrogate(eqx.Module):
    models: dict[str, gpx.gps.AbstractPosterior]
    train_data: dict[str, gpx.Dataset]
    input_transform: StandardScaler = input_transform
    output_transform: StandardScaler = output_transform
    IN_NAMES: list[str] = eqx.field(default_factory=lambda: INPUT_NAMES)
    OUT_NAMES: list[str] = eqx.field(default_factory=lambda: OUTPUT_NAMES)
    
    def __call__(self, X):
        """
        Evaluate the model at X.
        """
        Z = vmap(self.input_transform.forward)(X)
        y = self.get_scaled_predictive_mean_stdev(Z)[0]  # Just the mean
        return vmap(self.output_transform.inverse)(y)

    def sample(self, X, num_samples, *, key):
        """Sample the output. Contains both aleatoric and epistemic uncertainty."""
        Z = vmap(self.input_transform.forward)(X)
        y = self.sample_scaled_predictive(Z, num_samples, key=key)
        return vmap(self.output_transform.inverse)(y)
    
    def sample_scaled_predictive(self, X, num_samples, *, key):
        """Sample the scaled output. Contains both aleatoric and epistemic uncertainty."""
        predictive_dists = self._get_predictive_dists(X)
        y = []
        for o_name in self.OUT_NAMES:
            samples = predictive_dists[o_name].sample(seed=key, sample_shape=(num_samples,))
            y.append(samples)
        y = jnp.concatenate(y, axis=1)
        return y
    
    def sample_latent(self, X, num_samples, *, key):
        """Sample the latent space of the model (i.e., a GP sample before passing through the likelihood). Contains only epistemic uncertainty."""
        latent_dists = self._get_latent_dists(X)
        f = []
        for o_name in self.OUT_NAMES:
            samples = latent_dists[o_name].sample(seed=key, sample_shape=(num_samples,))
            f.append(samples)
        f = jnp.concatenate(f, axis=1)
        return f
    
    def get_scaled_predictive_mean_stdev(self, X):
        """Get the mean and standard deviation of the scaled output. Contains both aleatoric and epistemic uncertainty."""
        predictive_dists = self._get_predictive_dists(X)
        y_mean = []
        y_stdev = []
        for o_name in self.OUT_NAMES:
            y_mean.append(predictive_dists[o_name].mean())
            y_stdev.append(predictive_dists[o_name].stddev())
        return jnp.stack(y_mean, axis=1), jnp.stack(y_stdev, axis=1)
    
    def get_latent_mean_stdev(self, X):
        """Get the mean and standard deviation of the latent distribution. Contains only epistemic uncertainty."""
        latent_dists = self._get_latent_dists(X)
        f_mean = []
        f_stdev = []
        for o_name in self.OUT_NAMES:
            f_mean.append(latent_dists[o_name].mean())
            f_stdev.append(latent_dists[o_name].stddev())
        return jnp.stack(f_mean, axis=1), jnp.stack(f_stdev, axis=1)
    
    def _get_latent_dists(self, X):
        latent_dists = {}
        for o_name in self.OUT_NAMES:
            latent_dists[o_name] = self.models[o_name].predict(X, self.train_data[o_name])
        return latent_dists
    
    def _get_predictive_dists(self, X):
        latent_dists = self._get_latent_dists(X)
        predictive_dists = {}
        for o_name in self.OUT_NAMES:
            predictive_dists[o_name] = self.models[o_name].likelihood(latent_dists[o_name])
        return predictive_dists
surrogate_ = Surrogate(
    models=posteriors,
    train_data=datasets,
    input_transform=input_transform,
    output_transform=output_transform
)

Now let’s make a prediction at a new input location.

# Select a random test point
x = X_test.values[0:1]
x_scaled = input_transform.forward(x)

# Get the mean and standard deviation of the scaled output
y_scaled_mean, y_scaled_stdev = surrogate_.get_scaled_predictive_mean_stdev(x_scaled)

print('The scaled predictions are:')
for i, o_name in enumerate(OUTPUT_NAMES):
    print(f"\t{o_name:<25}:  {y_scaled_mean[0, i]:.2f} +/- {y_scaled_stdev[0, i]:.2f}")
The scaled predictions are:
	needle_displacement_m    :  -1.38 +/- 0.32
	injection_time_s         :  -1.78 +/- 0.03
	max_acceleration_m_per_s2:  1.49 +/- 0.09
	max_deceleration_m_per_s2:  0.68 +/- 0.04

This is one advantage of GP surrogates—they provide uncertainty estimates!

Let’s visualize the (unscaled) output probability distributions. We can do this by sampling from the GP posterior:

# This samples the unscaled output.
y_latent_samples = surrogate_.sample(x, num_samples=1000, key=key)
Hide code cell source
fig, ax = plt.subplots(1, len(OUTPUT_NAMES), figsize=(10, 2), tight_layout=True)
for i, o_name in enumerate(OUTPUT_NAMES):
    ax[i].hist(y_latent_samples[:,i], bins=16, density=True)
    ax[i].set_title(o_name)
    ax[i].set_yticks([])
    sns.despine(trim=True, ax=ax[i], left=True)
../../_images/4c1d19ab8eaf3a4510c346d4501fb271f8b111b4c49dee2249ac46e2d791250e.svg

The uncertainty you see here is the epistemic uncertainty (or “lack-of-data” uncertainty). It should go down as we add more data points. We’ll look at this next.

Convergence: How much data do we need?#

Ideally, we want to use enough data so that the epistemic uncertainty is low. Let’s test how much data we need for this problem. We’ll create training datasets of sizes \(N=50, 500, 1000, 1500, 2000\) and observe the convergence of the trained GP.

dataset_sizes = [50, 500, 1000, 1500, 2000]
datasets_N = {size: create_datasets(size) for size in dataset_sizes}

surrogates_N = {}
posteriors_N = {}
loss_histories_N = {}

for N in dataset_sizes:
    print(f'Now training GP with dataset of size {N:<3} ... ', end='')

    key, subkey = jrandom.split(key)
    posteriors_N[N], loss_histories_N[N] = optimize_gps(posterior, datasets_N[N], verbose=False, key=subkey)

    surrogates_N[N] = Surrogate(
        models=posteriors_N[N],
        train_data=datasets_N[N],
        input_transform=input_transform,
        output_transform=output_transform
    )

    print('done.')
Now training GP with dataset of size 50  ... done.
Now training GP with dataset of size 500 ... done.
Now training GP with dataset of size 1000 ... done.
Now training GP with dataset of size 1500 ... done.
Now training GP with dataset of size 2000 ... done.

We now have trained GPs for each dataset size, stored in surrogates_N. Note how the epistemic uncertainty decreases as we add more data:

y_latent_samples_N = {}
for i, N in enumerate(dataset_sizes):
    y_latent_samples_N[N] = surrogates_N[N].sample_latent(x_scaled, num_samples=1000, key=key)
Hide code cell source
output_index = 1  # Select an output

fig, ax = plt.subplots(len(dataset_sizes), 1, figsize=(4, 4), tight_layout=True, sharex=True, sharey=False)
for i, N in enumerate(dataset_sizes):
    fig.suptitle(f'Epsitemic uncertainty for {OUTPUT_NAMES[output_index]} (scaled)')
    sns.kdeplot(y_latent_samples_N[N][:,output_index], ax=ax[i], lw=2)
    ax[i].set_title(f'N={N}')
    ax[i].set_yticks([])
    ax[i].set_xlabel('')
    ax[i].set_ylabel('')
    sns.despine(trim=True, ax=ax[i], left=True)
../../_images/c0f25bf18d439f1b1c1f39a1ca2ab008dc86a5836ef80d87e86580d1abfdcf2b.svg

Let’s visualize the convergence for each (scaled) output. Below, we plot the epistemic uncertainty as a vertical bar:

means_N = {}
stdevs_N = {}
for i, N in enumerate(dataset_sizes):
    means_N[N], stdevs_N[N] = surrogates_N[N].get_latent_mean_stdev(x_scaled)
Hide code cell source
fig, ax = plt.subplots(len(OUTPUT_NAMES), 1, figsize=(4, 4), tight_layout=True, sharey=False)
fig.suptitle('Convergence of epistemic uncertainty to zero')
for i, o_name in enumerate(OUTPUT_NAMES):
    ax[i].errorbar(dataset_sizes, [means_N[N][0, i] for N in dataset_sizes], yerr=[stdevs_N[N][0, i] for N in dataset_sizes], ls='-', lw=0.5, elinewidth=3, label=f'{o_name} (scaled)', color=f'C{i}')
    ax[i].set_title(o_name, color=f'C{i}')
    # ax[i].set_yticks([])
    if i == len(OUTPUT_NAMES) - 1:
        sns.despine(trim=True, ax=ax[i])
    else:
        sns.despine(trim=True, ax=ax[i], bottom=True)
        ax[i].set_xticks([])
ax[-1].set_xlabel('Dataset size');
../../_images/5f47c2b6b2779cbeb52bf46028cc8efb7adb2f8f1ab77f928ea313dd0fbde64b.svg

For \(N>1000\) the epistemic uncertainty is low for all outputs. We’ll use the GP trained with 1000 points for subsequent analysis.

surrogate = surrogates_N[1000]

Diagnostics: How good is the surrogate?#

The next question is: How accurate is our surrogate? We will use our test dataset to find out.

Parity plot#

Let’s plot the predicted output against the true output (for the test dataset):

means_train, stdevs_train = surrogate.get_scaled_predictive_mean_stdev(X_train_scaled[:1000])
means_test, stdevs_test = surrogate.get_scaled_predictive_mean_stdev(X_test_scaled)
Hide code cell source
rmse = lambda y, y_hat: jnp.sqrt(jnp.mean((y - y_hat)**2))

fig, ax = plt.subplots(1, len(OUTPUT_NAMES), figsize=(10, 3), tight_layout=True)
fig.suptitle('Verification with training data', fontsize=16)
for i, o_name in enumerate(OUTPUT_NAMES):
    ax[i].errorbar(y_train_scaled[:1000, i], means_train[:, i], yerr=2 * stdevs_train[:, i], fmt='o', ms=2, alpha=0.3, lw=0.1)
    ax[i].plot(datasets[o_name].y, datasets[o_name].y, "r-", lw=0.5, zorder=100)
    ax[i].annotate(f"RMSE: {rmse(y_train_scaled[:1000, i], means_train[:, i]):.3f}", xy=(0.05, 0.9), xycoords='axes fraction')
    ax[i].set_title(o_name)
    sns.despine(trim=True, ax=ax[i])
ax[0].set_xlabel('True')
ax[0].set_ylabel('Predicted')

fig, ax = plt.subplots(1, len(OUTPUT_NAMES), figsize=(10, 3), tight_layout=True)
fig.suptitle('Validation with test data', fontsize=16)
for i, o_name in enumerate(OUTPUT_NAMES):
    ax[i].errorbar(y_test_scaled[:, i], means_test[:, i], yerr=2 * stdevs_test[:, i], fmt='o', ms=2, alpha=0.3, lw=0.1)
    ax[i].plot(datasets[o_name].y, datasets[o_name].y, "r-", lw=0.5, zorder=100)
    ax[i].annotate(f"RMSE: {rmse(y_test_scaled[:, i], means_test[:, i]):.3f}", xy=(0.05, 0.9), xycoords='axes fraction')
    ax[i].set_title(o_name)
    sns.despine(trim=True, ax=ax[i])
ax[0].set_xlabel('True')
ax[0].set_ylabel('Predicted');
../../_images/ba36f49b1c58244fa5370698235106f10e278dbddcb77141aab80db5bd49caae.svg ../../_images/7f2c9bfc7768eb2b7c0565b5138d17d2dc163b4e170cce56973a6e760776e4df.svg

As expected, we’ve fit the training data perfectly (all points lie on the red line). However, we mostly care about how well we predict unseen (test) data. The root mean square error (RSME) along with the parity plots show that some outputs are predicted better than others.

For GPs, parity plots and RSME values are good for checking if our predictive mean function is correct. However, they do not say anything about whether our uncertainty estimates are correct. For that, we must look at the standardized errors.

Standardized errors#

For the \(i^\text{th}\) test data point, the standardized error \(e_i\) is

\[ e_i = \frac{y_i - m(x_i)}{\sigma(x_i)} \]

where \(y_i\) is the true output, and \(m(x_i)\) and \(\sigma(x_i)\) are the mean and standard deviation of the prediction, respectively.

standardized_errors = (y_test_scaled - means_test) / stdevs_test

Let’s use these standardized errors to validate the GP’s uncertainty estimates.

Error histogram#

The basic GP model (which is what we have) assumes normally distributed errors. Let’s check this:

Hide code cell source
fig, ax = plt.subplots(1, len(OUTPUT_NAMES), figsize=(10, 3), tight_layout=True)
fig.suptitle('Standardized errors', fontsize=16)
for i, o_name in enumerate(OUTPUT_NAMES):
    ax[i].hist(standardized_errors[:, i], bins=20)
    ax[i].set_title(o_name)
    ax[i].set_yticks([])
    sns.despine(trim=True, ax=ax[i], left=True)
../../_images/d22b75bdbf6641332218133380566e6ecd6fdb52b38ac9b3164d5619374d340e.svg

QQ plot#

The histograms above look like normal distributions, but it’s sometimes hard to tell. A quantile-quantile (QQ) plot is another way to check normality:

quantiles = jnp.linspace(0.01, 0.99, 40)
normal_quantiles = jstats.norm.ppf(quantiles)
error_quantiles = jnp.quantile(standardized_errors, quantiles, axis=0)
Hide code cell source
fig, ax = plt.subplots(1, len(OUTPUT_NAMES), figsize=(10, 3), tight_layout=True)
fig.suptitle('QQ plot', fontsize=16)
for i, o_name in enumerate(OUTPUT_NAMES):
    ax[i].scatter(normal_quantiles, error_quantiles[:, i])
    ax[i].set_title(o_name)
    sns.despine(trim=True, ax=ax[i])
ax[0].set_xlabel('Normal quantiles')
ax[0].set_ylabel('Error quantiles');
../../_images/63c413dd0c91cc82d4d2476a12dd90f6acac37e6c294e70d748ccd6d43b9d33f.svg

The straighter the line, the more “normal” the errors are. They look decent.

Residuals plot#

Another assumption in our GP model is that the noise is homoscedastic—i.e., the same across all input/output values. We can check this assumption by plotting the error against the model prediction. This is called a residual plot. (Note that residuals is just another name for the errors.)

Hide code cell source
fig, ax = plt.subplots(1, len(OUTPUT_NAMES), figsize=(10, 3), tight_layout=True)
fig.suptitle('Residual plot', fontsize=16)
for i, o_name in enumerate(OUTPUT_NAMES):
    ax[i].scatter(means_test[:, i], standardized_errors[:, i], 1)
    ax[i].set_title(o_name)
    ax[i].axhline(-2, color='r', lw=0.5, label=r'95% CI')
    ax[i].axhline(2, color='r', lw=0.5)
    sns.despine(trim=True, ax=ax[i])
ax[0].set_xlabel('Predicted')
ax[0].set_ylabel('Standardized error')
ax[0].legend();
../../_images/0ac37f25496e62e8777237f221f74bcaec0f118fb36b1abcfe97a40290c9810a.svg

The variance in the errors seems to change with the output value. Similar to the parity plots, this basically tells us that the GP fit is somewhat off and we may need more training data. (Whether we actually use more data depends, of course, on how accurate we need our surrogate to be.)

Sensitivity analysis with surrogate#

Now that we have a trained and tested surrogate, there are many useful things we can do with it. As in the neural network surrogate hands-on activity, we’ll demonstrate using the surrogate to do Sobol sensitivity analysis.

Now, let’s create bounds for the inputs (based on physical intuition and/or literature values):

import SALib.sample.sobol as sobol
import SALib.analyze.sobol as analyze_sobol

input_bounds_dict = {
    'viscosity_cP': [1.0, 20.0],
    'fill_volume_mL': [1.0, 1.05],
    'air_gap_height_mm': [4.0, 5.0],
    'needle_length_mm': [8.0, 15.9],
    'needle_diameter_mm': [0.133, 0.21],
    'spring_force_N': [18.0, 36.0],
    'spring_constant_N_per_mm': [150.0, 250.0],
    'kappa5': [0.0, 3.0],
    'kappa6': [0.0, 4.0],
    'kappa7': [0.0, 0.1]
}
input_bounds = np.array([[input_bounds_dict[i][0], input_bounds_dict[i][1]] for i in INPUT_NAMES])
print('The input bounds are:\n', input_bounds)
The input bounds are:
 [[1.00e+00 2.00e+01]
 [1.00e+00 1.05e+00]
 [4.00e+00 5.00e+00]
 [8.00e+00 1.59e+01]
 [1.33e-01 2.10e-01]
 [1.80e+01 3.60e+01]
 [1.50e+02 2.50e+02]
 [0.00e+00 3.00e+00]
 [0.00e+00 4.00e+00]
 [0.00e+00 1.00e-01]]

Next, we create Sobol samples of the inputs and pass them through the surrogate:

problem = {
    'num_vars': len(INPUT_NAMES),
    'names': INPUT_NAMES,
    'bounds': input_bounds
}

# The number of samples to generate (should be a power of 2).
N = 512

# Generate the samples.
sobol_samples = sobol.sample(problem, N, calc_second_order=False)

# Evaluate the surrogate model at the Sobol samples.
sobol_outputs = surrogate(sobol_samples)

Finally, we calculate and plot the Sobol indices:

sobol_indices = {}
for i, o_name in enumerate(OUTPUT_NAMES):
    sobol_indices[o_name] = analyze_sobol.analyze(problem, sobol_outputs[:, i], calc_second_order=False, print_to_console=False)
Hide code cell source
# Plot sobol indices
fig, ax = plt.subplots(1, len(OUTPUT_NAMES), figsize=(15, 4), tight_layout=True, sharey=True)
fig.suptitle('Sobol indices', fontsize=16)
for i, o_name in enumerate(OUTPUT_NAMES):
    ax[i].bar(problem['names'], sobol_indices[o_name]['S1'])
    ax[i].set_title(o_name)
    for tick in ax[i].get_xticklabels():
        tick.set_rotation(90)
    sns.despine(trim=True, ax=ax[i])
../../_images/cbccde0c6db311763aa29e1735abff48b7e41cda1f479d346636ef9fbe5e2b84.svg

Excellent! We now know to which inputs the outputs are most sensitive. For example, we can see that the injection time is very sensitive to the drug viscosity and needle diameter, somewhat sensitive to needle length and injector spring force, and not sensitive to any other inputs. This information can be used, for example, to further study and understand the physics behind the model, or to investigate how identifiable each parameter is given an experimental dataset.

Remember that sensitivity analysis would not have been feasible without a surrogate model—the true physical model was just too expensive. With a surrogate however, we can do sensitivity analysis, design optimization, uncertainty quantification, etc. all at a reasonable computational cost.