Hide code cell source
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks");

Example: The Duffing Oscillator#

We are going to apply what we learned to the Duffing oscillator. The Duffing oscillator is a simple model for a forced oscillator with a nonlinear term. It is described by the following equation:

\[\ddot{x} + \delta \dot{x} + \alpha x + \beta x^3 = \gamma \cos(\omega t)\]

where \(\delta\), \(\alpha\), \(\beta\), \(\gamma\), and \(\omega\) are constants. We can rewrite this as a system of first order equations by introducing \(v = \dot{x}\):

\[\begin{split}\begin{align} \dot{x} & = v \\ \dot{v} & = \gamma \cos(\omega t) - \delta v - \alpha x - \beta x^3 \end{align}\end{split}\]

The initial conditions are \(x(0) = x_0\) and \(v(0) = v_0\). Denote by \(\theta\) the vector of all parameters, i.e. \(\theta = (\alpha, \beta, \gamma, \delta, \omega, x_0, v_0)\). The vector field \(f(x,v,t;\theta)\) is then given by:

\[\begin{split}f(x,v,t;\theta) = \begin{pmatrix} v \\ \gamma \cos(\omega t) - \delta v - \alpha x - \beta x^3 \end{pmatrix}.\end{split}\]

Let us start by writing some code that solves the problem for a given \(\theta\), but purely in Jax. We will use the diffrax package written by Patrick Kidger.

!pip install diffrax==0.4.1

import numpy as np
import jax.numpy as jnp
from diffrax import diffeqsolve, Tsit5, ODETerm, SaveAt

def vector_field(t, y, theta):
    alpha, beta, gamma, delta, omega = theta[:5]
    x = y[0]
    v = y[1]
    return jnp.array(
        [
            v,
            - alpha * x - beta * x ** 3 - delta * v + gamma * jnp.cos(omega * t)
        ]
    )


theta = jnp.array([
    1.0,  # alpha
    5.0,  # beta
    0.37, # gamma
    0.1,  # delta
    1.0,  # omega
])

# The numerical solver to use.
solver = Tsit5()
# At which timesteps to store the solution.
saveat = SaveAt(ts=jnp.linspace(0, 50, 2000))
# The differential equation term.
term = ODETerm(vector_field)
# The Solution for one theta.
sol = diffeqsolve(
    term,
    solver,
    t0=0,                       # Initial time
    t1=50,                      # Terminal time
    dt0=0.1,                    # Initial timestep - it will be adjusted
    y0=jnp.array([0.0, 0.0]),   # Initial value
    args=theta,
    saveat=saveat
)

And here is how you can extract the solution:

print(sol.ys.shape)
(2000, 2)

Let’s plot it:

fig, ax = plt.subplots()
ax.plot(sol.ts, sol.ys[:, 0], label="x")
ax.plot(sol.ts, sol.ys[:, 1], label="v")
ax.set(xlabel="t", ylabel="x(t), v(t)")
ax.legend(frameon=False)
sns.despine(trim=True);

fig, ax = plt.subplots()
ax.plot(sol.ys[:, 0], sol.ys[:, 1], lw=1)
ax.set(xlabel="x", ylabel="v")
sns.despine(trim=True);
../_images/5f7c64d3915c78b0b7c50a9688f09c37c2dc2a35b7f8bddf15ef4161b9451485.svg../_images/f648d62ae58622b8990280256d4df6a75b89a16eaaf0e98e21fcb7920457e8bf.svg

We see that the solution converges to a stable limit cycle for this particular value of the parameters.

Let’s now add a bit of uncertainty in the parameters. Say we about 5% uncertainty in the parameters \(\alpha\), \(\beta\), \(\gamma\), \(\delta\), and \(\omega\). We can do this by sampling from a normal distribution with mean \(\theta\) and standard deviation 5% of \(\theta\):

\[\begin{split} \theta \sim \mathcal{N}\left( \begin{pmatrix} 1 \\ 5 \\ 0.37 \\ 0.1 \\ 1 \end{pmatrix}, \begin{pmatrix} 0.05^2 & 0 & 0 & 0 & 0 \\ 0 & 0.25^2 & 0 & 0 & 0 \\ 0 & 0 & 0.0185^2 & 0 & 0 \\ 0 & 0 & 0 & 0.005^2 & 0 \\ 0 & 0 & 0 & 0 & 0.05^2 \end{pmatrix} \right) \end{split}\]

