{
"cells": [
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"import matplotlib_inline\n",
"matplotlib_inline.backend_inline.set_matplotlib_formats('svg')\n",
"import seaborn as sns\n",
"sns.set_context(\"paper\")\n",
"sns.set_style(\"ticks\");"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Numerical Estimation of Orthogonal Polynomials\n",
"\n",
"Walter Gautchi (a CS professor at Purdue) had created a Fortran package (ORTHPOL) that constructs orthogonal polynomials given an arbitrary weight function. The package is available at [here](https://www.cs.purdue.edu/homes/wxg/selected_works/section_11/141.pdf).\n",
"I have recoded the package in Python and added some additional features.\n",
"\n",
"The package is based on two important theorems.\n",
"In what follows, there is a random variable $\\Xi$ and we are working within the Hilbert space $L^2(\\Xi)$ with the usual inner product.\n",
"The polynomials we construct are named $\\pi_k(\\xi)$.\n",
"\n",
"**Theorem 1:** *There is a unique set of (monic) orthogonal polynomials:*\n",
"\n",
"$$\n",
"\\pi_k(\\xi) = \\xi^k + \\text{lower order terms},\n",
"$$\n",
"\n",
"*satisfying:*\n",
"\n",
"$$\n",
"\\langle \\pi_k, \\pi_l \\rangle = 0 \\quad \\text{if} \\quad k \\neq l.\n",
"$$\n",
"\n",
"The other important theorem is the *three-term recurrence relation*.\n",
"This relation enables us to construct and evaluate the polynomials efficiently.\n",
"\n",
"**Theorem 2:** *The polynomials satisfy the three-term recurrence relation:*\n",
"\n",
"$$\n",
"\\pi_{k+1}(\\xi) = (\\xi - \\alpha_k) \\pi_k(\\xi) - \\beta_k \\pi_{k-1}(\\xi).\n",
"$$\n",
"\n",
"*where:*\n",
"\n",
"$$\n",
"\\alpha_k = \\frac{\\langle \\xi \\pi_k, \\pi_k \\rangle}{\\langle \\pi_k, \\pi_k \\rangle}\n",
"$$\n",
"\n",
"*and:*\n",
"\n",
"$$\n",
"\\beta_k = \\frac{\\langle \\pi_k, \\pi_k \\rangle}{\\langle \\pi_{k-1}, \\pi_{k-1} \\rangle}\n",
"$$\n",
"\n",
"*and:*\n",
"\n",
"$$\n",
"\\pi_{-1}(\\xi) = 0 \\quad \\text{and} \\quad \\pi_0(\\xi) = 1.\n",
"$$\n",
"\n",
"Orthonormal polynomials constructed in this way are called **polynomial chaos**.\n",
"The name has nothing to do with chaos theory of dynamical systems."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Orthojax\n",
"\n",
"You will have to install the package [orthojax](https://pypi.org/project/orthojax/):"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"tags": [
"hide-output"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: orthojax in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (0.1.4)\n",
"Requirement already satisfied: jax>=0.4.19 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from orthojax) (0.4.19)\n",
"Requirement already satisfied: numpy in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from orthojax) (1.25.2)\n",
"Requirement already satisfied: equinox>=0.11.2 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from orthojax) (0.11.2)\n",
"Requirement already satisfied: jaxtyping>=0.2.20 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from equinox>=0.11.2->orthojax) (0.2.25)\n",
"Requirement already satisfied: typing-extensions>=4.5.0 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from equinox>=0.11.2->orthojax) (4.8.0)\n",
"Requirement already satisfied: ml-dtypes>=0.2.0 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jax>=0.4.19->orthojax) (0.3.1)\n",
"Requirement already satisfied: opt-einsum in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jax>=0.4.19->orthojax) (3.3.0)\n",
"Requirement already satisfied: scipy>=1.9 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jax>=0.4.19->orthojax) (1.11.3)\n",
"Requirement already satisfied: typeguard<3,>=2.13.3 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jaxtyping>=0.2.20->equinox>=0.11.2->orthojax) (2.13.3)\n",
"\u001b[33mDEPRECATION: graphql-ws 0.3.0 has a non-standard dependency specifier graphql-core>=2.0<3. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of graphql-ws or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063\u001b[0m\u001b[33m\n",
"\u001b[0m"
]
}
],
"source": [
"!pip install orthojax --upgrade"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example: Uniform random variable on $[0, 2]$\n",
"\n",
"Let $\\Xi \\sim U([0,2])$. The PDF is:\n",
"\n",
"$$\n",
"p(\\xi) = \\frac{1}{2} \\quad \\text{for} \\quad \\xi \\in [0,2].\n",
"$$\n",
"\n",
"Here is how we can construct the polynomials:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"import orthojax as ojax\n",
"\n",
"degree = 10\n",
"pdf = lambda xi: 0.5\n",
"poly = ojax.make_orthogonal_polynomial(degree, left=0.0, right=2.0, wf=pdf)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here is how the object looks like:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"OrthogonalPolynomial(\n",
" alpha=f32[11],\n",
" beta=f32[11],\n",
" gamma=f32[11],\n",
" quad=QuadratureRule(x=f32[100], w=f32[100])\n",
")"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"poly"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here are the recursion coefficients:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Array([1. , 1.0000025 , 0.99999017, 1.0000262 , 0.999944 ,\n",
" 1.0001017 , 0.9998344 , 1.0002512 , 0.9996395 , 1.0004972 ,\n",
" 0.9993389 ], dtype=float32),\n",
" Array([1. , 0.5773435 , 0.51641864, 0.5070478 , 0.50403017,\n",
" 0.50240046, 0.5019109 , 0.5010617 , 0.5012555 , 0.50043714,\n",
" 0.50102293], dtype=float32),\n",
" Array([1. , 0.5773435 , 0.51641864, 0.5070478 , 0.50403017,\n",
" 0.50240046, 0.5019109 , 0.5010617 , 0.5012555 , 0.50043714,\n",
" 0.50102293], dtype=float32))"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"poly.alpha, poly.beta, poly.gamma"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's plot the polynomials:"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(200, 11)"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"\n",
"xis = np.linspace(0.0, 2.0, 200)\n",
"phi = poly(xis)\n",
"phi.shape"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
"