Hide code cell source
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
import seaborn as sns

!pip install orthojax --upgrade
!pip install diffrax=0.4.1
Hide code cell output
Requirement already satisfied: orthojax in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (0.1.5)
Requirement already satisfied: jax>=0.4.19 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from orthojax) (0.4.19)
Requirement already satisfied: numpy in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from orthojax) (1.25.2)
Requirement already satisfied: equinox>=0.11.2 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from orthojax) (0.11.2)
Requirement already satisfied: jaxtyping>=0.2.20 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from equinox>=0.11.2->orthojax) (0.2.25)
Requirement already satisfied: typing-extensions>=4.5.0 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from equinox>=0.11.2->orthojax) (4.8.0)
Requirement already satisfied: ml-dtypes>=0.2.0 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jax>=0.4.19->orthojax) (0.3.1)
Requirement already satisfied: opt-einsum in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jax>=0.4.19->orthojax) (3.3.0)
Requirement already satisfied: scipy>=1.9 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jax>=0.4.19->orthojax) (1.11.3)
Requirement already satisfied: typeguard<3,>=2.13.3 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jaxtyping>=0.2.20->equinox>=0.11.2->orthojax) (2.13.3)
DEPRECATION: graphql-ws 0.3.0 has a non-standard dependency specifier graphql-core>=2.0<3. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of graphql-ws or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063
ERROR: Invalid requirement: 'diffrax=0.4.1'
Hint: = is not a valid operator. Did you mean == ?

Using Polynomial Chaos to Propagate Uncertainty through an ODE#

Let \(\Xi\) be a random variable. Consider the stochastic ODE:

\[ \dot{x} = f(t,x;\Xi), \]

with initial conditions:

\[ x_0 = x(0;\Xi). \]

Notice that the solution at time \(t\) is a random variable that is a function of \(\Xi\):

\[ X_t = x(t;\Xi). \]

So, \(X_t\) is in the Hilbert space \(L^2(\Xi)\).

Take the orthonormal polynomial basis of \(L^2(\Xi)\), \(\{e_0, e_1,\dots\}\). We can expand \(X_t\) in this basis:

\[ X_t = \sum_{i=0}^\infty \alpha_i(t) e_i(\Xi). \]

Our goal is to show that the coefficients \(\alpha_i\) satisfy an initial value problem. By solving that initial value problem, we can compute them for all times.

We start by taking the ODE and taking the inner product of both sides with \(e_i\). From the left hand side, we have:

\[ \langle \dot{X}_t, e_i \rangle = \langle \sum_{j=0}^\infty \dot{\alpha}_j(t) e_j,e_i \rangle = \sum_{j=0}^\infty \dot{\alpha}_j(t) \langle e_j, e_i \rangle = \dot{\alpha}_i(t). \]

This is very convenient. The right hand side is a bit more complicated. First, we need to think of \(f(t,X_t;\Xi)\) as a function of \(\Xi\). This will introduce a dependence on the coefficients \(a = (a_0,a_1, \dots)\). We write:

\[ g(t, a;\Xi) = f\left(t, \sum_{i=0}^\infty a_i e_i(\Xi);\Xi\right). \]

Now, think of \(g(t,a;\cdot)\) as a function of \(\Xi\) for fixed \(t\) and \(a\) and take the inner product with \(e_i\):

\[ g_i(t,a) = \langle g(t,a;\cdot), e_i \rangle. \]

Equating the left and right hand sides, we have:

\[ \dot{\alpha}_i(t) = g_i(t,\alpha), \]

for \(i=0,1,\dots\).

The initial conditions are given by:

\[ \alpha_i(0) = \langle x_0, e_i \rangle. \]

In practice we truncate the infinite sum at some \(N\) and solve the system of ODEs:

\[ \dot{\alpha}_i(t) = g_i(t,\alpha), \quad i=0,1,\dots,N, \]

with initial conditions:

\[ \alpha_i(0) = \langle x_0, e_i \rangle, \quad i=0,1,\dots,N. \]

One, typically increases \(N\) until the solution converges.


Another name for what we have done above is Galerkin projection. There is really nothing special about the basis \(\{e_0,e_1,\dots\}\). It was orthonormal polynomials, but it could have been any basis.

The fact that we have used an orthonormal basis enables us to quickly characterize the statistics of \(X_t\). For example, the mean is:

\[ \mu_t = \mathbb{E}[X_t] = \langle X_t, e_0 \rangle = \alpha_0(t). \]

The variance is:

\[ \sigma_t^2 = \|X_t\|^2 - \mu_t^2 = \sum_{i=1}^N \alpha_i(t)^2. \]

Example: Propagating Uncertainty through an ODE#

Consider the exponential decay equation:

\[ \dot{x} = -(0.5 + 0.1\Xi) x \]

with initial condition:

\[ x_0 = x(0) = 1, \]

and random variable \(\Xi\) uniformly distributed on \([-1,1]\).

Let’s construct the polynomials:

import orthojax as ojax

degree = 5
poly = ojax.make_legendre_polynomial(degree)

Now, let’s code up the left hand side of the ODE:

# Theta here are other parameters that are not random. None used here.
f = lambda t, x, xi, theta: -(0.5 + 0.1 * xi) * x

The initial conditions are the trivial function of \(\Xi\):

x0 = lambda xi: 1.0

Now, we need to write code that makes the right hand side and the initial conditions of the dynamical system that governs the coefficients \(\alpha_i\):

import jax.numpy as jnp
from jax import vmap, jit

def make_pc_problem(poly, f, x0, theta):
    # The quadrature rule used to compute inner products
    xis, ws = poly.quad

    # The polynomial basis functions on the collocation points
    phis = poly(xis)

    # The initial condition of the PC coefficients
    x0s = vmap(x0)(xis)
    a0 = jnp.einsum('i,ij,i->j', x0s, phis, ws)
    fv = vmap(f, in_axes=(None, 0, 0, None))
    # The right hand side of the PC ODE
    def rhs(t, a, theta_rhs):
        phis = theta_rhs[0]
        theta = theta_rhs[1:]
        xs = jnp.einsum('i,ni->n', a, phis)
        fs = f(t, xs, xis, theta)
        return jnp.einsum("i,ij,i->j", fs, phis, ws)
    return rhs, a0, (phis, theta)

Here is how to use the code:

rhs, a0, theta = make_pc_problem(poly, f, x0, None)

And now we have the code we need to solve an initial value problem using Diffrax:

from diffrax import diffeqsolve, Tsit5, ODETerm, SaveAt

solver = Tsit5()
saveat = SaveAt(ts=jnp.linspace(0, 10, 2000))
term = ODETerm(rhs)
sol = diffeqsolve(

Here are are the PC coefficients evolving over time:

fig, ax = plt.subplots()
ax.plot(sol.ts, sol.ys)
ax.set(xlabel="$t$", ylabel="$a(t)$", title="PC coefficients")

Here are the statistics with PC:

import numpy as np

pc_mean = sol.ys[:, 0]
pc_var = np.sum(sol.ys[:, 1:] ** 2, axis=1)

Let’s calculate the statistics using Monte Carlo and compare:

xis = -1.0 + 2.0 * np.random.rand(10_000)
true_solution = jit(vmap(lambda xi: x0(None) * jnp.exp(-(0.5 + 0.1 * xi) * sol.ts)))
true_ys = true_solution(xis)
mc_mean = jnp.mean(true_ys, axis=0)
mc_var = jnp.var(true_ys, axis=0)

Here is the comparison:

fig, ax = plt.subplots()
ax.plot(sol.ts, pc_mean, label="PC mean")
ax.plot(sol.ts, mc_mean, '--', label="MC mean")
ax.set(xlabel="$t$", ylabel="$x(t)$", title="Mean")
fig, ax = plt.subplots()
ax.plot(sol.ts, pc_var, label="PC variance")
ax.plot(sol.ts, mc_var, '--', label="MC variance")
ax.set(xlabel="$t$", ylabel="$x(t)$", title="Variance")

Notice, that polynomial chaos model we have created can also serve as a parametric form of the solution of the stochastic ODE. We can evaluate it at any \(\Xi\) like this:

def pc_sol(xis):
    return jnp.einsum("ti,ni->tn", sol.ys, poly(xis))

Here are some samples:

xis = -1.0 + 2.0 * np.random.rand(5)
fig, ax = plt.subplots()
ax.plot(sol.ts, pc_sol(xis), 'r', lw=0.5)
ax.set(xlabel="$t$", ylabel="$x(t)$", title="PC solution")
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.