To proceed, let’s make a function that takes a parameter vector \(\theta\) and returns the solution of the Duffing oscillator for that parameter vector:

def solve_duffing(theta):
    # The numerical solver to use.
    solver = Tsit5()
    # At which timesteps to store the solution.
    saveat = SaveAt(ts=jnp.linspace(0, 50, 2000))
    # The differential equation term.
    term = ODETerm(vector_field)
    # The Solution for one theta.
    sol = diffeqsolve(
        term,
        solver,
        t0=0,                       # Initial time
        t1=50,                      # Terminal time
        dt0=0.1,                    # Initial timestep - it will be adjusted
        y0=jnp.array([0.0, 0.0]),   # Initial value
        args=theta,
        saveat=saveat
    )
    return sol.ys

Okay, all we now how to do is differentiate and jit this function:

from jax import jacobian, jit

jit_solve_duffing = jit(solve_duffing)
jit_jac_solve_duffing = jit(jacobian(solve_duffing))

# Evaluate everything at the mean theta.
sol_mu = jit_solve_duffing(theta)
jac_sol_mu = jit_jac_solve_duffing(theta)

This is it, now we can do all sorts of things with this information. Recall that the solution \(Y=y(t;\theta)=(x(t;\theta), v(t;\theta))\) is a Gaussian process:

\[ Y \sim \text{GP}(y(t;\mu),\nabla_{\theta}y(t;\mu)\Sigma\nabla_{\theta}y(t';\mu)^T). \]

Again, pay attention to the fact that this is a vector-valued Gaussian process - not a scalar-valued one. So the covariance function is a matrix-valued function. Here it is 2x2.

The quantity sol_mu is its mean function evaluated at 2,000 points in time between 0 and 50. Similarly, jac_sol_mu is the Jacobian of the mean function evaluated at the same points in time. The second can be used to evaluate the covariance function of the Gaussian process at the same points.

Let’s get the covariance. We can easily find it in one line using the function einsum from Jax or Numpy. We use NumPy here because we have no plans of differentiating it. If you don’t know what einsum does, you can read about it here. It is immensely useful. If you don’t believe me watch.

We want to evaluate the covariance function of the vector-valued Gaussian process at all time steps. Consider times \(t_k\) and \(t_r\) and dimensions \(i\) and \(j\). Time steps go from \(0\) to \(2000\) and dimensions go from \(0\) to \(1\) (the dimension of the vector-valued Gaussian process, i.e., of the underlying ODE). We need to evaluate:

\[\begin{split} \begin{align*} c_{krij} = c_{ij}(t_k,t_r) &= \nabla_{\theta}y_i(t;\mu)\Sigma\nabla_{\theta}y_j(t';\mu)^T\\ &= \sum_{l=0}^4 \sum_{m=0}^4 \nabla_{\theta_l}y_i(t_k;\mu)\Sigma_{lm}\nabla_{\theta_m}y_j(t_r;\mu)^T\\ &= \nabla_{\theta_l}y_i(t_k;\mu)\Sigma_{lm}\nabla_{\theta_m}y_j(t_r;\mu)^T\;\text{(Einstein summation)} \end{align*} \end{split}\]

And here is the einsum call:

sol_cov = np.einsum("kil,lm,rjm->rkij", jac_sol_mu, Sigma, jac_sol_mu)
print(sol_cov.shape)
(2000, 2000, 2, 2)

I hope you appreciate the beauty of this.

Okay, now let’s extract the diagonal of the covariance matrix so that we can make some nice predictive intervals. We will use einsum again to extract the diagonal of the covariance matrix at all time steps:

tmp = np.einsum("kkii->ki", sol_cov)
x_std = np.sqrt(tmp[:, 0])
v_std = np.sqrt(tmp[:, 1])

Now we can plot the solution and the predictive intervals:

ts = np.linspace(0, 50, 2000)
steps_to_plot = 200
fig, ax = plt.subplots()
ax.plot(ts[:steps_to_plot], sol_mu[:steps_to_plot, 0], label="x")
ax.fill_between(
    ts[:steps_to_plot],
    sol_mu[:steps_to_plot, 0] - 2 * x_std[:steps_to_plot],
    sol_mu[:steps_to_plot, 0] + 2 * x_std[:steps_to_plot],
    alpha=0.5,
    label="2 std",
)
ax.set(xlabel="t", ylabel="x(t)")
sns.despine(trim=True)

fig, ax = plt.subplots()
ax.plot(ts[:steps_to_plot], sol_mu[:steps_to_plot, 1], label="v")
ax.fill_between(
    ts[:steps_to_plot],
    sol_mu[:steps_to_plot, 1] - 2 * v_std[:steps_to_plot],
    sol_mu[:steps_to_plot, 1] + 2 * v_std[:steps_to_plot],
    alpha=0.5,
    label="2 std",
)
ax.set(xlabel="t", ylabel="v(t)")
sns.despine(trim=True)
../_images/b61498d761b552ddd080fb897d3b9a18dde2ed9a97e97bd53f628cc45e728aa7.svg../_images/002cdbe9624293a9a8d096fd1573cc0f61265e6932b5f18c8c47bb8d72a7e65e.svg

You may have noticed that, I din’t plot the solution at all time steps. This is because it is wrong after a certain point. I will demonstrate this below.

But how good is our result? Let’s do Monte Carlo sampling of the parameters so that we establish a ground truth for the solution. We will use vmap to vectorize the function solve_duffing over the first axis of the parameter array theta and then use jit to compile the function:

from jax import jit, vmap

many_solve_duffing = jit(vmap(solve_duffing, in_axes=(0,)))

Let’s do Monte Carlo:

num_thetas = 1_000_000
mu = np.array([1.0, 5.0, 0.37, 0.1, 1.0])
Sigma = np.diag((0.05 * mu) ** 2)
L = np.linalg.cholesky(Sigma)
thetas = np.random.randn(num_thetas, 5) @ L + mu
sols = many_solve_duffing(thetas)

Let’s extract some statistics to compare:

mean_sol = np.mean(sols, axis=0)
std_sol = np.std(sols, axis=0)

Let’s plot predictive intervals for ‘x’ and ‘v’ for the same time steps as before:

Hide code cell source
fig, ax = plt.subplots()
ax.plot(ts[:steps_to_plot], mean_sol[:steps_to_plot, 0], label="x")
ax.fill_between(
    ts[:steps_to_plot],
    mean_sol[:steps_to_plot, 0] - 2 * std_sol[:steps_to_plot, 0],
    mean_sol[:steps_to_plot, 0] + 2 * std_sol[:steps_to_plot, 0],
    alpha=0.5,
    label="2 std",
)
ax.set(xlabel="t", ylabel="x(t)")
sns.despine(trim=True);

fig, ax = plt.subplots()
ax.plot(ts[:steps_to_plot], mean_sol[:steps_to_plot, 1], label="v")
ax.fill_between(
    ts[:steps_to_plot],
    mean_sol[:steps_to_plot, 1] - 2 * std_sol[:steps_to_plot, 1],
    mean_sol[:steps_to_plot, 1] + 2 * std_sol[:steps_to_plot, 1],
    alpha=0.5,
    label="2 std",
)
ax.set(xlabel="t", ylabel="v(t)")
sns.despine(trim=True);
../_images/c808d6eba0121ac8c779afb026f9f39c75651dc9859822260b7861cc2115ec2e.svg../_images/21e77ba01fbbd210fafa0ab18d31142c0011c539e9c282b5dd72485540e6ba97.svg

These look identical to the ones we got before. Let’s take a closer look. First, let’s compare the mean functions:

Hide code cell source
fig, ax = plt.subplots()
ax.plot(ts[:steps_to_plot], mean_sol[:steps_to_plot, 0], label="x (Monte Carlo)")
ax.plot(ts[:steps_to_plot], sol_mu[:steps_to_plot, 0], '--', label="x (Local sensitivity)")
ax.set(xlabel="t", ylabel="x(t)")
ax.legend(frameon=False)
sns.despine(trim=True);

fig, ax = plt.subplots()
ax.plot(ts[:steps_to_plot], mean_sol[:steps_to_plot, 1], label="v (Monte Carlo)")
ax.plot(ts[:steps_to_plot], sol_mu[:steps_to_plot, 1], '--', label="v (Local sensitivity)")
ax.set(xlabel="t", ylabel="v(t)")
ax.legend(frameon=False)
sns.despine(trim=True);
../_images/ea75ac52f46aa8bff3b35439c2c12a1782f31d2fe3dfb1289240833ffa40ab56.svg../_images/d24252a305bcb68b52afa2f3d2d6c5158a984e2aea1a28d08850860d053d1ff7.svg

Pretty good. Now the standard deviations:

Hide code cell source
fig, ax = plt.subplots()
ax.plot(ts[:steps_to_plot], std_sol[:steps_to_plot, 0], label="x (Monte Carlo)")
ax.plot(ts[:steps_to_plot], x_std[:steps_to_plot], '--', label="x (Local sensitivity)")
ax.set(xlabel="t", ylabel="x(t)")
ax.legend(frameon=False)
sns.despine(trim=True);

fig, ax = plt.subplots()
ax.plot(ts[:steps_to_plot], std_sol[:steps_to_plot, 1], label="v (Monte Carlo)")
ax.plot(ts[:steps_to_plot], v_std[:steps_to_plot], '--', label="v (Local sensitivity)")
ax.set(xlabel="t", ylabel="v(t)")
ax.legend(frameon=False)
sns.despine(trim=True);
../_images/efa5b05fdd3f61d926db99b9eea66eca8c6984b48ce7d476d3ed54d49de26243.svg../_images/b43f981d348499ef56b16c228d812900e4ae33617570cc8566686c2df743c8cd.svg

Also, pretty good - especially when we consider that it was very cheap to compute the local sensitivity results.

Okay, now let’s see what happens if we try to predict the solution for a longer time. We have already calculated everything, so all we need to do is open up the predictive intervals.

Here are the means:

steps_to_plot = 1_000

fig, ax = plt.subplots()
ax.plot(ts[:steps_to_plot], mean_sol[:steps_to_plot, 0], label="x (Monte Carlo)")
ax.plot(ts[:steps_to_plot], sol_mu[:steps_to_plot, 0], '--', label="x (Local sensitivity)")
ax.set(xlabel="t", ylabel="x(t)")
ax.legend(frameon=False)
sns.despine(trim=True);

fig, ax = plt.subplots()
ax.plot(ts[:steps_to_plot], mean_sol[:steps_to_plot, 1], label="v (Monte Carlo)")
ax.plot(ts[:steps_to_plot], sol_mu[:steps_to_plot, 1], '--', label="v (Local sensitivity)")
ax.set(xlabel="t", ylabel="v(t)")
ax.legend(frameon=False)
sns.despine(trim=True);
../_images/2b8761e07e02bcd575337e2b473ccbf0702bf37ea43233a2b738d2c7a9b5ebe2.svg../_images/35982ee2495e9dc9fe81cb351de88fde7099070dc3a72682758faebb4282bafa.svg

And here are the standard deviations:

fig, ax = plt.subplots()
ax.plot(ts[:steps_to_plot], std_sol[:steps_to_plot, 0], label="x (Monte Carlo)")
ax.plot(ts[:steps_to_plot], x_std[:steps_to_plot], '--', label="x (Local sensitivity)")
ax.set(xlabel="t", ylabel="x(t)")
ax.legend(frameon=False)
sns.despine(trim=True);

fig, ax = plt.subplots()
ax.plot(ts[:steps_to_plot], std_sol[:steps_to_plot, 1], label="v (Monte Carlo)")
ax.plot(ts[:steps_to_plot], v_std[:steps_to_plot], '--', label="v (Local sensitivity)")
ax.set(xlabel="t", ylabel="v(t)")
ax.legend(frameon=False)
sns.despine(trim=True);
../_images/68321f1befce1383f2516db0c8f01853a0063c1b620fc7e18b0359e9690d69bf.svg../_images/3066f848ab99ca9f236288f9d5b693fd4be655ab4a9a363e594b8b59d27ef5cf.svg

We observe that the method does well at first, but then it progressively yields the wrong statistics. This is expected with local sensitivity analysis whenever the solution is not a linear function of the parameters.