{
"cells": [
{
"cell_type": "code",
"execution_count": 122,
"metadata": {
"tags": [
"hide-input",
"hide-output"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: diffrax==0.4.1 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (0.4.1)\n",
"Requirement already satisfied: jax>=0.4.13 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from diffrax==0.4.1) (0.4.19)\n",
"Requirement already satisfied: equinox>=0.10.11 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from diffrax==0.4.1) (0.11.2)\n",
"Requirement already satisfied: jaxtyping>=0.2.20 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from equinox>=0.10.11->diffrax==0.4.1) (0.2.25)\n",
"Requirement already satisfied: typing-extensions>=4.5.0 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from equinox>=0.10.11->diffrax==0.4.1) (4.8.0)\n",
"Requirement already satisfied: ml-dtypes>=0.2.0 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jax>=0.4.13->diffrax==0.4.1) (0.3.1)\n",
"Requirement already satisfied: numpy>=1.22 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jax>=0.4.13->diffrax==0.4.1) (1.25.2)\n",
"Requirement already satisfied: opt-einsum in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jax>=0.4.13->diffrax==0.4.1) (3.3.0)\n",
"Requirement already satisfied: scipy>=1.9 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jax>=0.4.13->diffrax==0.4.1) (1.11.3)\n",
"Requirement already satisfied: typeguard<3,>=2.13.3 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jaxtyping>=0.2.20->equinox>=0.10.11->diffrax==0.4.1) (2.13.3)\n",
"\u001b[33mDEPRECATION: graphql-ws 0.3.0 has a non-standard dependency specifier graphql-core>=2.0<3. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of graphql-ws or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063\u001b[0m\u001b[33m\n",
"\u001b[0mRequirement already satisfied: orthojax in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (0.1.5)\n",
"Requirement already satisfied: jax>=0.4.19 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from orthojax) (0.4.19)\n",
"Requirement already satisfied: numpy in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from orthojax) (1.25.2)\n",
"Requirement already satisfied: equinox>=0.11.2 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from orthojax) (0.11.2)\n",
"Requirement already satisfied: jaxtyping>=0.2.20 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from equinox>=0.11.2->orthojax) (0.2.25)\n",
"Requirement already satisfied: typing-extensions>=4.5.0 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from equinox>=0.11.2->orthojax) (4.8.0)\n",
"Requirement already satisfied: ml-dtypes>=0.2.0 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jax>=0.4.19->orthojax) (0.3.1)\n",
"Requirement already satisfied: opt-einsum in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jax>=0.4.19->orthojax) (3.3.0)\n",
"Requirement already satisfied: scipy>=1.9 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jax>=0.4.19->orthojax) (1.11.3)\n",
"Requirement already satisfied: typeguard<3,>=2.13.3 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jaxtyping>=0.2.20->equinox>=0.11.2->orthojax) (2.13.3)\n",
"\u001b[33mDEPRECATION: graphql-ws 0.3.0 has a non-standard dependency specifier graphql-core>=2.0<3. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of graphql-ws or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063\u001b[0m\u001b[33m\n",
"\u001b[0mRequirement already satisfied: py-design in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (2.0)\n",
"\u001b[33mDEPRECATION: graphql-ws 0.3.0 has a non-standard dependency specifier graphql-core>=2.0<3. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of graphql-ws or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063\u001b[0m\u001b[33m\n",
"\u001b[0m"
]
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"import matplotlib_inline\n",
"matplotlib_inline.backend_inline.set_matplotlib_formats('svg')\n",
"import seaborn as sns\n",
"sns.set_context(\"paper\")\n",
"sns.set_style(\"ticks\");\n",
"\n",
"!pip install diffrax==0.4.1\n",
"!pip install orthojax --upgrade\n",
"!pip install py-design --upgrade"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Uncertainty Propagation in Dynamical Systems\n",
"\n",
"Consider the following $n$ dimensional dynamical system:\n",
"\n",
"$$\n",
"\\dot{\\mathbf{x}} = \\mathbf{f}(t,\\mathbf{x};\\boldsymbol{\\Xi}),\n",
"$$\n",
"\n",
"where $\\mathbf{x} \\in \\mathbb{R}^n$ is the state vector, $\\mathbf{f}$ is a vector valued function, $t$ is time, and $\\Xi$ is a vector of uncertain parameters.\n",
"The initial condition is:\n",
"\n",
"$$\n",
"\\mathbf{x}(0) = \\mathbf{x}_0(\\boldsymbol{\\Xi}).\n",
"$$\n",
"\n",
"We assume that:\n",
"\n",
"$$\n",
"\\boldsymbol{\\Xi} = (\\Xi_1,\\dots,\\Xi_d),\n",
"$$\n",
"\n",
"independent. \n",
"The goal is to propagate the uncertainty in $\\boldsymbol{\\Xi}$ through the system to obtain the uncertainty in $\\mathbf{x}$.\n",
"\n",
"The space we work with is:\n",
"\n",
"$$\n",
"L^2(\\mathbf{\\Xi},\\mathbb{R}^n) = \\left\\{\\mathbf{g}:\\mathbb{R}^d\\to\\mathbb{R}^n\\mid \\int_{\\mathbb{R}^d} \\|\\mathbf{g}(\\boldsymbol{\\Xi})\\|^2\\,d\\boldsymbol{\\Xi} < \\infty\\right\\},\n",
"$$\n",
"\n",
"with inner product:\n",
"\n",
"$$\n",
"\\langle \\mathbf{g},\\mathbf{h}\\rangle = \\int_{\\mathbb{R}^d} \\mathbf{g}(\\boldsymbol{\\Xi})\\cdot\\mathbf{h}(\\boldsymbol{\\Xi})\\,d\\boldsymbol{\\Xi}.\n",
"$$\n",
"\n",
"Let $\\{\\phi_\\alpha\\}$ be the tensor product orthonormal basis for $L^2(\\mathbf{\\Xi})$ and $\\{\\mathbf{e}_i\\}$ be the standard basis for $\\mathbb{R}^n$.\n",
"Then, the functions:\n",
"\n",
"$$\n",
"\\boldsymbol{\\psi}_{i,\\alpha} = \\mathbf{e}_i\\phi_\\alpha,\n",
"$$\n",
"\n",
"form an orthonormal basis for $L^2(\\mathbf{\\Xi},\\mathbb{R}^n)$.\n",
"\n",
"We expand the dynamical system state in this basis (at each time):\n",
"\n",
"$$\n",
"\\mathbf{x}(t;\\Xi) = \\sum_{i=1}^n \\sum_{\\alpha} \\mathbf{x}_{i,\\alpha}(t)\\boldsymbol{\\psi}_{i,\\alpha}(\\boldsymbol{\\Xi}).\n",
"$$\n",
"\n",
"We plug this into the dynamical system to get:\n",
"\n",
"$$\n",
"\\dot{\\mathbf{x}}(t;\\Xi) = \\sum_{i=1}^n \\sum_{\\alpha} \\dot{\\mathbf{x}}_{i,\\alpha}(t)\\boldsymbol{\\psi}_{i,\\alpha}(\\boldsymbol{\\Xi})\n",
"= \\mathbf{f}\\left(t,\\sum_{i=1}^n \\sum_{\\alpha} \\mathbf{x}_{i,\\alpha}(t)\\boldsymbol{\\psi}_{i,\\alpha}(\\boldsymbol{\\Xi});\\boldsymbol{\\Xi}\\right).\n",
"$$\n",
"\n",
"We project each side onto $\\psi_{j,\\beta}$ and use the orthogonality of the basis to get:\n",
"\n",
"$$\n",
"\\dot{\\mathbf{x}}_{j,\\beta}(t) = \\left\\langle \\mathbf{f}\\left(t,\\sum_{i=1}^n \\sum_{\\alpha} \\mathbf{x}_{i,\\alpha}(t)\\psi_{i,\\alpha}(\\boldsymbol{\\Xi});\\boldsymbol{\\Xi}\\right),\\boldsymbol{\\psi}_{j,\\beta}\\right\\rangle.\n",
"$$\n",
"\n",
"This is a differential equation that describes the evolution of the coefficients $\\mathbf{x}_{i,\\alpha}(t)$.\n",
"The initial condition is:\n",
"\n",
"$$\n",
"\\mathbf{x}_{i,\\alpha}(0) = \\left\\langle \\mathbf{x}_0(\\boldsymbol{\\Xi}),\\boldsymbol{\\psi}_{i,\\alpha}\\right\\rangle.\n",
"$$\n",
"\n",
"Let's write `jax` code that solves this problem."
]
},
{
"cell_type": "code",
"execution_count": 123,
"metadata": {},
"outputs": [],
"source": [
"from collections import namedtuple\n",
"\n",
"import orthojax as ojax\n",
"import design\n",
"import jax.numpy as jnp\n",
"from jax import vmap, jit\n",
"\n",
"\n",
"def make_sparse_grid(dim, level):\n",
" \"\"\"Make a sparse grid of dimension dim and a given level.\n",
" We do it for the uniform cube [-1, 1]^d.\"\"\"\n",
" x, w = design.sparse_grid(dim, level, 'F2')\n",
" w = w / (2 ** dim)\n",
" x = jnp.array(x, dtype=jnp.float32)\n",
" w = jnp.array(w, dtype=jnp.float32)\n",
" return ojax.QuadratureRule(x, w)\n",
"\n",
"\n",
"PCProblem = namedtuple(\"PCProblem\", [\"poly\", \"quad\", \"f\", \"x0\", \"phis\", \"y0\", \"rhs\"])\n",
"\n",
"\n",
"def make_pc_problem(poly, quad, f, x0):\n",
" \"\"\"Make the PC dynamical system problem.\n",
"\n",
" Params:\n",
" poly: The polynomial basis\n",
" quad: The quadrature rule used to compute inner products\n",
" f: The function defining the right hand side of the ODE (function of x, t and xi) to R^n\n",
" x0: The initial condition (function of xi, from R^d -> R^n)\n",
" theta: The parameters of the ODE\n",
" \"\"\"\n",
" # The quadrature rule used to compute inner products\n",
" xis, ws = quad\n",
" # xis is m x d and ws is m\n",
"\n",
" # The polynomial basis functions on the collocation points\n",
" phis = poly(xis)\n",
" # this is m x p\n",
"\n",
" # The initial condition of the PC coefficients\n",
" x0s = jit(vmap(x0))(xis) # this is m x n\n",
" # The PC coefficients are n x p\n",
" # ws is m\n",
" # phis is m x p\n",
" # x0s is m x n\n",
" # y0 must be n x p\n",
" y0 = jnp.einsum(\"m,mp,mn->np\", ws, phis, x0s)\n",
" \n",
" # Vectorize the function f\n",
" fv = vmap(f, in_axes=(None, 0, 0))\n",
" \n",
" # The right hand side of the PC ODE\n",
" def rhs(t, y, phis):\n",
" # y is n x p\n",
" # phis is m x p\n",
" # xs must be m x n\n",
" xs = jnp.einsum(\"np,mp->mn\", y, phis)\n",
" # xs is m x n\n",
" # xis is m x d\n",
" # fs must be m x n\n",
" fs = fv(t, xs, xis)\n",
" # do the dot product with quadrature weights\n",
" return jnp.einsum(\"m,mn,mp->np\", ws, fs, phis)\n",
" \n",
" return PCProblem(poly, quad, f, x0, phis, y0, rhs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example: Duffing Oscillator with Random Initial State\n",
"\n",
"$$\\begin{align}\n",
"\\dot{x} & = v \\\\\n",
"\\dot{v} & = \\gamma \\cos(\\omega t) - \\delta v - \\alpha x - \\beta x^3,\n",
"\\end{align}$$\n",
"\n",
"With initial state:\n",
"\n",
"$$\n",
"x(0) \\sim N(\\mu_x, \\sigma_x^2), \\quad v(0) \\sim N(\\mu_y, \\sigma_v^2).\n",
"$$\n",
"\n",
"We are going to keep the parameters $\\alpha,\\beta,\\gamma,\\delta, \\omega$ fixed and only vary the initial state.\n",
"\n",
"The first thing we are going to do is express the initial conditions in terms of independent random variables $\\Xi_1,\\Xi_2 \\sim U[-1,1]$.\n",
"This will allow us to use Legendre polynomials.\n",
"Let $\\Phi$ be the CDF of the standard normal distribution.\n",
"Then:\n",
"\n",
"$$\n",
"x(0) = \\mu_x + \\sigma_x \\Phi^{-1}\\left((\\Xi_1 + 1) / 2\\right),\n",
"\\quad v(0) = \\mu_v + \\sigma_v \\Phi^{-1}\\left((\\Xi_2 + 1) / 2\\right).\n",
"$$\n",
"\n",
"We are going to develop both a Monte Carlo solver and a polynomial chaos solver.\n",
"But first, let's get a bit organized.\n",
"We are going to create some useful named tuples to hold the parameters and the initial conditions."
]
},
{
"cell_type": "code",
"execution_count": 124,
"metadata": {},
"outputs": [],
"source": [
"import equinox as eqx\n",
"from collections import namedtuple\n",
"\n",
"NormalDistribution = namedtuple(\"NormalDistribution\", [\"mu\", \"sigma\"])\n",
"Parameters = namedtuple(\"Parameters\", [\"alpha\", \"beta\", \"gamma\", \"delta\", \"omega\"])\n",
"\n",
"Duffing = namedtuple(\"Duffing\", [\"params\", \"X\", \"V\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These can be used as follows:"
]
},
{
"cell_type": "code",
"execution_count": 125,
"metadata": {},
"outputs": [],
"source": [
"X = NormalDistribution(0.0, 0.1)\n",
"V = NormalDistribution(0.0, 0.1)\n",
"\n",
"params = Parameters(1.0, 5.0, 0.37, 0.1, 1.0)\n",
"\n",
"duffing = Duffing(params, X, V)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here is how they appear:"
]
},
{
"cell_type": "code",
"execution_count": 126,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Duffing(params=Parameters(alpha=1.0, beta=5.0, gamma=0.37, delta=0.1, omega=1.0), X=NormalDistribution(mu=0.0, sigma=0.1), V=NormalDistribution(mu=0.0, sigma=0.1))\n"
]
}
],
"source": [
"print(duffing)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is, of course, a `pytree` and it will help us write functions with not so many arguments.\n",
"\n",
"Now, let's write code that implements the initial conditions and vector field:"
]
},
{
"cell_type": "code",
"execution_count": 127,
"metadata": {},
"outputs": [],
"source": [
"from jax.scipy import stats as jstats\n",
"from functools import partial\n",
"from diffrax import diffeqsolve, Tsit5, SaveAt, ODETerm\n",
"\n",
"\n",
"def to_normal(xi : float, dist : NormalDistribution) -> float:\n",
" \"\"\"Transforms a [-1, 1] to a normal distribution.\"\"\"\n",
" return dist.mu + dist.sigma * jstats.norm.ppf(0.5 * (xi + 1))\n",
"\n",
"def x0(xi, duffing : Duffing):\n",
" \"\"\"Initial condition for the position.\"\"\"\n",
" return jnp.array(\n",
" [to_normal(xi[0], duffing.X), to_normal(xi[1], duffing.V)]\n",
" )\n",
"\n",
"def vector_field(t, y, params):\n",
" x = y[0]\n",
" v = y[1]\n",
" alpha = params.alpha\n",
" beta = params.beta\n",
" gamma = params.gamma\n",
" delta = params.delta\n",
" omega = params.omega\n",
" return jnp.array(\n",
" [\n",
" v,\n",
" - alpha * x - beta * x ** 3 - delta * v + gamma * jnp.cos(omega * t)\n",
" ]\n",
" )\n",
"\n",
"@jit\n",
"@partial(vmap, in_axes=(0, None))\n",
"def solve_duffing(xi, duffing : Duffing):\n",
" \"\"\"Simple solver of the dynamical system.\"\"\"\n",
" solver = Tsit5()\n",
" saveat = SaveAt(ts=jnp.linspace(0, 10, 2000))\n",
" term = ODETerm(vector_field)\n",
" sol = diffeqsolve(\n",
" term,\n",
" solver,\n",
" t0=0, \n",
" t1=10, \n",
" dt0=0.1, \n",
" y0=x0(xi, duffing),\n",
" args=duffing.params,\n",
" saveat=saveat\n",
" )\n",
" return sol.ys"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Develop the Monte Carlo ground truth:"
]
},
{
"cell_type": "code",
"execution_count": 128,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"num_samples = 100_000\n",
"xis = 2 * np.random.uniform(size=(num_samples, 2)) - 1\n",
"samples = solve_duffing(xis, duffing)\n",
"\n",
"mc_mean = jnp.mean(samples, axis=0)\n",
"mc_var = jnp.var(samples, axis=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, let's write a polynomial chaos solver. First, construct the polynomials and the quadrature rule:"
]
},
{
"cell_type": "code",
"execution_count": 129,
"metadata": {},
"outputs": [],
"source": [
"from functools import partial\n",
"\n",
"total_degree = 5\n",
"degrees = (5, 5)\n",
"poly = ojax.TensorProduct(\n",
" total_degree,\n",
" [ojax.make_legendre_polynomial(d) for d in degrees])\n",
"level = 5\n",
"quad = make_sparse_grid(2, level)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, make the polynomial chaos solver:"
]
},
{
"cell_type": "code",
"execution_count": 130,
"metadata": {},
"outputs": [],
"source": [
"new_vector_field = lambda t, x, xi: vector_field(t, x, duffing.params)\n",
"new_x0 = lambda xi: x0(xi, duffing)\n",
"pc_problem = make_pc_problem(poly, quad, new_vector_field, new_x0)"
]
},
{
"cell_type": "code",
"execution_count": 131,
"metadata": {},
"outputs": [],
"source": [
"@jit\n",
"def solve_duffing_pc(duffing, poly=poly, quad=quad):\n",
" # Adhere to the PCProblem interface\n",
" new_vector_field = lambda t, x, xi: vector_field(t, x, duffing.params)\n",
" new_x0 = lambda xi: x0(xi, duffing)\n",
" pc_problem = make_pc_problem(poly, quad, new_vector_field, new_x0)\n",
" sol = diffeqsolve(\n",
" ODETerm(pc_problem.rhs),\n",
" Tsit5(),\n",
" t0=0,\n",
" t1=10,\n",
" dt0=0.1,\n",
" y0=pc_problem.y0,\n",
" args=pc_problem.phis,\n",
" saveat=SaveAt(ts=jnp.linspace(0, 10, 2000))\n",
" )\n",
" return sol"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And now we can solve it as follows:"
]
},
{
"cell_type": "code",
"execution_count": 132,
"metadata": {},
"outputs": [],
"source": [
"pc_sol = solve_duffing_pc(duffing)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's calculate the mean and the variance of PC:"
]
},
{
"cell_type": "code",
"execution_count": 133,
"metadata": {},
"outputs": [],
"source": [
"pc_mean = pc_sol.ys[:, :, 0]\n",
"pc_variance = np.sum(pc_sol.ys[:, :, 1:] ** 2, axis=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's compare the Monte Carlo solution with the polynomial chaos solution:"
]
},
{
"cell_type": "code",
"execution_count": 134,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
"