Hide code cell source
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks");

!pip install diffrax==0.4.1
!pip install orthojax --upgrade
!pip install py-design --upgrade
Hide code cell output
Requirement already satisfied: diffrax==0.4.1 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (0.4.1)
Requirement already satisfied: jax>=0.4.13 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from diffrax==0.4.1) (0.4.19)
Requirement already satisfied: equinox>=0.10.11 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from diffrax==0.4.1) (0.11.2)
Requirement already satisfied: jaxtyping>=0.2.20 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from equinox>=0.10.11->diffrax==0.4.1) (0.2.25)
Requirement already satisfied: typing-extensions>=4.5.0 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from equinox>=0.10.11->diffrax==0.4.1) (4.8.0)
Requirement already satisfied: ml-dtypes>=0.2.0 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jax>=0.4.13->diffrax==0.4.1) (0.3.1)
Requirement already satisfied: numpy>=1.22 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jax>=0.4.13->diffrax==0.4.1) (1.25.2)
Requirement already satisfied: opt-einsum in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jax>=0.4.13->diffrax==0.4.1) (3.3.0)
Requirement already satisfied: scipy>=1.9 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jax>=0.4.13->diffrax==0.4.1) (1.11.3)
Requirement already satisfied: typeguard<3,>=2.13.3 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jaxtyping>=0.2.20->equinox>=0.10.11->diffrax==0.4.1) (2.13.3)
DEPRECATION: graphql-ws 0.3.0 has a non-standard dependency specifier graphql-core>=2.0<3. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of graphql-ws or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063
Requirement already satisfied: orthojax in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (0.1.5)
Requirement already satisfied: jax>=0.4.19 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from orthojax) (0.4.19)
Requirement already satisfied: numpy in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from orthojax) (1.25.2)
Requirement already satisfied: equinox>=0.11.2 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from orthojax) (0.11.2)
Requirement already satisfied: jaxtyping>=0.2.20 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from equinox>=0.11.2->orthojax) (0.2.25)
Requirement already satisfied: typing-extensions>=4.5.0 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from equinox>=0.11.2->orthojax) (4.8.0)
Requirement already satisfied: ml-dtypes>=0.2.0 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jax>=0.4.19->orthojax) (0.3.1)
Requirement already satisfied: opt-einsum in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jax>=0.4.19->orthojax) (3.3.0)
Requirement already satisfied: scipy>=1.9 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jax>=0.4.19->orthojax) (1.11.3)
Requirement already satisfied: typeguard<3,>=2.13.3 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jaxtyping>=0.2.20->equinox>=0.11.2->orthojax) (2.13.3)
DEPRECATION: graphql-ws 0.3.0 has a non-standard dependency specifier graphql-core>=2.0<3. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of graphql-ws or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063
Requirement already satisfied: py-design in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (2.0)
DEPRECATION: graphql-ws 0.3.0 has a non-standard dependency specifier graphql-core>=2.0<3. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of graphql-ws or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063

Uncertainty Propagation in Dynamical Systems#

Consider the following \(n\) dimensional dynamical system:

\[ \dot{\mathbf{x}} = \mathbf{f}(t,\mathbf{x};\boldsymbol{\Xi}), \]

where \(\mathbf{x} \in \mathbb{R}^n\) is the state vector, \(\mathbf{f}\) is a vector valued function, \(t\) is time, and \(\Xi\) is a vector of uncertain parameters. The initial condition is:

\[ \mathbf{x}(0) = \mathbf{x}_0(\boldsymbol{\Xi}). \]

We assume that:

\[ \boldsymbol{\Xi} = (\Xi_1,\dots,\Xi_d), \]

independent. The goal is to propagate the uncertainty in \(\boldsymbol{\Xi}\) through the system to obtain the uncertainty in \(\mathbf{x}\).

The space we work with is:

\[ L^2(\mathbf{\Xi},\mathbb{R}^n) = \left\{\mathbf{g}:\mathbb{R}^d\to\mathbb{R}^n\mid \int_{\mathbb{R}^d} \|\mathbf{g}(\boldsymbol{\Xi})\|^2\,d\boldsymbol{\Xi} < \infty\right\}, \]

with inner product:

\[ \langle \mathbf{g},\mathbf{h}\rangle = \int_{\mathbb{R}^d} \mathbf{g}(\boldsymbol{\Xi})\cdot\mathbf{h}(\boldsymbol{\Xi})\,d\boldsymbol{\Xi}. \]

Let \(\{\phi_\alpha\}\) be the tensor product orthonormal basis for \(L^2(\mathbf{\Xi})\) and \(\{\mathbf{e}_i\}\) be the standard basis for \(\mathbb{R}^n\). Then, the functions:

\[ \boldsymbol{\psi}_{i,\alpha} = \mathbf{e}_i\phi_\alpha, \]

form an orthonormal basis for \(L^2(\mathbf{\Xi},\mathbb{R}^n)\).

We expand the dynamical system state in this basis (at each time):

\[ \mathbf{x}(t;\Xi) = \sum_{i=1}^n \sum_{\alpha} \mathbf{x}_{i,\alpha}(t)\boldsymbol{\psi}_{i,\alpha}(\boldsymbol{\Xi}). \]

We plug this into the dynamical system to get:

\[ \dot{\mathbf{x}}(t;\Xi) = \sum_{i=1}^n \sum_{\alpha} \dot{\mathbf{x}}_{i,\alpha}(t)\boldsymbol{\psi}_{i,\alpha}(\boldsymbol{\Xi}) = \mathbf{f}\left(t,\sum_{i=1}^n \sum_{\alpha} \mathbf{x}_{i,\alpha}(t)\boldsymbol{\psi}_{i,\alpha}(\boldsymbol{\Xi});\boldsymbol{\Xi}\right). \]

We project each side onto \(\psi_{j,\beta}\) and use the orthogonality of the basis to get:

\[ \dot{\mathbf{x}}_{j,\beta}(t) = \left\langle \mathbf{f}\left(t,\sum_{i=1}^n \sum_{\alpha} \mathbf{x}_{i,\alpha}(t)\psi_{i,\alpha}(\boldsymbol{\Xi});\boldsymbol{\Xi}\right),\boldsymbol{\psi}_{j,\beta}\right\rangle. \]

This is a differential equation that describes the evolution of the coefficients \(\mathbf{x}_{i,\alpha}(t)\). The initial condition is:

\[ \mathbf{x}_{i,\alpha}(0) = \left\langle \mathbf{x}_0(\boldsymbol{\Xi}),\boldsymbol{\psi}_{i,\alpha}\right\rangle. \]

Let’s write jax code that solves this problem.

from collections import namedtuple

import orthojax as ojax
import design
import jax.numpy as jnp
from jax import vmap, jit


def make_sparse_grid(dim, level):
    """Make a sparse grid of dimension dim and a given level.
    We do it for the uniform cube [-1, 1]^d."""
    x, w = design.sparse_grid(dim, level, 'F2')
    w = w / (2 ** dim)
    x = jnp.array(x, dtype=jnp.float32)
    w = jnp.array(w, dtype=jnp.float32)
    return ojax.QuadratureRule(x, w)


PCProblem = namedtuple("PCProblem", ["poly", "quad", "f", "x0", "phis", "y0", "rhs"])


def make_pc_problem(poly, quad, f, x0):
    """Make the PC dynamical system problem.

    Params:
        poly: The polynomial basis
        quad: The quadrature rule used to compute inner products
        f: The function defining the right hand side of the ODE (function of x, t and xi) to R^n
        x0: The initial condition (function of xi, from R^d -> R^n)
        theta: The parameters of the ODE
    """
    # The quadrature rule used to compute inner products
    xis, ws = quad
    # xis is m x d and ws is m

    # The polynomial basis functions on the collocation points
    phis = poly(xis)
    # this is m x p

    # The initial condition of the PC coefficients
    x0s = jit(vmap(x0))(xis) # this is m x n
    # The PC coefficients are n x p
    # ws is m
    # phis is m x p
    # x0s is m x n
    # y0 must be n x p
    y0 = jnp.einsum("m,mp,mn->np", ws, phis, x0s)
    
    # Vectorize the function f
    fv = vmap(f, in_axes=(None, 0, 0))
    
    # The right hand side of the PC ODE
    def rhs(t, y, phis):
        # y is n x p
        # phis is m x p
        # xs must be m x n
        xs = jnp.einsum("np,mp->mn", y, phis)
        # xs is m x n
        # xis is m x d
        # fs must be m x n
        fs = fv(t, xs, xis)
        # do the dot product with quadrature weights
        return jnp.einsum("m,mn,mp->np", ws, fs, phis)
    
    return PCProblem(poly, quad, f, x0, phis, y0, rhs)

Example: Duffing Oscillator with Random Initial State#

\[\begin{split}\begin{align} \dot{x} & = v \\ \dot{v} & = \gamma \cos(\omega t) - \delta v - \alpha x - \beta x^3, \end{align}\end{split}\]

With initial state:

\[ x(0) \sim N(\mu_x, \sigma_x^2), \quad v(0) \sim N(\mu_y, \sigma_v^2). \]

We are going to keep the parameters \(\alpha,\beta,\gamma,\delta, \omega\) fixed and only vary the initial state.

The first thing we are going to do is express the initial conditions in terms of independent random variables \(\Xi_1,\Xi_2 \sim U[-1,1]\). This will allow us to use Legendre polynomials. Let \(\Phi\) be the CDF of the standard normal distribution. Then:

\[ x(0) = \mu_x + \sigma_x \Phi^{-1}\left((\Xi_1 + 1) / 2\right), \quad v(0) = \mu_v + \sigma_v \Phi^{-1}\left((\Xi_2 + 1) / 2\right). \]

We are going to develop both a Monte Carlo solver and a polynomial chaos solver. But first, let’s get a bit organized. We are going to create some useful named tuples to hold the parameters and the initial conditions.

import equinox as eqx
from collections import namedtuple

NormalDistribution = namedtuple("NormalDistribution", ["mu", "sigma"])
Parameters = namedtuple("Parameters", ["alpha", "beta", "gamma", "delta", "omega"])

Duffing = namedtuple("Duffing", ["params", "X", "V"])

These can be used as follows:

X = NormalDistribution(0.0, 0.1)
V = NormalDistribution(0.0, 0.1)

params = Parameters(1.0, 5.0, 0.37, 0.1, 1.0)

duffing = Duffing(params, X, V)

Here is how they appear:

print(duffing)
Duffing(params=Parameters(alpha=1.0, beta=5.0, gamma=0.37, delta=0.1, omega=1.0), X=NormalDistribution(mu=0.0, sigma=0.1), V=NormalDistribution(mu=0.0, sigma=0.1))

This is, of course, a pytree and it will help us write functions with not so many arguments.

Now, let’s write code that implements the initial conditions and vector field:

from jax.scipy import stats as jstats
from functools import partial
from diffrax import diffeqsolve, Tsit5, SaveAt, ODETerm


def to_normal(xi : float, dist : NormalDistribution) -> float:
    """Transforms a [-1, 1] to a normal distribution."""
    return dist.mu + dist.sigma * jstats.norm.ppf(0.5 * (xi + 1))

def x0(xi, duffing : Duffing):
    """Initial condition for the position."""
    return jnp.array(
        [to_normal(xi[0], duffing.X), to_normal(xi[1], duffing.V)]
    )

def vector_field(t, y, params):
    x = y[0]
    v = y[1]
    alpha = params.alpha
    beta = params.beta
    gamma = params.gamma
    delta = params.delta
    omega = params.omega
    return jnp.array(
        [
            v,
            - alpha * x - beta * x ** 3 - delta * v + gamma * jnp.cos(omega * t)
        ]
    )

@jit
@partial(vmap, in_axes=(0, None))
def solve_duffing(xi, duffing : Duffing):
    """Simple solver of the dynamical system."""
    solver = Tsit5()
    saveat = SaveAt(ts=jnp.linspace(0, 10, 2000))
    term = ODETerm(vector_field)
    sol = diffeqsolve(
        term,
        solver,
        t0=0,                       
        t1=10,                      
        dt0=0.1,                    
        y0=x0(xi, duffing),
        args=duffing.params,
        saveat=saveat
    )
    return sol.ys

Develop the Monte Carlo ground truth:

import numpy as np

num_samples = 100_000
xis = 2 * np.random.uniform(size=(num_samples, 2)) - 1
samples = solve_duffing(xis, duffing)

mc_mean = jnp.mean(samples, axis=0)
mc_var = jnp.var(samples, axis=0)

Now, let’s write a polynomial chaos solver. First, construct the polynomials and the quadrature rule:

