Show code cell source
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks")
# Uncomment the next two lines if running for the book
# import warnings
# warnings.filterwarnings("ignore")
import jax
from jax import vmap
import jax.numpy as jnp
import jax.random as jrandom
import gpjax as gpx
import optax
import numpy as np
import pandas as pd
from functools import partial
import os
import scipy.stats.qmc as qmc
jax.config.update("jax_enable_x64", True)
key = jax.random.PRNGKey(0);
Multifidelity Gaussian process surrogates#
Example 1: Multi-fidelity regression of a synthetic function#
Suppose we have a high-fidelity model
(These function definitions are modified from Perdikaris et al. (2015).)
We want to create a surrogate for
Datasets#
Suppose we evaluate
where
Suppose we also evaluate
This is our low fidelity data. Let’s generate the data and visualize the functions.
N_LOW_FIDELITY = 50
N_HIGH_FIDELITY = 8
N_TEST = 100
def high_fidelity_model(x):
x1, x2 = x[0], x[1]
return 1/2*jnp.sin(5/2*x1 + 2/3*x2)**2 + 2/3*jnp.exp(-x1*(x2 - 0.5)**2)*jnp.cos(4*x1 + x2)**2
def low_fidelity_model(x):
x1, x2 = x[0], x[1]
return 2.5*high_fidelity_model(x) + 1/3*( jnp.sin(x1 + x2) + 1/2*jnp.exp(-x1)*jnp.sin(x1 + 7*x2) )
def generate_synthetic_data(n_l, n_h, n_test, key=None):
"""Returns datasets necessary for training and testing a multi-fidelity GP."""
# Train data
x_train_l_only = jnp.array(qmc.LatinHypercube(2).random(n=n_l - n_h))
x_train_h = jnp.array(qmc.LatinHypercube(2).random(n_h))
y_train_l_only = vmap(low_fidelity_model)(x_train_l_only)
y_train_l_common = vmap(low_fidelity_model)(x_train_h)
y_train_h = vmap(high_fidelity_model)(x_train_h)
x_train_l = jnp.concatenate([x_train_l_only, x_train_h], axis=0) # low fidelity inputs points should also include all the high fidelity input points
y_train_l = jnp.concatenate([y_train_l_only, y_train_l_common], axis=0)
# Test data
x_test = jrandom.uniform(key, (n_test, 2))
y_test_l = vmap(low_fidelity_model)(x_test)
y_test_h = vmap(high_fidelity_model)(x_test)
# D_low_only = gpx.Dataset(x_train_l_only, y_train_l_only[:, None])
D_low_common = gpx.Dataset(x_train_h, y_train_l_common[:, None])
D_low = gpx.Dataset(x_train_l, y_train_l[:, None])
D_high = gpx.Dataset(x_train_h, y_train_h[:, None])
D_low_test = gpx.Dataset(x_test, y_test_l[:, None])
D_high_test = gpx.Dataset(x_test, y_test_h[:, None])
return D_low_common, D_low, D_high, D_low_test, D_high_test
key, subkey = jrandom.split(key)
D_low_common, D_low, D_high, D_low_test, D_high_test = generate_synthetic_data(N_LOW_FIDELITY, N_HIGH_FIDELITY, N_TEST, key=subkey)
Show code cell source
fig = plt.figure()
ax = fig.add_subplot(121)
ax.scatter(D_low.X[:,0], D_low.X[:, 1], label=r'Low-fidelity data, $X_\ell$', marker='o', facecolors='none', edgecolors='tab:blue')
ax.scatter(D_high.X[:,0], D_high.X[:, 1], label=r'High-fidelity data, $X_h$', marker='o', facecolors='tab:orange', edgecolors='none')
ax.set_aspect('equal')
ax.set_title('Data locations')
ax.set_xlabel(r'$x_1$')
ax.set_ylabel(r'$x_2$')
leg = ax.legend()
leg.get_frame().set_alpha(0.6)
x1, x2 = jnp.linspace(0, 1, 100), jnp.linspace(0, 1, 100)
X1, X2 = jnp.meshgrid(x1, x2)
Y_low = vmap(low_fidelity_model)(jnp.stack([X1.ravel(), X2.ravel()], axis=-1)).reshape(X1.shape)
Y_high = vmap(high_fidelity_model)(jnp.stack([X1.ravel(), X2.ravel()], axis=-1)).reshape(X1.shape)
ax = fig.add_subplot(122, projection='3d')
ax.plot_surface(X1, X2, Y_low, label='Low fidelity', lw=0.1, alpha=0.5)
ax.plot_surface(X1, X2, Y_high, label='High fidelity', lw=0.1)
ax.view_init(elev=20, azim=-58)
ax.set_title('True functions')
ax.set_xlabel(r'$x_1$')
ax.set_ylabel(r'$x_2$')
ax.set_zlabel(r'$y$')
ax.legend()
sns.despine(trim=True);
For ease of notation, let’s group all the low-fidelity data together:
where
Low-fidelity Gaussian process#
We start by making a Gaussian process surrogate
We just set the covariance kernel
Let’s do this in GPJax:
# Construct low-fidelity GP
mean_low = gpx.mean_functions.Zero()
kernel_low = gpx.kernels.RBF(lengthscale=jnp.ones(2), variance=1.0)
prior_low = gpx.gps.Prior(mean_function=mean_low, kernel=kernel_low)
likelihood_low = gpx.likelihoods.Gaussian(
num_datapoints=D_low.n,
obs_stddev=1.0,
)
posterior_low = prior_low * likelihood_low
# Loss function
negative_mll = lambda p, d: -gpx.objectives.conjugate_mll(p, d)
# Optimize low-fidelity hyperparameters
posterior_low, history_low = gpx.fit(
model=posterior_low,
objective=negative_mll,
train_data=D_low,
optim=optax.adam(1e-2), # Adam optimizer
num_iters=2000,
batch_size=128,
key=subkey,
verbose=True
)
Show code cell output
Let’s visualize the fit to the low-fidelity data with a parity plot:
def generate_predictive_dist(x, posterior, train_data):
latent_dist = posterior.predict(x, train_data=train_data)
return posterior.likelihood(latent_dist)
predictive_dist_low = generate_predictive_dist(D_low_test.X, posterior_low, D_low)
mean_low_test = predictive_dist_low.mean()
std_low_test = predictive_dist_low.stddev()
Show code cell source
fig, ax = plt.subplots(figsize=(4,3))
ax.errorbar(D_low_test.y, mean_low_test, yerr=2 * std_low_test, fmt="o")
ax.plot(D_low_test.y, D_low_test.y, "k--")
ax.set_xlabel("Model (low-fidelity)")
ax.set_ylabel("Surrogate (low-fidelity)")
ax.set_title(r"Low-fidelity GP predictions vs low-fidelity data", fontsize=12)
sns.despine(trim=True);
It looks good. Now onto multi-fidelity.
Multi-fidelity Gaussian process#
The idea behind the multi-fidelity GP is this: Do GP regression on the high-fidelity data, but add a dimension to the input space which represents the output of the low-fidelity model.
More precisely, the multi-fidelity GP is
where the multi-fidelity covariance kernel
Note that
Also, note that in GPJax we will actually implement the equivalent
and train it on the data
You can think of this as a form of data augmentation (we simply augment the high-fidelity dataset with some low-fidelity model evaluations at the same points).
Let’s try this out. Here is a GP constructed with the new augmented dataset
Xm = jnp.concatenate([D_high.X, D_low_common.y], axis=1)
ym = D_high.y
D_multi = gpx.Dataset(Xm, ym)
# Construct multi-fidelity GP
mean_multi = gpx.mean_functions.Zero()
kernel_multi = gpx.kernels.RBF(lengthscale=jnp.ones(3), variance=1.0)
prior_multi = gpx.gps.Prior(mean_function=mean_multi, kernel=kernel_multi)
likelihood_multi = gpx.likelihoods.Gaussian(
num_datapoints=D_multi.n,
obs_stddev=1.0
)
posterior_multi = prior_multi * likelihood_multi
# Optimize multi-fidelity hyperparameters
posterior_multi, history_low = gpx.fit(
model=posterior_multi,
objective=negative_mll,
train_data=D_multi,
optim=optax.adam(1e-2), # Adam optimizer
num_iters=2000,
batch_size=128,
key=subkey,
verbose=True
)
Show code cell output
Note that this is now a 3D-input GP instead of a 2D-input GP (due to the augmented dimension).
Sampling from the posterior#
The next question is how do we sample the posterior multi-fidelity GP
First, sample the posterior low-fidelity GP:
x_test = jnp.array([[0.5, 0.5]])
predictive_dist_low_test = generate_predictive_dist(x_test, posterior_low, D_low)
key, subkey = jrandom.split(key)
samples_low_test = predictive_dist_low_test.sample((), subkey)
Next, augment the input like
.
x_augmented = jnp.concatenate([x_test, samples_low_test[:, None]], axis=-1)
Finally, sample the posterior multi-fidelity (augmented) GP:
predictive_dist_multi_test = generate_predictive_dist(x_augmented, posterior_multi, D_multi)
key, subkey = jrandom.split(key)
samples_multi_test = predictive_dist_multi_test.sample((), key)
print(f'The sample is: {samples_multi_test.squeeze():.4f}')
print(f'The true value is: {high_fidelity_model(x_test[0]):.4f}')
The sample is: 0.9073
The true value is: 0.9278
And that’s it! We now know how to construct and sample a multi-fidelity GP surrogate. Next, we’ll check our surrogate’s accuracy.
Before we do that, however, let’s train a vanilla GP on the high-fidelity data for comparison:
# Construct high-fidelity GP
mean_high = gpx.mean_functions.Zero()
kernel_high = gpx.kernels.RBF(lengthscale=jnp.ones(2), variance=1.0)
prior_high = gpx.gps.Prior(mean_function=mean_high, kernel=kernel_high)
likelihood_high = gpx.likelihoods.Gaussian(
num_datapoints=D_high.n,
obs_stddev=1.0
)
posterior_high = prior_high * likelihood_high
# Optimize multi-fidelity hyperparameters
posterior_high, history_high = gpx.fit(
model=posterior_high,
objective=negative_mll,
train_data=D_high,
optim=optax.adam(1e-2), # Adam optimizer
num_iters=2000,
batch_size=128,
key=subkey,
verbose=True
)
Show code cell output
Let’s also wrap the sampling code into functions (for organization):
def sample_multi_fidelity_gp(
X,
key,
num_samples,
posterior_low,
posterior_multi,
D_low,
D_multi,
):
"""Sample from the multi-fidelity Gaussian process.
Parameters
----------
X : ndarray
The test points.
posterior_low : gpx.gps.Posterior
The low-fidelity GP posterior.
posterior_multi : gpx.gps.Posterior
The multi-fidelity GP posterior.
key : PRNGKey
The random key.
num_samples : int
The number of samples to draw.
D_low : gpx.Dataset
The low-fidelity dataset.
D_multi : gpx.Dataset
The multi-fidelity dataset.
Returns
-------
ndarray
The samples from the multi-fidelity GP. The shape is (num_samples, num_test_points).
"""
predictive_dist_low = generate_predictive_dist(X, posterior_low, D_low)
samples_low_test = predictive_dist_low.sample((num_samples,), key)
single_augment_fn = lambda yl: jnp.concatenate([X, yl[:, None]], axis=-1)
x_aug = vmap(single_augment_fn)(samples_low_test) # This is now shape (num_samples, num_test_points, d)
x_aug_flat = x_aug.reshape(-1, X.shape[-1] + 1) # This is now shape (num_samples*num_test_points, d)
predictive_dist_multi = generate_predictive_dist(x_aug_flat, posterior_multi, D_multi)
samples_multi_test = predictive_dist_multi.sample((), key).reshape(num_samples, X.shape[0])
return samples_multi_test
def sample_vanilla_gp(
X,
key,
num_samples,
posterior,
D
):
"""Sample from a vanilla GP.
Parameters
----------
X : ndarray
The test points.
key : PRNGKey
The random key.
num_samples : int
The number of samples to draw.
posterior : gpx.gps.Posterior
The GP posterior.
D : gpx.Dataset
The dataset.
"""
predictive_dist = generate_predictive_dist(X, posterior, D)
samples = predictive_dist.sample((num_samples,), key)
return samples
Now, we are ready to see how well the GPs predict the test data. Let’s start by making parity plots:
sample_low_fidelity_gp = partial(sample_vanilla_gp, posterior=posterior_low, D=D_low)
sample_high_fidelity_gp = partial(sample_vanilla_gp, posterior=posterior_high, D=D_high)
sample_multi_fidelity_gp = partial(sample_multi_fidelity_gp, posterior_low=posterior_low, posterior_multi=posterior_multi, D_low=D_low, D_multi=D_multi)
num_samples = 50
X_test = D_high_test.X
samples_low = sample_low_fidelity_gp(X_test, key, num_samples)
samples_high = sample_high_fidelity_gp(X_test, key, num_samples)
samples_multi = sample_multi_fidelity_gp(X_test, key, num_samples)
mean_low_test = samples_low.mean(axis=0)
std_low_test = samples_low.std(axis=0)
mean_high_test = samples_high.mean(axis=0)
std_high_test = samples_high.std(axis=0)
mean_multi_test = samples_multi.mean(axis=0)
std_multi_test = samples_multi.std(axis=0)
Show code cell source
rmse = lambda y, y_hat: jnp.sqrt(jnp.mean((y - y_hat)**2))
# Parity plot - add uncertainty on predictions using whiskers
fig, ax = plt.subplots(1, 3, figsize=(10,4), tight_layout=True)
fig.suptitle("Surrogate predictions vs. high-fidelity test data", fontsize=20)
ax[0].errorbar(D_high_test.y, mean_low_test, yerr=2 * std_low_test, fmt="o", markersize=3, lw=0.5)
ax[0].plot(D_high_test.y, D_high_test.y, "r-", lw=1)
ax[0].set_xlabel("True model")
ax[0].set_ylabel("Surrogate")
ax[0].set_title(r"Low fidelity GP, $\hat{f}_\ell$", fontsize=14)
ax[0].annotate(f"RMSE: {rmse(D_high_test.y.squeeze(-1), mean_low_test):.3f}", xy=(0.05, 0.9), xycoords='axes fraction')
ax[1].errorbar(D_high_test.y, mean_high_test, yerr=2 * std_high_test, fmt="o", markersize=3, lw=0.5)
ax[1].plot(D_high_test.y, D_high_test.y, "r-", lw=1)
ax[1].set_xlabel("True model")
ax[1].set_ylabel("Surrogate")
ax[1].set_title(r"High fidelity GP, $\hat{f}_h$", fontsize=14)
ax[1].annotate(f"RMSE: {rmse(D_high_test.y.squeeze(-1), mean_high_test):.3f}", xy=(0.05, 0.9), xycoords='axes fraction')
ax[2].errorbar(D_high_test.y, mean_multi_test, yerr=2 * std_multi_test, fmt="o", markersize=3, lw=0.5)
ax[2].plot(D_high_test.y, D_high_test.y, "r-", lw=1)
ax[2].set_xlabel("True model")
ax[2].set_ylabel("Surrogate")
ax[2].set_title(r"Multi fidelity GP, $\hat{f}_m$", fontsize=14)
ax[2].annotate(f"RMSE: {rmse(D_high_test.y.squeeze(-1), mean_multi_test):.3f}", xy=(0.05, 0.9), xycoords='axes fraction')
sns.despine(trim=True);
The multi-fidelity GP has the most accurate predictions. Let’s visualize the mean predictive surface against the ground truth:
Show code cell source
# Surface plot
x1, x2 = jnp.linspace(0, 1, 20), jnp.linspace(0, 1, 20)
X1, X2 = jnp.meshgrid(x1, x2)
X_plt = jnp.stack([X1.ravel(), X2.ravel()], axis=-1)
Y_high_plt = vmap(high_fidelity_model)(X_plt).reshape(X1.shape)
num_samples = 20
mean_multi_plt = sample_multi_fidelity_gp(X_plt, key, num_samples).mean(axis=0).reshape(X1.shape)
fig = plt.figure(figsize=(5,4))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X1, X2, mean_multi_plt, label=r'Mean of multi-fidelity GP, $\mathbb{E}[\hat{f}_m]$', color='tab:green', lw=0.1, alpha=0.3)
ax.plot_surface(X1, X2, Y_high_plt, label=r'True function, $f_h$', color='tab:blue', lw=0.1, alpha=0.7)
ax.view_init(elev=20, azim=-58)
ax.set_title('Mean predictive surface vs ground truth', fontsize=14)
ax.set_xlabel(r'$x_1$')
ax.set_ylabel(r'$x_2$')
ax.set_zlabel(r'$y$')
ax.legend();
The surfaces are basically right on top of each other! Let’s take a closer look at a vertical “slice” of the surface plot. Fix
# Slice plot
x1_slice = jnp.linspace(0, 1, 40)
x2_slice = jnp.full_like(x1_slice, 0.4)
X_slice = jnp.stack([x1_slice, x2_slice], axis=-1)
num_samples = 120
samples_low = sample_low_fidelity_gp(X_slice, key, num_samples)
samples_high = sample_high_fidelity_gp(X_slice, key, num_samples)
samples_multi = sample_multi_fidelity_gp(X_slice, key, num_samples)
mean_low_slice = samples_low.mean(axis=0)
std_low_slice = samples_low.std(axis=0)
mean_high_slice = samples_high.mean(axis=0)
std_high_slice = samples_high.std(axis=0)
mean_multi_slice = samples_multi.mean(axis=0)
std_multi_slice = samples_multi.std(axis=0)
Show code cell source
fig, ax = plt.subplots(figsize=(4,3))
ax.set_title("Slice through the predictive surfaces", fontsize=14)
ax.plot(x1_slice, mean_low_slice, label=r"Low-fidelity GP, $\hat{f}_\ell$", lw=2)
ax.fill_between(x1_slice, mean_low_slice - 2 * std_low_slice, mean_low_slice + 2 * std_low_slice, alpha=0.2)
ax.plot(x1_slice, mean_high_slice, label=r"High-fidelity GP, $\hat{f}_h$", lw=2)
ax.fill_between(x1_slice, mean_high_slice - 2 * std_high_slice, mean_high_slice + 2 * std_high_slice, alpha=0.2)
ax.plot(x1_slice, mean_multi_slice, label=r"Multi-fidelity GP, $\hat{f}_m$", lw=2)
ax.fill_between(x1_slice, mean_multi_slice - 2 * std_multi_slice, mean_multi_slice + 2 * std_multi_slice, alpha=0.2)
ax.plot(x1_slice, vmap(high_fidelity_model)(X_slice), label=r"True function, $f_h$", linestyle="--", color="black")
ax.set_xlabel(r'$x_1$')
ax.legend()
sns.despine(trim=True);
Again, the multi-fidelity GP
Questions#
Decrease
N_LOW_FIDELITY
. How does the multi-fidelity GP perform with less low-fidelity data?Decrease
N_HIGH_FIDELITY
. How does multi-fidelity GP perform with less high-fidelity data?Increase
N_HIGH_FIDELITY
. At what point is the high-fidelity-only GP just as good as the multi-fidelity GP ?Add more terms (sin/cos, exponential, quadratic, or whetever you want) to
low_fidelity_model
. How different can the low-fidelity model be from the high-fidelity model and still get a good surrogate ?
Example 2: Stochastic incompressible flow past a cylinder#
This example is taken from Perdikaris et al. (2015). Suppose you have a flow past a cylinder, subject to random inflow boundary conditions of the form
where
Figure 9 from Perdikaris et al. (2015). The left figure shows the spatial discretization used for CFD, while the right figure shows the quantity of interest (i.e., the base pressure coefficient
Quantity of interest: Superquantile risk of base pressure coefficient#
Note that for each inflow condition
Computing the risk, however, is expensive—evaluating the risk once requires many CFD simulations.
Our goal therefore is to somehow speed up evaluation of
Datasets: Probabilistic collocation (high-fidelity) and Monte Carlo (low-fidelity)#
As described in the paper, there are two different models that approximate
First, import and visualize the data:
data_folder = '../../data/mf_cylinder_flow'
high_fid_data = pd.read_csv(os.path.join(data_folder, 'high_fidelity_data.csv'))
low_fid_data = pd.read_csv(os.path.join(data_folder, 'low_fidelity_data.csv'))
ground_truth = pd.read_csv(os.path.join(data_folder, 'ground_truth.csv'))
D_high_cyl = gpx.Dataset(jnp.array(high_fid_data[['x1', 'x2']].values), jnp.array(high_fid_data['y'].values[:, None]))
D_low_cyl = gpx.Dataset(jnp.array(low_fid_data[['x1', 'x2']].values), jnp.array(low_fid_data['y'].values[:, None]))
gt = ground_truth.values.reshape(21, 17, 3)
Show code cell source
fig = plt.figure()
ax = fig.add_subplot(121)
ax.scatter(D_low_cyl.X[:,0], D_low_cyl.X[:, 1], label=r'Low-fidelity data, $X_\ell$', marker='o', facecolors='none', edgecolors='tab:blue')
ax.scatter(D_high_cyl.X[:,0], D_high_cyl.X[:, 1], label=r'High-fidelity data, $X_h$', marker='o', facecolors='tab:orange', edgecolors='none')
ax.set_aspect(5)
ax.set_title('Data locations', fontsize=12)
ax.set_xlabel(r'$\sigma_1$')
ax.set_ylabel(r'$\sigma_2$')
leg = ax.legend(loc='upper right')
leg.get_frame()
x1, x2 = jnp.linspace(0, 1, 100), jnp.linspace(0, 1, 100)
X1, X2 = jnp.meshgrid(x1, x2)
Y_low = vmap(low_fidelity_model)(jnp.stack([X1.ravel(), X2.ravel()], axis=-1)).reshape(X1.shape)
Y_high = vmap(high_fidelity_model)(jnp.stack([X1.ravel(), X2.ravel()], axis=-1)).reshape(X1.shape)
ax = fig.add_subplot(122, projection='3d')
ax.plot_surface(gt[:,:,0], gt[:,:,1], gt[:,:,2], lw=0.1, cmap='viridis', alpha=0.5)
ax.scatter3D(D_low_cyl.X[:,0], D_low_cyl.X[:, 1], D_low_cyl.y[:, 0], s=4, label='Low fidelity', color='tab:blue', alpha=0.7)
ax.scatter3D(D_high_cyl.X[:,0], D_high_cyl.X[:, 1], D_high_cyl.y[:, 0], s=4, label='High fidelity', color='tab:orange', alpha=1)
ax.view_init(elev=20, azim=-75)
ax.set_title(r'High-fidelity model for $\mathcal{R}_{0.6}[C_\text{BP}]$', fontsize=12)
ax.set_xlabel(r'$\sigma_1$')
ax.set_ylabel(r'$\sigma_2$')
sns.despine(trim=True);
Multi-fidelity Gaussian process for superquantile risk#
As before, we first construct the low-fidelity GP surrogate
# Construct low-fidelity GP
mean_low_cyl = gpx.mean_functions.Zero()
kernel_low_cyl = gpx.kernels.RBF(lengthscale=jnp.ones(2), variance=0.01)
prior_low_cyl = gpx.gps.Prior(mean_function=mean_low_cyl, kernel=kernel_low_cyl)
likelihood_low_cyl = gpx.likelihoods.Gaussian(
num_datapoints=D_low_cyl.n,
obs_stddev=0.0,
)
posterior_low_cyl = prior_low_cyl * likelihood_low_cyl
# Optimize low-fidelity hyperparameters
posterior_low_cyl, history_low_cyl = gpx.fit_scipy(
model=posterior_low_cyl,
objective=negative_mll,
train_data=D_low_cyl,
max_iters=1000
)
Show code cell output
Optimization terminated successfully.
Current function value: -496.380070
Iterations: 25
Function evaluations: 30
Gradient evaluations: 30
Next let’s prepare the data for training the multi-fidelity GP surrogate
Easier option: Evaluate the low-fidelity model (or the low-fidelity surrogate, if it’s good enough) at the missing points.
Harder option: Formulate the multi-fidelity GP in a way that does not require low-fidelity evaluations at the same points (see Le Gratiet (2013)).
We will do the easier option:
# Evaluate the low-fidelity surrogate at the high-fidelity data points
y_low_common_cyl = generate_predictive_dist(D_high_cyl.X, posterior_low, D_low_cyl).mean()
# Create the augmented dataset for training the multi-fidelity GP
Xm_cyl = jnp.concatenate([D_high_cyl.X, y_low_common_cyl[:, None]], axis=1)
D_multi_cyl = gpx.Dataset(Xm_cyl, D_high_cyl.y)
Now that we have the augmented dataset, let’s construct the multi-fidelity GP
# Construct multi-fidelity GP
mean_multi_cyl = gpx.mean_functions.Zero()
kernel_multi_cyl = gpx.kernels.RBF(lengthscale=jnp.ones(3), variance=0.01)
prior_multi_cyl = gpx.gps.Prior(mean_function=mean_multi_cyl, kernel=kernel_multi_cyl)
likelihood_multi_cyl = gpx.likelihoods.Gaussian(
num_datapoints=D_multi_cyl.n,
obs_stddev=0.1
)
posterior_multi_cyl = prior_multi_cyl * likelihood_multi_cyl
# Optimize multi-fidelity hyperparameters
posterior_multi_cyl, history_multi_cyl = gpx.fit_scipy(
model=posterior_multi_cyl,
objective=negative_mll,
train_data=D_multi_cyl,
max_iters=1000
)
# Construct high-fidelity GP
mean_high_cyl = gpx.mean_functions.Zero()
kernel_high_cyl = gpx.kernels.RBF(lengthscale=jnp.ones(2), variance=0.1)
prior_high_cyl = gpx.gps.Prior(mean_function=mean_high_cyl, kernel=kernel_high_cyl)
likelihood_high_cyl = gpx.likelihoods.Gaussian(
num_datapoints=D_high_cyl.n,
obs_stddev=0.1
)
posterior_high_cyl = prior_high_cyl * likelihood_high_cyl
# Optimize high-fidelity hyperparameters
posterior_high_cyl, history_high_cyl = gpx.fit_scipy(
model=posterior_high_cyl,
objective=negative_mll,
train_data=D_high_cyl,
max_iters=1000
)
Show code cell output
Optimization terminated successfully.
Current function value: -22.822893
Iterations: 64
Function evaluations: 72
Gradient evaluations: 72
Optimization terminated successfully.
Current function value: -18.071678
Iterations: 34
Function evaluations: 37
Gradient evaluations: 37
As before, let’s visualize the predictive accuracy with some parity plots:
Show code cell source
sample_low_fidelity_gp_cyl = partial(sample_vanilla_gp, posterior=posterior_low_cyl, D=D_low_cyl)
sample_high_fidelity_gp_cyl = partial(sample_vanilla_gp, posterior=posterior_high_cyl, D=D_high_cyl)
sample_multi_fidelity_gp_cyl = partial(sample_multi_fidelity_gp, posterior_low=posterior_low_cyl, posterior_multi=posterior_multi_cyl, D_low=D_low_cyl, D_multi=D_multi_cyl)
num_samples_cyl = 10
_ind_test_cyl = np.random.choice(ground_truth.index, 100, replace=False)
X_test_cyl = ground_truth[['x1', 'x2']].values[_ind_test_cyl]
y_test_cyl = ground_truth['y'].values[:, None][_ind_test_cyl]
key, key_low_cyl, key_high_cyl, key_multi_cyl = jrandom.split(key, 4)
samples_low_cyl = sample_low_fidelity_gp_cyl(X_test_cyl, key_low_cyl, num_samples_cyl)
samples_high_cyl = sample_high_fidelity_gp_cyl(X_test_cyl, key_high_cyl, num_samples_cyl)
samples_multi_cyl = sample_multi_fidelity_gp_cyl(X_test_cyl, key_multi_cyl, num_samples_cyl)
mean_low_test_cyl = samples_low_cyl.mean(axis=0)
std_low_test_cyl = samples_low_cyl.std(axis=0)
mean_high_test_cyl = samples_high_cyl.mean(axis=0)
std_high_test_cyl = samples_high_cyl.std(axis=0)
mean_multi_test_cyl = samples_multi_cyl.mean(axis=0)
std_multi_test_cyl = samples_multi_cyl.std(axis=0)
# Parity plot - add uncertainty on predictions using whiskers
fig, ax = plt.subplots(1, 3, figsize=(10,4), tight_layout=True)
fig.suptitle("Surrogate predictions vs. high-fidelity test data", fontsize=20)
ax[0].errorbar(y_test_cyl, mean_low_test_cyl, yerr=2 * std_low_test_cyl, fmt="o", markersize=3, lw=0.5)
ax[0].plot(y_test_cyl, y_test_cyl, "k--")
ax[0].set_xlabel("True model")
ax[0].set_ylabel("Surrogate")
ax[0].set_title("Low fidelity GP", fontsize=14)
ax[0].annotate(f"RMSE: {rmse(y_test_cyl.squeeze(-1), mean_low_test_cyl):.3f}", xy=(0.05, 0.9), xycoords='axes fraction')
ax[1].errorbar(y_test_cyl, mean_high_test_cyl, yerr=2 * std_high_test_cyl, fmt="o", markersize=3, lw=0.5)
ax[1].plot(y_test_cyl, y_test_cyl, "k--")
ax[1].set_xlabel("True model")
ax[1].set_ylabel("Surrogate")
ax[1].set_title("High fidelity GP", fontsize=14)
ax[1].annotate(f"RMSE: {rmse(y_test_cyl.squeeze(-1), mean_high_test_cyl):.3f}", xy=(0.05, 0.9), xycoords='axes fraction')
ax[2].errorbar(y_test_cyl, mean_multi_test_cyl, yerr=2 * std_multi_test_cyl, fmt="o", markersize=3, lw=0.5)
ax[2].plot(y_test_cyl, y_test_cyl, "k--")
ax[2].set_xlabel("True model")
ax[2].set_ylabel("Surrogate")
ax[2].set_title("Multi fidelity GP", fontsize=14)
ax[2].annotate(f"RMSE: {rmse(y_test_cyl.squeeze(-1), mean_multi_test_cyl):.3f}", xy=(0.05, 0.9), xycoords='axes fraction')
sns.despine(trim=True);
The multi-fidelity GP has the best predictive accuracy. Let’s visualize the response surface of the surrogate vs. the true high-fidelity model:
Show code cell source
# Surface plot
x1_cyl, x2_cyl = jnp.linspace(ground_truth['x1'].min(), ground_truth['x1'].max(), 10), jnp.linspace(ground_truth['x2'].min(), ground_truth['x2'].max(), 10)
X1_cyl, X2_cyl = jnp.meshgrid(x1_cyl, x2_cyl)
X_plt_cyl = jnp.stack([X1_cyl.ravel(), X2_cyl.ravel()], axis=-1)
Y_high_plt_cyl = vmap(high_fidelity_model)(X_plt_cyl).reshape(X1_cyl.shape)
num_samples_cyl = 15
key, key_cyl = jrandom.split(key)
# TODO: Figure out why sampling is so slow!
mean_multi_plt_cyl = sample_multi_fidelity_gp_cyl(X_plt_cyl, key_cyl, num_samples_cyl).mean(axis=0).reshape(X1_cyl.shape)
fig = plt.figure(figsize=(5,4))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(gt[:,:,0], gt[:,:,1], gt[:,:,2], lw=0.1, color='tab:blue', alpha=0.7)
ax.plot_surface(X1_cyl, X2_cyl, mean_multi_plt_cyl, label='Multi fidelity', color='tab:green', lw=0.1, alpha=0.3)
ax.view_init(elev=20, azim=-75)
ax.set_title('Mean predictive surface vs ground truth', fontsize=14)
ax.set_xlabel(r'$x_1$')
ax.set_ylabel(r'$x_2$')
ax.set_zlabel(r'$y$')
ax.legend();
The surfaces are almost right on top of each other.