Show code cell source
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks")
# Uncomment the next two lines if running for the book
# import warnings
# warnings.filterwarnings("ignore")
import jax
from jax import vmap
import jax.numpy as jnp
import jax.random as jrandom
import gpjax as gpx
import optax
import numpy as np
import pandas as pd
from functools import partial
import os
import scipy.stats.qmc as qmc
jax.config.update("jax_enable_x64", True)
key = jax.random.PRNGKey(0);
Multifidelity Gaussian process surrogates#
Example 1: Multi-fidelity regression of a synthetic function#
Suppose we have a high-fidelity model \(f_h\) and a low-fidelity model \(f_\ell\) of some phemonenon, given by
(These function definitions are modified from Perdikaris et al. (2015).)
We want to create a surrogate for \(f_h\).
Datasets#
Suppose we evaluate \(f_h\) at a limited number of points \(x_h\):
where \(N_h\) is the number of evaluations, \(d\) is the input dimension, and \(x_{h,i} \in \mathbb{R}^d\) and \(y_{h,i} \in \mathbb{R}\) are the input and output for the \(i^\text{th}\) evaluation, respectively. This is our high-fidelity data.
Suppose we also evaluate \(f_\ell\) at the same points \(x_h\) and at some different points \(x_\ell' \in \mathbb{R}^{N_\ell' \times d}\):
This is our low fidelity data. Let’s generate the data and visualize the functions.
N_LOW_FIDELITY = 50
N_HIGH_FIDELITY = 8
N_TEST = 100
def high_fidelity_model(x):
x1, x2 = x[0], x[1]
return 1/2*jnp.sin(5/2*x1 + 2/3*x2)**2 + 2/3*jnp.exp(-x1*(x2 - 0.5)**2)*jnp.cos(4*x1 + x2)**2
def low_fidelity_model(x):
x1, x2 = x[0], x[1]
return 2.5*high_fidelity_model(x) + 1/3*( jnp.sin(x1 + x2) + 1/2*jnp.exp(-x1)*jnp.sin(x1 + 7*x2) )
def generate_synthetic_data(n_l, n_h, n_test, key=None):
"""Returns datasets necessary for training and testing a multi-fidelity GP."""
# Train data
x_train_l_only = jnp.array(qmc.LatinHypercube(2).random(n=n_l - n_h))
x_train_h = jnp.array(qmc.LatinHypercube(2).random(n_h))
y_train_l_only = vmap(low_fidelity_model)(x_train_l_only)
y_train_l_common = vmap(low_fidelity_model)(x_train_h)
y_train_h = vmap(high_fidelity_model)(x_train_h)
x_train_l = jnp.concatenate([x_train_l_only, x_train_h], axis=0) # low fidelity inputs points should also include all the high fidelity input points
y_train_l = jnp.concatenate([y_train_l_only, y_train_l_common], axis=0)
# Test data
x_test = jrandom.uniform(key, (n_test, 2))
y_test_l = vmap(low_fidelity_model)(x_test)
y_test_h = vmap(high_fidelity_model)(x_test)
# D_low_only = gpx.Dataset(x_train_l_only, y_train_l_only[:, None])
D_low_common = gpx.Dataset(x_train_h, y_train_l_common[:, None])
D_low = gpx.Dataset(x_train_l, y_train_l[:, None])
D_high = gpx.Dataset(x_train_h, y_train_h[:, None])
D_low_test = gpx.Dataset(x_test, y_test_l[:, None])
D_high_test = gpx.Dataset(x_test, y_test_h[:, None])
return D_low_common, D_low, D_high, D_low_test, D_high_test
key, subkey = jrandom.split(key)
D_low_common, D_low, D_high, D_low_test, D_high_test = generate_synthetic_data(N_LOW_FIDELITY, N_HIGH_FIDELITY, N_TEST, key=subkey)
Show code cell source
fig = plt.figure()
ax = fig.add_subplot(121)
ax.scatter(D_low.X[:,0], D_low.X[:, 1], label=r'Low-fidelity data, $X_\ell$', marker='o', facecolors='none', edgecolors='tab:blue')
ax.scatter(D_high.X[:,0], D_high.X[:, 1], label=r'High-fidelity data, $X_h$', marker='o', facecolors='tab:orange', edgecolors='none')
ax.set_aspect('equal')
ax.set_title('Data locations')
ax.set_xlabel(r'$x_1$')
ax.set_ylabel(r'$x_2$')
leg = ax.legend()
leg.get_frame().set_alpha(0.6)
x1, x2 = jnp.linspace(0, 1, 100), jnp.linspace(0, 1, 100)
X1, X2 = jnp.meshgrid(x1, x2)
Y_low = vmap(low_fidelity_model)(jnp.stack([X1.ravel(), X2.ravel()], axis=-1)).reshape(X1.shape)
Y_high = vmap(high_fidelity_model)(jnp.stack([X1.ravel(), X2.ravel()], axis=-1)).reshape(X1.shape)
ax = fig.add_subplot(122, projection='3d')
ax.plot_surface(X1, X2, Y_low, label='Low fidelity', lw=0.1, alpha=0.5)
ax.plot_surface(X1, X2, Y_high, label='High fidelity', lw=0.1)
ax.view_init(elev=20, azim=-58)
ax.set_title('True functions')
ax.set_xlabel(r'$x_1$')
ax.set_ylabel(r'$x_2$')
ax.set_zlabel(r'$y$')
ax.legend()
sns.despine(trim=True);
For ease of notation, let’s group all the low-fidelity data together:
where \(N_\ell = N_h + N_\ell'\) is the total number of low-fidelity simulations.
Low-fidelity Gaussian process#
We start by making a Gaussian process surrogate \(\hat{f}_\ell\) using just the low-fidelity data \(\mathcal{D}_\ell = (\mathbf{X}_\ell, \mathbf{y}_\ell)\):
We just set the covariance kernel \(k\) as a radial basis function (RBF) kernel whose variance, \(\sigma^2\), and lengthscale, \(\ell_d\), are found by maximizing the marginal log likelihood, i.e.,
Let’s do this in GPJax:
# Construct low-fidelity GP
mean_low = gpx.mean_functions.Zero()
kernel_low = gpx.kernels.RBF(lengthscale=jnp.ones(2), variance=1.0)
prior_low = gpx.gps.Prior(mean_function=mean_low, kernel=kernel_low)
likelihood_low = gpx.likelihoods.Gaussian(
num_datapoints=D_low.n,
obs_stddev=1.0,
)
posterior_low = prior_low * likelihood_low
# Loss function
negative_mll = lambda p, d: -gpx.objectives.conjugate_mll(p, d)
# Optimize low-fidelity hyperparameters
posterior_low, history_low = gpx.fit(
model=posterior_low,
objective=negative_mll,
train_data=D_low,
optim=optax.adam(1e-2), # Adam optimizer
num_iters=2000,
batch_size=128,
key=subkey,
verbose=True
)
Show code cell output
Let’s visualize the fit to the low-fidelity data with a parity plot:
def generate_predictive_dist(x, posterior, train_data):
latent_dist = posterior.predict(x, train_data=train_data)
return posterior.likelihood(latent_dist)
predictive_dist_low = generate_predictive_dist(D_low_test.X, posterior_low, D_low)
mean_low_test = predictive_dist_low.mean()
std_low_test = predictive_dist_low.stddev()
Show code cell source
fig, ax = plt.subplots(figsize=(4,3))
ax.errorbar(D_low_test.y, mean_low_test, yerr=2 * std_low_test, fmt="o")
ax.plot(D_low_test.y, D_low_test.y, "k--")
ax.set_xlabel("Model (low-fidelity)")
ax.set_ylabel("Surrogate (low-fidelity)")
ax.set_title(r"Low-fidelity GP predictions vs low-fidelity data", fontsize=12)
sns.despine(trim=True);
It looks good. Now onto multi-fidelity.
Multi-fidelity Gaussian process#
The idea behind the multi-fidelity GP is this: Do GP regression on the high-fidelity data, but add a dimension to the input space which represents the output of the low-fidelity model.
More precisely, the multi-fidelity GP is
where the multi-fidelity covariance kernel \(k_m\) is
Note that \(k_m\) itself has a Gaussian process \((\hat{f}_\ell)\) inside of it! This is how low-fidelity information is passed to \(\hat{f}_m\). Intuitively, the multi-fidelity GP automatically “learns” how to correct the low-fidelity model’s predictions.
Also, note that in GPJax we will actually implement the equivalent \((d+1)\)-dimensional augmented GP
and train it on the data
You can think of this as a form of data augmentation (we simply augment the high-fidelity dataset with some low-fidelity model evaluations at the same points).
Let’s try this out. Here is a GP constructed with the new augmented dataset \(\mathcal{D}_m = (\mathbf{X}_m\), \(\mathbf{y}_m)\):
Xm = jnp.concatenate([D_high.X, D_low_common.y], axis=1)
ym = D_high.y
D_multi = gpx.Dataset(Xm, ym)
# Construct multi-fidelity GP
mean_multi = gpx.mean_functions.Zero()
kernel_multi = gpx.kernels.RBF(lengthscale=jnp.ones(3), variance=1.0)
prior_multi = gpx.gps.Prior(mean_function=mean_multi, kernel=kernel_multi)
likelihood_multi = gpx.likelihoods.Gaussian(
num_datapoints=D_multi.n,
obs_stddev=1.0
)
posterior_multi = prior_multi * likelihood_multi
# Optimize multi-fidelity hyperparameters
posterior_multi, history_low = gpx.fit(
model=posterior_multi,
objective=negative_mll,
train_data=D_multi,
optim=optax.adam(1e-2), # Adam optimizer
num_iters=2000,
batch_size=128,
key=subkey,
verbose=True
)
Show code cell output
Note that this is now a 3D-input GP instead of a 2D-input GP (due to the augmented dimension).
Sampling from the posterior#
The next question is how do we sample the posterior multi-fidelity GP \(\hat{f}_m | \mathcal{D}_m\) at some test input \(x_\text{test}?\) Here is how:
First, sample the posterior low-fidelity GP:
\(~~~~~~~\) Then, evaluate at the test point:
x_test = jnp.array([[0.5, 0.5]])
predictive_dist_low_test = generate_predictive_dist(x_test, posterior_low, D_low)
key, subkey = jrandom.split(key)
samples_low_test = predictive_dist_low_test.sample((), subkey)
Next, augment the input like \(x_\text{test}^\text{aug} = [ x_\text{test} \quad y_{\ell,\text{test}} ]\).
x_augmented = jnp.concatenate([x_test, samples_low_test[:, None]], axis=-1)
Finally, sample the posterior multi-fidelity (augmented) GP:
\(~~~~~~~\) And evaluate at the (augmented) test point:
predictive_dist_multi_test = generate_predictive_dist(x_augmented, posterior_multi, D_multi)
key, subkey = jrandom.split(key)
samples_multi_test = predictive_dist_multi_test.sample((), key)
print(f'The sample is: {samples_multi_test.squeeze():.4f}')
print(f'The true value is: {high_fidelity_model(x_test[0]):.4f}')
The sample is: 0.9073
The true value is: 0.9278
And that’s it! We now know how to construct and sample a multi-fidelity GP surrogate. Next, we’ll check our surrogate’s accuracy.
Before we do that, however, let’s train a vanilla GP on the high-fidelity data for comparison:
# Construct high-fidelity GP
mean_high = gpx.mean_functions.Zero()
kernel_high = gpx.kernels.RBF(lengthscale=jnp.ones(2), variance=1.0)
prior_high = gpx.gps.Prior(mean_function=mean_high, kernel=kernel_high)
likelihood_high = gpx.likelihoods.Gaussian(
num_datapoints=D_high.n,
obs_stddev=1.0
)
posterior_high = prior_high * likelihood_high
# Optimize multi-fidelity hyperparameters
posterior_high, history_high = gpx.fit(
model=posterior_high,
objective=negative_mll,
train_data=D_high,
optim=optax.adam(1e-2), # Adam optimizer
num_iters=2000,
batch_size=128,
key=subkey,
verbose=True
)
Show code cell output
Let’s also wrap the sampling code into functions (for organization):
def sample_multi_fidelity_gp(
X,
key,
num_samples,
posterior_low,
posterior_multi,
D_low,
D_multi,
):
"""Sample from the multi-fidelity Gaussian process.
Parameters
----------
X : ndarray
The test points.
posterior_low : gpx.gps.Posterior
The low-fidelity GP posterior.
posterior_multi : gpx.gps.Posterior
The multi-fidelity GP posterior.
key : PRNGKey
The random key.
num_samples : int
The number of samples to draw.
D_low : gpx.Dataset
The low-fidelity dataset.
D_multi : gpx.Dataset
The multi-fidelity dataset.
Returns
-------
ndarray
The samples from the multi-fidelity GP. The shape is (num_samples, num_test_points).
"""
predictive_dist_low = generate_predictive_dist(X, posterior_low, D_low)
samples_low_test = predictive_dist_low.sample((num_samples,), key)
single_augment_fn = lambda yl: jnp.concatenate([X, yl[:, None]], axis=-1)
x_aug = vmap(single_augment_fn)(samples_low_test) # This is now shape (num_samples, num_test_points, d)
x_aug_flat = x_aug.reshape(-1, X.shape[-1] + 1) # This is now shape (num_samples*num_test_points, d)
predictive_dist_multi = generate_predictive_dist(x_aug_flat, posterior_multi, D_multi)
samples_multi_test = predictive_dist_multi.sample((), key).reshape(num_samples, X.shape[0])
return samples_multi_test
def sample_vanilla_gp(
X,
key,
num_samples,
posterior,
D
):
"""Sample from a vanilla GP.
Parameters
----------
X : ndarray
The test points.
key : PRNGKey
The random key.
num_samples : int
The number of samples to draw.
posterior : gpx.gps.Posterior
The GP posterior.
D : gpx.Dataset
The dataset.
"""
predictive_dist = generate_predictive_dist(X, posterior, D)
samples = predictive_dist.sample((num_samples,), key)
return samples
Now, we are ready to see how well the GPs predict the test data. Let’s start by making parity plots:
sample_low_fidelity_gp = partial(sample_vanilla_gp, posterior=posterior_low, D=D_low)
sample_high_fidelity_gp = partial(sample_vanilla_gp, posterior=posterior_high, D=D_high)
sample_multi_fidelity_gp = partial(sample_multi_fidelity_gp, posterior_low=posterior_low, posterior_multi=posterior_multi, D_low=D_low, D_multi=D_multi)
num_samples = 50
X_test = D_high_test.X
samples_low = sample_low_fidelity_gp(X_test, key, num_samples)
samples_high = sample_high_fidelity_gp(X_test, key, num_samples)
samples_multi = sample_multi_fidelity_gp(X_test, key, num_samples)
mean_low_test = samples_low.mean(axis=0)
std_low_test = samples_low.std(axis=0)
mean_high_test = samples_high.mean(axis=0)
std_high_test = samples_high.std(axis=0)
mean_multi_test = samples_multi.mean(axis=0)
std_multi_test = samples_multi.std(axis=0)
Show code cell source
rmse = lambda y, y_hat: jnp.sqrt(jnp.mean((y - y_hat)**2))
# Parity plot - add uncertainty on predictions using whiskers
fig, ax = plt.subplots(1, 3, figsize=(10,4), tight_layout=True)
fig.suptitle("Surrogate predictions vs. high-fidelity test data", fontsize=20)
ax[0].errorbar(D_high_test.y, mean_low_test, yerr=2 * std_low_test, fmt="o", markersize=3, lw=0.5)
ax[0].plot(D_high_test.y, D_high_test.y, "r-", lw=1)
ax[0].set_xlabel("True model")
ax[0].set_ylabel("Surrogate")
ax[0].set_title(r"Low fidelity GP, $\hat{f}_\ell$", fontsize=14)
ax[0].annotate(f"RMSE: {rmse(D_high_test.y.squeeze(-1), mean_low_test):.3f}", xy=(0.05, 0.9), xycoords='axes fraction')
ax[1].errorbar(D_high_test.y, mean_high_test, yerr=2 * std_high_test, fmt="o", markersize=3, lw=0.5)
ax[1].plot(D_high_test.y, D_high_test.y, "r-", lw=1)
ax[1].set_xlabel("True model")
ax[1].set_ylabel("Surrogate")
ax[1].set_title(r"High fidelity GP, $\hat{f}_h$", fontsize=14)
ax[1].annotate(f"RMSE: {rmse(D_high_test.y.squeeze(-1), mean_high_test):.3f}", xy=(0.05, 0.9), xycoords='axes fraction')
ax[2].errorbar(D_high_test.y, mean_multi_test, yerr=2 * std_multi_test, fmt="o", markersize=3, lw=0.5)
ax[2].plot(D_high_test.y, D_high_test.y, "r-", lw=1)
ax[2].set_xlabel("True model")
ax[2].set_ylabel("Surrogate")
ax[2].set_title(r"Multi fidelity GP, $\hat{f}_m$", fontsize=14)
ax[2].annotate(f"RMSE: {rmse(D_high_test.y.squeeze(-1), mean_multi_test):.3f}", xy=(0.05, 0.9), xycoords='axes fraction')
sns.despine(trim=True);
The multi-fidelity GP has the most accurate predictions. Let’s visualize the mean predictive surface against the ground truth:
Show code cell source
# Surface plot
x1, x2 = jnp.linspace(0, 1, 20), jnp.linspace(0, 1, 20)
X1, X2 = jnp.meshgrid(x1, x2)
X_plt = jnp.stack([X1.ravel(), X2.ravel()], axis=-1)
Y_high_plt = vmap(high_fidelity_model)(X_plt).reshape(X1.shape)
num_samples = 20
mean_multi_plt = sample_multi_fidelity_gp(X_plt, key, num_samples).mean(axis=0).reshape(X1.shape)
fig = plt.figure(figsize=(5,4))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X1, X2, mean_multi_plt, label=r'Mean of multi-fidelity GP, $\mathbb{E}[\hat{f}_m]$', color='tab:green', lw=0.1, alpha=0.3)
ax.plot_surface(X1, X2, Y_high_plt, label=r'True function, $f_h$', color='tab:blue', lw=0.1, alpha=0.7)
ax.view_init(elev=20, azim=-58)
ax.set_title('Mean predictive surface vs ground truth', fontsize=14)
ax.set_xlabel(r'$x_1$')
ax.set_ylabel(r'$x_2$')
ax.set_zlabel(r'$y$')
ax.legend();
The surfaces are basically right on top of each other! Let’s take a closer look at a vertical “slice” of the surface plot. Fix \(x_2=0.5\), and vary \(x_1\) along the \(x\)-axis:
# Slice plot
x1_slice = jnp.linspace(0, 1, 40)
x2_slice = jnp.full_like(x1_slice, 0.4)
X_slice = jnp.stack([x1_slice, x2_slice], axis=-1)
num_samples = 120
samples_low = sample_low_fidelity_gp(X_slice, key, num_samples)
samples_high = sample_high_fidelity_gp(X_slice, key, num_samples)
samples_multi = sample_multi_fidelity_gp(X_slice, key, num_samples)
mean_low_slice = samples_low.mean(axis=0)
std_low_slice = samples_low.std(axis=0)
mean_high_slice = samples_high.mean(axis=0)
std_high_slice = samples_high.std(axis=0)
mean_multi_slice = samples_multi.mean(axis=0)
std_multi_slice = samples_multi.std(axis=0)
Show code cell source
fig, ax = plt.subplots(figsize=(4,3))
ax.set_title("Slice through the predictive surfaces", fontsize=14)
ax.plot(x1_slice, mean_low_slice, label=r"Low-fidelity GP, $\hat{f}_\ell$", lw=2)
ax.fill_between(x1_slice, mean_low_slice - 2 * std_low_slice, mean_low_slice + 2 * std_low_slice, alpha=0.2)
ax.plot(x1_slice, mean_high_slice, label=r"High-fidelity GP, $\hat{f}_h$", lw=2)
ax.fill_between(x1_slice, mean_high_slice - 2 * std_high_slice, mean_high_slice + 2 * std_high_slice, alpha=0.2)
ax.plot(x1_slice, mean_multi_slice, label=r"Multi-fidelity GP, $\hat{f}_m$", lw=2)
ax.fill_between(x1_slice, mean_multi_slice - 2 * std_multi_slice, mean_multi_slice + 2 * std_multi_slice, alpha=0.2)
ax.plot(x1_slice, vmap(high_fidelity_model)(X_slice), label=r"True function, $f_h$", linestyle="--", color="black")
ax.set_xlabel(r'$x_1$')
ax.legend()
sns.despine(trim=True);
Again, the multi-fidelity GP \(\hat{f}_m\) approximates the high-fidelity model \(f_h\) very well! This is a significant improvement over naively fitting to the high-fidelity data alone.
Questions#
Decrease
N_LOW_FIDELITY
. How does the multi-fidelity GP \(\hat{f}_m\) perform with less low-fidelity data?Decrease
N_HIGH_FIDELITY
. How does multi-fidelity GP \(\hat{f}_m\) perform with less high-fidelity data?Increase
N_HIGH_FIDELITY
. At what point is the high-fidelity-only GP \(\hat{f}_h\) just as good as the multi-fidelity GP \(\hat{f}_m\)?Add more terms (sin/cos, exponential, quadratic, or whetever you want) to
low_fidelity_model
. How different can the low-fidelity model \(f_\ell\) be from the high-fidelity model \(f_h\) and still get a good surrogate \(\hat{f}_m\)?
Example 2: Stochastic incompressible flow past a cylinder#
This example is taken from Perdikaris et al. (2015). Suppose you have a flow past a cylinder, subject to random inflow boundary conditions of the form
where \(x \equiv (\sigma_1, \sigma_2)\) are controllable design variables representing the amplitude and skewness of the inflow noise, and \(\xi \equiv (\xi_1, \xi_2)\) are standard Gaussian random variables. Let \(C_\text{BP}\) be the base pressure coefficient at the rear of the cylinder (see figure below).
Figure 9 from Perdikaris et al. (2015). The left figure shows the spatial discretization used for CFD, while the right figure shows the quantity of interest (i.e., the base pressure coefficient \(C_\text{BP}\)).
Quantity of interest: Superquantile risk of base pressure coefficient#
Note that for each inflow condition \(x\), we will get a distribution for \(C_\text{BP}\). Our quantity of interest (QOI) is the mean of the upper 40% of this distribution, also known as the superquantile risk and denoted as \(\mathcal{R}_{0.6}[C_\text{BP}]\). It’s kind of like an expectation that places more weight on “risker” outcomes. Minimizing the superquantile risk, therefore, translates to finding a risk-averse design, since it penalizing inflows based on their 40% worst-case scenarios for \(C_\text{BP}.\)
Computing the risk, however, is expensive—evaluating the risk once requires many CFD simulations. Our goal therefore is to somehow speed up evaluation of \(\mathcal{R}_{0.6}[C_\text{BP}](x)\).
Datasets: Probabilistic collocation (high-fidelity) and Monte Carlo (low-fidelity)#
As described in the paper, there are two different models that approximate \(\mathcal{R}_{0.6}[C_\text{BP}]\). For some input \(x\), the cheap low-fidelity model \(f_\ell\) uses coarse Monte Carlo (MC) estimation, while the expensive high-fidelity model \(f_h\) uses the more accurate (but in this case ~16 times slower) probabilistic collocation method (PCM). As in example 1, we have many evaluations of \(f_\ell\) and only a few of \(f_h\). And we want to approximate \(f_h\). This is a classic setup for multi-fidelity GP surrogate modeling. Let’s do it.
First, import and visualize the data:
data_folder = '../../data/mf_cylinder_flow'
high_fid_data = pd.read_csv(os.path.join(data_folder, 'high_fidelity_data.csv'))
low_fid_data = pd.read_csv(os.path.join(data_folder, 'low_fidelity_data.csv'))
ground_truth = pd.read_csv(os.path.join(data_folder, 'ground_truth.csv'))
D_high_cyl = gpx.Dataset(jnp.array(high_fid_data[['x1', 'x2']].values), jnp.array(high_fid_data['y'].values[:, None]))
D_low_cyl = gpx.Dataset(jnp.array(low_fid_data[['x1', 'x2']].values), jnp.array(low_fid_data['y'].values[:, None]))
gt = ground_truth.values.reshape(21, 17, 3)
Show code cell source
fig = plt.figure()
ax = fig.add_subplot(121)
ax.scatter(D_low_cyl.X[:,0], D_low_cyl.X[:, 1], label=r'Low-fidelity data, $X_\ell$', marker='o', facecolors='none', edgecolors='tab:blue')
ax.scatter(D_high_cyl.X[:,0], D_high_cyl.X[:, 1], label=r'High-fidelity data, $X_h$', marker='o', facecolors='tab:orange', edgecolors='none')
ax.set_aspect(5)
ax.set_title('Data locations', fontsize=12)
ax.set_xlabel(r'$\sigma_1$')
ax.set_ylabel(r'$\sigma_2$')
leg = ax.legend(loc='upper right')
leg.get_frame()
x1, x2 = jnp.linspace(0, 1, 100), jnp.linspace(0, 1, 100)
X1, X2 = jnp.meshgrid(x1, x2)
Y_low = vmap(low_fidelity_model)(jnp.stack([X1.ravel(), X2.ravel()], axis=-1)).reshape(X1.shape)
Y_high = vmap(high_fidelity_model)(jnp.stack([X1.ravel(), X2.ravel()], axis=-1)).reshape(X1.shape)
ax = fig.add_subplot(122, projection='3d')
ax.plot_surface(gt[:,:,0], gt[:,:,1], gt[:,:,2], lw=0.1, cmap='viridis', alpha=0.5)
ax.scatter3D(D_low_cyl.X[:,0], D_low_cyl.X[:, 1], D_low_cyl.y[:, 0], s=4, label='Low fidelity', color='tab:blue', alpha=0.7)
ax.scatter3D(D_high_cyl.X[:,0], D_high_cyl.X[:, 1], D_high_cyl.y[:, 0], s=4, label='High fidelity', color='tab:orange', alpha=1)
ax.view_init(elev=20, azim=-75)
ax.set_title(r'High-fidelity model for $\mathcal{R}_{0.6}[C_\text{BP}]$', fontsize=12)
ax.set_xlabel(r'$\sigma_1$')
ax.set_ylabel(r'$\sigma_2$')
sns.despine(trim=True);
Multi-fidelity Gaussian process for superquantile risk#
As before, we first construct the low-fidelity GP surrogate \(\hat{f}_\ell\):
# Construct low-fidelity GP
mean_low_cyl = gpx.mean_functions.Zero()
kernel_low_cyl = gpx.kernels.RBF(lengthscale=jnp.ones(2), variance=0.01)
prior_low_cyl = gpx.gps.Prior(mean_function=mean_low_cyl, kernel=kernel_low_cyl)
likelihood_low_cyl = gpx.likelihoods.Gaussian(
num_datapoints=D_low_cyl.n,
obs_stddev=0.0,
)
posterior_low_cyl = prior_low_cyl * likelihood_low_cyl
# Optimize low-fidelity hyperparameters
posterior_low_cyl, history_low_cyl = gpx.fit_scipy(
model=posterior_low_cyl,
objective=negative_mll,
train_data=D_low_cyl,
max_iters=1000
)
Show code cell output
Optimization terminated successfully.
Current function value: -496.380070
Iterations: 25
Function evaluations: 30
Gradient evaluations: 30
Next let’s prepare the data for training the multi-fidelity GP surrogate \(\hat{f}_m\). This time, however, there is an issue. We do not have low-fidelity evaluations at the same points as the high fidelity evaluations (which is required by our formulation). We have two options:
Easier option: Evaluate the low-fidelity model (or the low-fidelity surrogate, if it’s good enough) at the missing points.
Harder option: Formulate the multi-fidelity GP in a way that does not require low-fidelity evaluations at the same points (see Le Gratiet (2013)).
We will do the easier option:
# Evaluate the low-fidelity surrogate at the high-fidelity data points
y_low_common_cyl = generate_predictive_dist(D_high_cyl.X, posterior_low, D_low_cyl).mean()
# Create the augmented dataset for training the multi-fidelity GP
Xm_cyl = jnp.concatenate([D_high_cyl.X, y_low_common_cyl[:, None]], axis=1)
D_multi_cyl = gpx.Dataset(Xm_cyl, D_high_cyl.y)
Now that we have the augmented dataset, let’s construct the multi-fidelity GP \(\hat{f}_m\) (and a vanilla high-fidelity GP \(\hat{f}_h\) for comparison):
# Construct multi-fidelity GP
mean_multi_cyl = gpx.mean_functions.Zero()
kernel_multi_cyl = gpx.kernels.RBF(lengthscale=jnp.ones(3), variance=0.01)
prior_multi_cyl = gpx.gps.Prior(mean_function=mean_multi_cyl, kernel=kernel_multi_cyl)
likelihood_multi_cyl = gpx.likelihoods.Gaussian(
num_datapoints=D_multi_cyl.n,
obs_stddev=0.1
)
posterior_multi_cyl = prior_multi_cyl * likelihood_multi_cyl
# Optimize multi-fidelity hyperparameters
posterior_multi_cyl, history_multi_cyl = gpx.fit_scipy(
model=posterior_multi_cyl,
objective=negative_mll,
train_data=D_multi_cyl,
max_iters=1000
)
# Construct high-fidelity GP
mean_high_cyl = gpx.mean_functions.Zero()
kernel_high_cyl = gpx.kernels.RBF(lengthscale=jnp.ones(2), variance=0.1)
prior_high_cyl = gpx.gps.Prior(mean_function=mean_high_cyl, kernel=kernel_high_cyl)
likelihood_high_cyl = gpx.likelihoods.Gaussian(
num_datapoints=D_high_cyl.n,
obs_stddev=0.1
)
posterior_high_cyl = prior_high_cyl * likelihood_high_cyl
# Optimize high-fidelity hyperparameters
posterior_high_cyl, history_high_cyl = gpx.fit_scipy(
model=posterior_high_cyl,
objective=negative_mll,
train_data=D_high_cyl,
max_iters=1000
)
Show code cell output
Optimization terminated successfully.
Current function value: -22.822893
Iterations: 64
Function evaluations: 72
Gradient evaluations: 72
Optimization terminated successfully.
Current function value: -18.071678
Iterations: 34
Function evaluations: 37
Gradient evaluations: 37
As before, let’s visualize the predictive accuracy with some parity plots:
Show code cell source
sample_low_fidelity_gp_cyl = partial(sample_vanilla_gp, posterior=posterior_low_cyl, D=D_low_cyl)
sample_high_fidelity_gp_cyl = partial(sample_vanilla_gp, posterior=posterior_high_cyl, D=D_high_cyl)
sample_multi_fidelity_gp_cyl = partial(sample_multi_fidelity_gp, posterior_low=posterior_low_cyl, posterior_multi=posterior_multi_cyl, D_low=D_low_cyl, D_multi=D_multi_cyl)
num_samples_cyl = 10
_ind_test_cyl = np.random.choice(ground_truth.index, 100, replace=False)
X_test_cyl = ground_truth[['x1', 'x2']].values[_ind_test_cyl]
y_test_cyl = ground_truth['y'].values[:, None][_ind_test_cyl]
key, key_low_cyl, key_high_cyl, key_multi_cyl = jrandom.split(key, 4)
samples_low_cyl = sample_low_fidelity_gp_cyl(X_test_cyl, key_low_cyl, num_samples_cyl)
samples_high_cyl = sample_high_fidelity_gp_cyl(X_test_cyl, key_high_cyl, num_samples_cyl)
samples_multi_cyl = sample_multi_fidelity_gp_cyl(X_test_cyl, key_multi_cyl, num_samples_cyl)
mean_low_test_cyl = samples_low_cyl.mean(axis=0)
std_low_test_cyl = samples_low_cyl.std(axis=0)
mean_high_test_cyl = samples_high_cyl.mean(axis=0)
std_high_test_cyl = samples_high_cyl.std(axis=0)
mean_multi_test_cyl = samples_multi_cyl.mean(axis=0)
std_multi_test_cyl = samples_multi_cyl.std(axis=0)
# Parity plot - add uncertainty on predictions using whiskers
fig, ax = plt.subplots(1, 3, figsize=(10,4), tight_layout=True)
fig.suptitle("Surrogate predictions vs. high-fidelity test data", fontsize=20)
ax[0].errorbar(y_test_cyl, mean_low_test_cyl, yerr=2 * std_low_test_cyl, fmt="o", markersize=3, lw=0.5)
ax[0].plot(y_test_cyl, y_test_cyl, "k--")
ax[0].set_xlabel("True model")
ax[0].set_ylabel("Surrogate")
ax[0].set_title("Low fidelity GP", fontsize=14)
ax[0].annotate(f"RMSE: {rmse(y_test_cyl.squeeze(-1), mean_low_test_cyl):.3f}", xy=(0.05, 0.9), xycoords='axes fraction')
ax[1].errorbar(y_test_cyl, mean_high_test_cyl, yerr=2 * std_high_test_cyl, fmt="o", markersize=3, lw=0.5)
ax[1].plot(y_test_cyl, y_test_cyl, "k--")
ax[1].set_xlabel("True model")
ax[1].set_ylabel("Surrogate")
ax[1].set_title("High fidelity GP", fontsize=14)
ax[1].annotate(f"RMSE: {rmse(y_test_cyl.squeeze(-1), mean_high_test_cyl):.3f}", xy=(0.05, 0.9), xycoords='axes fraction')
ax[2].errorbar(y_test_cyl, mean_multi_test_cyl, yerr=2 * std_multi_test_cyl, fmt="o", markersize=3, lw=0.5)
ax[2].plot(y_test_cyl, y_test_cyl, "k--")
ax[2].set_xlabel("True model")
ax[2].set_ylabel("Surrogate")
ax[2].set_title("Multi fidelity GP", fontsize=14)
ax[2].annotate(f"RMSE: {rmse(y_test_cyl.squeeze(-1), mean_multi_test_cyl):.3f}", xy=(0.05, 0.9), xycoords='axes fraction')
sns.despine(trim=True);
The multi-fidelity GP has the best predictive accuracy. Let’s visualize the response surface of the surrogate vs. the true high-fidelity model:
Show code cell source
# Surface plot
x1_cyl, x2_cyl = jnp.linspace(ground_truth['x1'].min(), ground_truth['x1'].max(), 10), jnp.linspace(ground_truth['x2'].min(), ground_truth['x2'].max(), 10)
X1_cyl, X2_cyl = jnp.meshgrid(x1_cyl, x2_cyl)
X_plt_cyl = jnp.stack([X1_cyl.ravel(), X2_cyl.ravel()], axis=-1)
Y_high_plt_cyl = vmap(high_fidelity_model)(X_plt_cyl).reshape(X1_cyl.shape)
num_samples_cyl = 15
key, key_cyl = jrandom.split(key)
# TODO: Figure out why sampling is so slow!
mean_multi_plt_cyl = sample_multi_fidelity_gp_cyl(X_plt_cyl, key_cyl, num_samples_cyl).mean(axis=0).reshape(X1_cyl.shape)
fig = plt.figure(figsize=(5,4))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(gt[:,:,0], gt[:,:,1], gt[:,:,2], lw=0.1, color='tab:blue', alpha=0.7)
ax.plot_surface(X1_cyl, X2_cyl, mean_multi_plt_cyl, label='Multi fidelity', color='tab:green', lw=0.1, alpha=0.3)
ax.view_init(elev=20, azim=-75)
ax.set_title('Mean predictive surface vs ground truth', fontsize=14)
ax.set_xlabel(r'$x_1$')
ax.set_ylabel(r'$x_2$')
ax.set_zlabel(r'$y$')
ax.legend();
The surfaces are almost right on top of each other.