from functools import partial

total_degree = 5
degrees = (5, 5)
poly = ojax.TensorProduct(
    total_degree,
    [ojax.make_legendre_polynomial(d) for d in degrees])
level = 5
quad = make_sparse_grid(2, level)

Now, make the polynomial chaos solver:

new_vector_field = lambda t, x, xi: vector_field(t, x, duffing.params)
new_x0 = lambda xi: x0(xi, duffing)
pc_problem = make_pc_problem(poly, quad, new_vector_field, new_x0)
@jit
def solve_duffing_pc(duffing, poly=poly, quad=quad):
    # Adhere to the PCProblem interface
    new_vector_field = lambda t, x, xi: vector_field(t, x, duffing.params)
    new_x0 = lambda xi: x0(xi, duffing)
    pc_problem = make_pc_problem(poly, quad, new_vector_field, new_x0)
    sol = diffeqsolve(
        ODETerm(pc_problem.rhs),
        Tsit5(),
        t0=0,
        t1=10,
        dt0=0.1,
        y0=pc_problem.y0,
        args=pc_problem.phis,
        saveat=SaveAt(ts=jnp.linspace(0, 10, 2000))
    )
    return sol

And now we can solve it as follows:

pc_sol = solve_duffing_pc(duffing)

Let’s calculate the mean and the variance of PC:

pc_mean = pc_sol.ys[:, :, 0]
pc_variance = np.sum(pc_sol.ys[:, :, 1:] ** 2, axis=2)

Let’s compare the Monte Carlo solution with the polynomial chaos solution:

fig, ax = plt.subplots()
ax.plot(pc_sol.ts, pc_mean, label="PC mean")
ax.plot(pc_sol.ts, mc_mean, '--', label="MC mean")
ax.legend(loc="best")
sns.despine(trim=True);
../_images/8c33810747fe2dd64ab786dca99d04b90bcea1bd649a7df526d876d07f35c124.svg
fig, ax = plt.subplots()
ax.plot(pc_sol.ts, pc_variance, label="PC variance")
ax.plot(pc_sol.ts, mc_var, '--', label="MC variance")
ax.legend(loc="best")
sns.despine(trim=True);
../_images/4173a34635e45f042616f50744b1515a28aa85194b33b246901ea76315f1feff.svg

This looks perfect!

Of course, if you increase the initial variance, the polynomial chaos solution will be less accurate. You can experiment with this.

Let me demonstrate that the polynomial chaos solution is much faster than the Monte Carlo solution. This will take a while to run.

# write code that times the PC solver and the MC solver
%timeit solve_duffing_pc(duffing)
%timeit solve_duffing(xis, duffing)
5.01 ms ± 29.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
5.19 s ± 184 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Much faster.

Finally, let me demonstrate an added benefit of the PC solution. It can be used as a surrogate for the solution of the dynamical system. This means that you can evaluate the solution at any \(\xi\) without having to solve the dynamical system. Here is how:

@jit
def surrogate(xis, pc_coeff=pc_sol.ys, poly=poly):
    """Surrogate function for the PC solution."""
    phis = poly(xis)
    ys = jnp.einsum("tip,mp->mti", pc_coeff, phis)
    return ys

Let’s evaluate both the surrogate and the solution at a bunch of points and compare them:

num_test = 2
xis_test = 2 * np.random.uniform(size=(num_test, 2)) - 1
preds = surrogate(xis_test)
true = solve_duffing(xis_test, duffing)

Compare the two:

names = ["x", "v"]
fig, ax = plt.subplots(1, 2)
for i in range(num_test):
    for k in range(2):
        ax[k].plot(pc_sol.ts, preds[i, :, k], label=f"PC {i+1} {names[k]}")
        ax[k].plot(pc_sol.ts, true[i, :, k], '--', label=f"True {i+1} {names[k]}")
        ax[k].legend(loc="best")
        ax[k].set_xlabel("$t$")
        ax[k].set_ylabel("$" + names[k] + "$")
    plt.legend(loc="best")
sns.despine(trim=True);
../_images/baa677e660959ad5358bbeabb983046a282f2df9bcb258b433d5edd3bf9fe400.svg

You can use this surrogate to do all sorts of things, like getting higher order statistics, the PDF, or the Sobol indices with respect to \(\Xi\).