import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks")
Example - Fourier Neural Operators#
1D Burgers’ Equation#
The purpose of this notebook is to introduce you to coding the FNO using JAX. It is recommended that you use Google Colab with GPU to run this notebook as the training can take between 15-30 minutes on a CPU depending on your RAM.
The Burger’s equation for the current analysis is described by the following PDE.
This a fundamental partial differential equation and convection-diffusion equation.
The boundary conditions are provided by:
\(\Omega = (0, 2\pi)\).
\(u(t,x=0)=u(t,x=2\pi)\).
\(\nu=0.1\).
Goal: We aim to learn the operator mapping the initial condition to the solution at time one, defined by \(u_0 \mapsto u(\cdot,1)\).
# Import Libraries
import jax
from jax import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
try:
import optax
except:
!pip install optax --quiet
import optax
from typing import Callable, List
import scipy
try:
import equinox as eqx
except:
!pip install equinox --quiet
import equinox as eqx
from tqdm import tqdm
The dataset consists of \(v=2048\) initial conditions \(u(t=0,x)\) on a \(N=8192\) resolution together with their solution at time \(u(t=1,x)\). You can download the dataset by running the following cellblock.
Dataset#
# Mathworks (the creators of Matlab) host the original Li et. al. dataset in the .mat format
!curl -O https://ssd.mathworks.com/supportfiles/nnet/data/burgers1d/burgers_data_R10.mat
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
100 614M 100 614M 0 0 14.0M 0 0:00:43 0:00:43 --:--:-- 12.9M14.4M 0 0:00:42 0:00:17 0:00:25 14.9M:43 0:00:42 0:00:01 12.6M
# Load the .mat file in python environment
data = scipy.io.loadmat('burgers_data_R10.mat')
Since our data \(v^{(i)}\) and \(u^{(i)}\) are, in general, functions, to work with them numerically, we assume access only to their point-wise evaluations. The FNO architecture does not depend on the way the functions \(v^{(i)}, u^{(i)}\) are discretized.
# Extract the initial conditions (inputs) and outputs from the dataset
a, u = data['a'], data['u'] # The dataset defines 'a' as the initial condition and 'u' as the output
print(f"Shape of a (Number of Samples, Discretization): {a.shape}")
print(f"Shape of u (Number of Samples, Discretization): {u.shape}")
Shape of a (Number of Samples, Discretization): (2048, 8192)
Shape of u (Number of Samples, Discretization): (2048, 8192)
# Plot the initial condition and the solution after 1 time step
fig, ax = plt.subplots()
ax.plot(a[0], label="Initial Condition: $a$")
ax.plot(u[0], label="Solution after 1 time step")
ax.legend()
ax.set_title("Initial Condition and Solution")
ax.grid(True)
plt.show()
# Add channel dimension
# This changes the shape of `a` and `u` from (batch_size, sequence_length)
# to (batch_size, 1, sequence_length), adding a channel dimension in the middle.
a = a[:, jnp.newaxis, :]
u = u[:, jnp.newaxis, :]
# The dataset is preconcatenated with the mesh.
# Mesh is from 0 to 2 pi
mesh = jnp.linspace(0, 2 * jnp.pi, u.shape[-1])
fig, ax = plt.subplots()
ax.plot(mesh, a[0, 0], label="initial condition")
ax.plot(mesh, u[0, 0], label="After 1 time unit")
ax.legend()
ax.grid()
# Adjust the shape of `mesh` to match the batch dimension of `u`
# - First, add two new axes to `mesh` using `jnp.newaxis` to reshape it to (1, 1, mesh_length)
# - Then, repeat `mesh` along the batch dimension (axis=0) to create a shape of (batch_size, 1, mesh_length)
# - This ensures `mesh_shape_corrected` matches the shape of `a` in the batch dimension
mesh_shape_corrected = jnp.repeat(mesh[jnp.newaxis, jnp.newaxis, :], u.shape[0], axis=0)
# Concatenate `a` and `mesh_shape_corrected` along the channel dimension (axis=1)
# - This combines `a` with the reshaped `mesh`, resulting in a shape of (batch_size, 2, sequence_length)
# - The second dimension now includes both `a` and `mesh` as separate "channels"
a_with_mesh = jnp.concatenate((a, mesh_shape_corrected), axis=1)
# Check shape
a_with_mesh.shape
(2048, 2, 8192)
# Extract training and test data
train_x, test_x = a_with_mesh[:1000], a_with_mesh[1000:1200]
train_y, test_y = u[:1000], u[1000:1200]
print(f"Shape of train_x: {train_x.shape}")
print(f"Shape of train_y: {test_x.shape}")
print(f"Shape of test_x: {test_x.shape}")
print(f"Shape of test_y: {test_y.shape}")
Shape of train_x: (1000, 2, 8192)
Shape of train_y: (200, 2, 8192)
Shape of test_x: (200, 2, 8192)
Shape of test_y: (200, 1, 8192)
FNO Architecture#
class SpectralConv1d(eqx.Module):
real_weights: jax.Array # Real part of spectral weights
imag_weights: jax.Array # Imaginary part of spectral weights
in_channels: int # Number of input channels
out_channels: int # Number of output channels
modes: int # Number of frequency modes to use
def __init__(
self,
in_channels,
out_channels,
modes,
*,
key,
):
# Initialize input, output channels and number of modes
self.in_channels = in_channels
self.out_channels = out_channels
self.modes = modes
# Scaling factor for weight initialization
scale = 1.0 / (in_channels * out_channels)
# Split key for separate initialization of real and imaginary parts
real_key, imag_key = jax.random.split(key)
# Initialize real part of weights with uniform distribution
self.real_weights = jax.random.uniform(
real_key,
(in_channels, out_channels, modes),
minval=-scale,
maxval=+scale,
)
# Initialize imaginary part of weights with uniform distribution
self.imag_weights = jax.random.uniform(
imag_key,
(in_channels, out_channels, modes),
minval=-scale,
maxval=+scale,
)
def complex_mult1d(
self,
x_hat,
w,
):
# Perform complex multiplication along frequency modes
# x_hat has shape (in_channels, modes), w has shape (in_channels, out_channels, modes)
# Resulting shape is (out_channels, modes)
return jnp.einsum("iM,ioM->oM", x_hat, w)
def __call__(
self,
x,
):
# Input x shape: (in_channels, spatial_points)
channels, spatial_points = x.shape
# shape of x_hat is (in_channels, spatial_points//2+1)
# Compute FFT of input along the spatial dimension
x_hat = jnp.fft.rfft(x)
# shape of x_hat_under_modes is (in_channels, self.modes)
# Truncate the frequency representation to keep only the specified modes
x_hat_under_modes = x_hat[:, :self.modes]
# Combine real and imaginary weights to form complex weights
weights = self.real_weights + 1j * self.imag_weights
# shape of out_hat_under_modes is (out_channels, self.modes)
# Apply complex multiplication to obtain output in frequency space
out_hat_under_modes = self.complex_mult1d(x_hat_under_modes, weights)
# shape of out_hat is (out_channels, spatial_points//2+1)
# Initialize the full output in the frequency domain
out_hat = jnp.zeros(
(self.out_channels, x_hat.shape[-1]),
dtype=x_hat.dtype
)
# Insert the computed modes back into the output
# Only the first 'modes' entries are set, others remain zero
out_hat = out_hat.at[:, :self.modes].set(out_hat_under_modes)
# Perform the inverse FFT to return to spatial domain
# Output shape: (out_channels, spatial_points)
out = jnp.fft.irfft(out_hat, n=spatial_points)
return out
class FNOBlock1d(eqx.Module):
# Define the module attributes
spectral_conv: SpectralConv1d # Spectral convolution layer
bypass_conv: eqx.nn.Conv1d # Bypass convolution layer for residual connection
activation: Callable # Activation function for the layer output
def __init__(
self,
in_channels,
out_channels,
modes,
activation,
*,
key,
):
# Split the random key for initializing the spectral and bypass convolutions separately
spectral_conv_key, bypass_conv_key = jax.random.split(key)
# Initialize spectral convolution layer with given parameters and modes
self.spectral_conv = SpectralConv1d(
in_channels,
out_channels,
modes,
key=spectral_conv_key,
)
# Initialize bypass convolution layer with a kernel size of 1 for pointwise convolution
self.bypass_conv = eqx.nn.Conv1d(
in_channels,
out_channels,
1, # Kernel size is one (pointwise convolution)
key=bypass_conv_key,
)
# Set the activation function (e.g., ReLU, tanh)
self.activation = activation
def __call__(
self,
x,
):
# Perform the forward pass
# Apply spectral convolution and bypass convolution to input x
# Add the results to create a residual connection, then apply activation function
return self.activation(
self.spectral_conv(x) + self.bypass_conv(x)
)
class FNO1d(eqx.Module):
# Define the module attributes
lifting: eqx.nn.Conv1d # Initial convolution to lift input to desired channel width
fno_blocks: List[FNOBlock1d] # Sequence of FNO blocks to apply transformations
projection: eqx.nn.Conv1d # Final convolution to project output to desired shape
def __init__(
self,
in_channels, # Number of input channels
out_channels, # Number of output channels
modes, # Number of frequency modes for spectral convolution
width, # Width (number of channels) in FNO blocks
activation, # Activation function to use in FNO blocks
n_blocks=4, # Number of FNO blocks
*,
key, # Random key for initialization
):
# Split the random key for initializing the lifting layer
key, lifting_key = jax.random.split(key)
# Initialize lifting layer (pointwise convolution) to map input to `width` channels
self.lifting = eqx.nn.Conv1d(
in_channels,
width,
1, # Kernel size of 1 (pointwise convolution)
key=lifting_key,
)
# Initialize the sequence of FNO blocks
self.fno_blocks = []
for i in range(n_blocks):
# Split key for each block to ensure unique initialization
key, subkey = jax.random.split(key)
# Append an FNO block with `width` channels and `modes` frequency modes
self.fno_blocks.append(FNOBlock1d(
width,
width,
modes,
activation,
key=subkey,
))
# Split key for initializing the projection layer
key, projection_key = jax.random.split(key)
# Initialize projection layer to map `width` channels to `out_channels`
self.projection = eqx.nn.Conv1d(
width,
out_channels,
1, # Kernel size of 1 (pointwise convolution)
key=projection_key,
)
def __call__(
self,
x, # Input data of shape (in_channels, spatial_points)
):
# Apply the lifting layer to increase input channels to `width`
x = self.lifting(x)
# Pass the input through each FNO block in sequence
for fno_block in self.fno_blocks:
x = fno_block(x)
# Apply the projection layer to map to `out_channels`
x = self.projection(x)
# Return the final output
return x
# Create the FNO architecture
fno = FNO1d(
in_channels=2,
out_channels=1,
modes=16,
width=64,
activation=jax.nn.relu,
key=jax.random.PRNGKey(0),
)
def dataloader(
key, # Random key for shuffling the dataset
dataset_x, # Input dataset (features), shape: (n_samples, ...)
dataset_y, # Output dataset (labels/targets), shape: (n_samples, ...)
batch_size, # Size of each batch
):
# Get the number of samples in the dataset
n_samples = dataset_x.shape[0]
# Calculate the total number of batches needed
n_batches = int(jnp.ceil(n_samples / batch_size))
# Generate a random permutation of indices to shuffle the data
permutation = jax.random.permutation(key, n_samples)
# Yield each batch
for batch_id in range(n_batches):
# Calculate the start and end indices for the current batch
start = batch_id * batch_size
end = min((batch_id + 1) * batch_size, n_samples)
# Select the indices for the current batch from the permutation
batch_indices = permutation[start:end]
# Yield the shuffled batch of data for both x and y
yield dataset_x[batch_indices], dataset_y[batch_indices]
Loss Function#
def loss_fn(model, x, y):
# Apply the model to each example in x to get predictions
y_pred = jax.vmap(model)(x)
# Calculate mean squared error (MSE) between predictions and true values y
loss = jnp.mean(jnp.square(y_pred - y))
# Return the calculated loss
return loss
100%|██████████| 200/200 [14:17<00:00, 4.29s/it]
Training#
Optimizer#
# Initialize the Adam optimizer with a learning rate of 3e-4
optimizer = optax.adam(3e-4)
# Initialize the optimizer state, filtering the model to only include parameters that are arrays
opt_state = optimizer.init(eqx.filter(fno, eqx.is_array))
Supervisor#
@eqx.filter_jit
def make_step(model, state, x, y):
# Compute loss and gradients with respect to model parameters
loss, grad = eqx.filter_value_and_grad(loss_fn)(model, x, y)
# Calculate validation loss on a subset of test data
val_loss = loss_fn(model, test_x[..., ::32], test_y[..., ::32])
# Use optimizer to update model parameters with gradients
updates, new_state = optimizer.update(grad, state, model)
# Apply updates to the model parameters
new_model = eqx.apply_updates(model, updates)
# Return the updated model, new optimizer state, training loss, and validation loss
return new_model, new_state, loss, val_loss
Training Loop#
# Initialize lists to store loss history for training and validation
loss_history = []
val_loss_history = []
# Set random seed for data shuffling
shuffle_key = jax.random.PRNGKey(10)
# Training loop over a specified number of epochs
for epoch in tqdm(range(200)):
# Split the random key for each epoch to ensure different shuffling
shuffle_key, subkey = jax.random.split(shuffle_key)
# Loop over each batch from the data loader
for (batch_x, batch_y) in dataloader(
subkey,
train_x[..., ::32], # Use a subset of training data
train_y[..., ::32], # Use a subset of training labels
batch_size=100, # Define batch size
):
# Perform a single optimization step and get the updated model, optimizer state, and losses
fno, opt_state, loss, val_loss = make_step(fno, opt_state, batch_x, batch_y)
# Append the current training and validation losses to their respective histories
loss_history.append(loss)
val_loss_history.append(val_loss)
plt.plot(loss_history, label="train loss")
plt.plot(val_loss_history, label="val loss")
plt.legend()
plt.yscale("log")
plt.grid()
plt.show()
Ineference#
plt.plot(test_x[1, 0, ::32], label="Initial condition")
plt.plot(test_y[1, 0, ::32], label="Ground Truth")
plt.plot(fno(test_x[1, :, ::32])[0], label="FNO prediction")
plt.legend()
plt.grid()
plt.show()
plt.plot(fno(test_x[1, :, ::32])[0] - test_y[1, 0, ::32], label="Difference")
plt.legend()
plt.show()
Discretization-Invariant Models#
Such models should satisfy:
acts on any discretization of the input function, i.e. accepts any set of points in the input domain,
can be evaluated at any point of the output domain,
converges to a continuum operator as the discretization is refined.
In the following code block we will check this property (called super-resolution) when the output discretization is altered.
# Zero-Shot superresolution
plt.plot(test_x[1, 0, ::4], label="Initial condition")
plt.plot(test_y[1, 0, ::4], label="Ground Truth")
plt.plot(fno(test_x[1, :, ::4])[0], label="FNO prediction")
plt.legend()
plt.grid()
plt.plot(fno(test_x[1, :, ::4])[0] - test_y[1, 0, ::4], label="Difference")
plt.legend()
plt.show()
# Compute the error as reported in the FNO paper
test_pred = jax.vmap(fno)(test_x)
def relative_l2_norm(pred, ref):
diff_norm = jnp.linalg.norm(pred - ref)
ref_norm = jnp.linalg.norm(ref)
return diff_norm / ref_norm
rel_l2_set = jax.vmap(relative_l2_norm)(test_pred, test_y)
jnp.mean(rel_l2_set) # ~1e-2
Array(0.01014479, dtype=float64)