import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks");

!pip install diffrax==0.4.1

Example: Lorenz System#

Consider the Lorenz system of differential equations:

\[\begin{split} \begin{align*} \dot{x} &= \sigma(y-x),\\ \dot{y} &= x(\rho-z)-y,\\ \dot{z} &= xy-\beta z. \end{align*} \end{split}\]

This system has a chaotic attractor for \(\sigma=10\), \(\beta=8/3\), and \(\rho=28\). We are going to study the sensitivity of the system on the initial conditions. Our goal is to demonstrate that local sensitivity analysis is not appropriate for this system.

Let’s code it up and see what it looks like.

import jax.numpy as jnp
from jax import vmap, jit
from functools import partial
from diffrax import diffeqsolve, Tsit5, ODETerm, SaveAt


def solve_lorenz(u0, theta):

    def vector_field(t, u, theta):
        x = u[0]
        y = u[1]
        z = u[2]
        sigma = theta[0]
        beta = theta[1]
        rho = theta[2]
        dx = sigma * (y - x)
        dy = x * (rho - z) - y
        dz = x * y - beta * z
        return jnp.array([dx, dy, dz])

    return diffeqsolve(
        ODETerm(vector_field),
        Tsit5(),
        t0=0.0,
        t1=100.0,
        dt0=0.1,
        y0=u0,
        args=theta,
        saveat=SaveAt(ts=jnp.linspace(0.0, 100.0, 10_000))
    ).ys
monte_carlo_lorenz = jit(vmap(solve_lorenz, in_axes=(0, None)))

sigma = 10.0
beta = 8.0 / 3.0
rho = 28.0
theta = jnp.array([sigma, beta, rho])
u0 = jnp.array([1.0, 1.0, 1.0]).reshape(1, 3)
ys = monte_carlo_lorenz(u0, theta)
ys.shape
(1, 10000, 3)

Plot the first time steps of the solution:

ts = jnp.linspace(0.0, 100.0, 10_000)
steps_to_plot = 1_000

fig, ax = plt.subplots()
ax.plot(ts[:steps_to_plot], ys[0, :steps_to_plot, 0], label="x", lw=0.5)
ax.plot(ts[:steps_to_plot], ys[0, :steps_to_plot, 1], label="y", lw=0.5)
ax.plot(ts[:steps_to_plot], ys[0, :steps_to_plot, 2], label="z", lw=0.5)
ax.legend(frameon=False)
ax.set(xlabel="t", ylabel="u(t)", title="Lorenz system")
sns.despine(trim=True)
../_images/ff6d7c58471aed5a189e8f04f46fe4ea1b65c937e59a41d5e1ff464b389bdc5d.svg

And here is the classic butterfly plot:

fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot(ys[0, :, 0], ys[0, :, 1], sol.ys[0, :, 2], lw=0.5, alpha=0.5)
ax.set(xlabel="x", ylabel="y", zlabel="z", title="Lorenz system")
sns.despine(trim=True)
../_images/4de19b7d3222d20494d6606cc9424c92ad6cc121e095a08a6e959328763c2d85.svg

Okay. Now we are going to take a tiny blob of initial conditions and see how it evolves in time. We are going to color the points red.

import numpy as np

mu = np.array([1.0, 1.0, 1.0])
sigma = np.array([0.001, 0.001, 0.001])
num_samples = 1_000
u0_samples = np.random.normal(mu, sigma, size=(num_samples, 3))
ys_samples = monte_carlo_lorenz(u0_samples, theta)
ys_samples.shape
(1000, 10000, 3)

Here we go, the blob at different times:

for i in [0, 100, 500, 750, 1_000, 1_500, 2_000, 5_000, 10_000]:
    fig = plt.figure()
    ax = plt.axes(projection='3d')
    ax.plot(ys_samples[:, i, 0], ys_samples[:, i, 1], ys_samples[:, i, 2], '.', color='red', alpha=0.5, ms=0.5)
    ax.set(xlabel="x", ylabel="y", zlabel="z", title=f"Lorenz system at t={ts[i]:.2f}")
    sns.despine(trim=True)
    ax.set_xlim(-20, 20)
    ax.set_ylim(-20, 20)
    ax.set_zlim(0, 50)
../_images/077d35c19ce8c5e1dc1095a4ae4f484d16c5e76473ac74698f221865db6c5307.svg../_images/582f2a2db26e143fdfb6b979b1d8e6ffdad2b448e6c4343079cf789fc9f278d5.svg../_images/837d797e8a8dd25361a8e50969e1f0a519f06920439c890fa7847399e0fb31d5.svg../_images/c636d6a33885e27d9169efe1329a38a6c0171de0dcae02b5bc3f2347e86da61b.svg../_images/0079ca455530e3f6303ba91f2c56e9bd500fc33fdb7ba83675ef1463c1f11751.svg../_images/3dc322703413d4877d6e824125ab16b93341223a4bfdcc62de926803a09481ba.svg../_images/694cfe3d050f62d1605450aaabc10197b468f74d5b25922cd5a63017e754eca5.svg../_images/ef41502625aee9d7250b4d746e20140593ceca33f692b0c6e3e596c22de34a4d.svg../_images/9bde8a56b22afd0e1b47cd5c2ac439bf5b06e1e7b55c2261549e03180c2fae3a.svg

You see that the tiny blob moved everywhere in the attractor.

What does this mean? We cannot predict the future of the system from the initial conditions even if we know the parameters perfectly.

We know that we will fail, but let’s check where local sensitivity analysis will get us. How far does it predict correctly?

We need to get the Jacobian with respect to the initial conditions.

from jax import jacobian

jit_solve_lorenz = jit(solve_lorenz)
jit_jac_solve_lorenz = jit(jacobian(solve_lorenz, argnums=0))

mu_lorenz = solve_lorenz(mu, theta)
jac_lorenz = jit_jac_solve_lorenz(mu, theta)

Here is the Jacobian.

jac_lorenz.shape
(10000, 3, 3)

Notice that towards the end it has quite a few NaNs:

jac_lorenz
Hide code cell output
Array([[[ 1.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  1.00000000e+00,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00,  1.00000000e+00]],

       [[ 9.09877181e-01,  9.86922607e-02, -1.37051975e-04],
        [ 2.66055822e-01,  9.98531580e-01, -1.04108844e-02],
        [ 2.02484280e-02,  7.30568543e-03,  9.73287284e-01]],

       [[ 8.43284011e-01,  1.93439901e-01, -8.21596012e-04],
        [ 5.21020770e-01,  1.01683962e+00, -2.14508791e-02],
        [ 5.35953641e-02,  1.23710185e-02,  9.46779907e-01]],

       ...,

       [[            nan,             nan,             nan],
        [            nan,             nan,             nan],
        [            nan,             nan,             nan]],

       [[            nan,             nan,             nan],
        [            nan,             nan,             nan],
        [            nan,             nan,             nan]],

       [[            nan,             nan,             nan],
        [            nan,             nan,             nan],
        [            nan,             nan,             nan]]],      dtype=float32)

The NaNs are because of numerical errors. You may have to move to 64-bit floats to get rid of them. Anyway, local sensitivity analysis will break before that point. So let’s proceed. I’m only going to calculate the variance of the stochastic process - the full covariance is to big.

var = np.einsum(
    "tij,j,tij->ti",
    jac_lorenz,
    sigma ** 2,
    jac_lorenz
)

Again, I really hope you appreciate the magic of the einsum function. Try to do the above calculation without it.

Let’s look at the mean and the variance at specific times:

mc_mean = ys_samples.mean(axis=0)

fig, ax = plt.subplots()
ax.plot(ts[:steps_to_plot], mc_mean[:steps_to_plot, 0], label="x (Monte Carlo)", lw=1)
ax.plot(ts[:steps_to_plot], mu_lorenz[:steps_to_plot, 0], '--', label="x (Local sensitivity)", lw=1)
ax.set(xlabel="t", ylabel="u(t)", title="Lorenz system")
ax.legend(frameon=False)
sns.despine(trim=True)
../_images/761757650e39b726d2249c6de0fdc99f9635660efcc69e3abe79baa14f7d716f.svg

The variance breaks down even faster:

mc_var = ys_samples.var(axis=0)

steps_to_plot = 550
fig, ax = plt.subplots()
ax.plot(ts[:steps_to_plot], mc_var[:steps_to_plot, 0], label="x (Monte Carlo)", lw=1)
ax.plot(ts[:steps_to_plot], var[:steps_to_plot, 0], '--', label="x (Local sensitivity)", lw=1)
ax.set(xlabel="t", ylabel="u(t)", title="Lorenz system")
ax.legend(frameon=False)
ax.set_xlim(0, 10)
sns.despine(trim=True)
../_images/71dcd180d8bd216e6cb92d85402d1e1937fcf05602fcd863dc20da2d4417347e.svg