Solving Parametric Problems using Physics-informed Neural Networks

Hide code cell source
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks");

!pip show equinox || echo equinox not found. Installing... && pip install equinox 2> /dev/null
Hide code cell output
Name: equinox
Version: 0.11.2
Summary: Elegant easy-to-use neural networks in JAX.
Home-page: 
Author: 
Author-email: Patrick Kidger <contact@kidger.site>
License: Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.
Location: /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages
Requires: jax, jaxtyping, typing-extensions
Required-by: diffrax
Requirement already satisfied: equinox in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (0.11.2)
Requirement already satisfied: jax>=0.4.13 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from equinox) (0.4.19)
Requirement already satisfied: jaxtyping>=0.2.20 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from equinox) (0.2.23)
Requirement already satisfied: typing-extensions>=4.5.0 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from equinox) (4.8.0)
Requirement already satisfied: ml-dtypes>=0.2.0 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jax>=0.4.13->equinox) (0.3.1)
Requirement already satisfied: numpy>=1.22 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jax>=0.4.13->equinox) (1.25.2)
Requirement already satisfied: opt-einsum in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jax>=0.4.13->equinox) (3.3.0)
Requirement already satisfied: scipy>=1.9 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jax>=0.4.13->equinox) (1.11.3)
Requirement already satisfied: typeguard<3,>=2.13.3 in /Users/ibilion/.pyenv/versions/3.11.6/lib/python3.11/site-packages (from jaxtyping>=0.2.20->equinox) (2.13.3)
Hide code cell source
import jax
import equinox as eqx

class FourierEncoding(eqx.Module):
    B: jax.Array

    @property
    def num_fourier_features(self) -> int:
        return self.B.shape[0]

    @property
    def in_size(self) -> int:
        return self.B.shape[1]
    
    @property
    def out_size(self) -> int:
        return self.B.shape[0] * 2

    def __init__(self, 
                 in_size: int, 
                 num_fourier_features: int, 
                 key: jax.random.PRNGKey, 
                 sigma: float = 1.0):
        self.B = jax.random.normal(
            key, shape=(num_fourier_features, in_size),
            dtype=jax.numpy.float32) * sigma
    
    def __call__(self, x: jax.Array, **kwargs) -> jax.Array:
        return jax.numpy.concatenate(
            [jax.numpy.cos(jax.numpy.dot(self.B, x)),
             jax.numpy.sin(jax.numpy.dot(self.B, x))],
            axis=0)

Solving Parametric Problems using Physics-informed Neural Networks#

As you have may have noticed, physics-informed neural networks (PINNs) are not a good choice for solving forward problems. Standard numerical methods are much faster and more accurate. One of the places where PINNs can potentially beat traditional approaches is in the solution of \emph{parametric problems}. In these problems, one has an ODE or PDE that depends on some parameters, and one wants to solve the problem for many different values of the parameters. PINNs can help us learn, in one shot, the map from the parameters to the solution of the problem.

It is straightforward to extend the standard PINN approach to parametric problems. Suppose that the parameters are \(\boldsymbol{\xi}\), an \(\mathbb{R}^d\)-valued random vector with PDF \(p(\boldsymbol{\xi})\). The loss function, e.g., from the integrated squared residual, is now dependent on \(\boldsymbol{\xi}\):

\[ \boldsymbol{\xi} \mapsto \mathcal{L}_{\boldsymbol{\xi}}(\boldsymbol{\xi}). \]

The idea, is three-fold. First, try to promote inductive biases by baking in the boundary/initial conditions. Second, parameterize the uknown part using a neural network \(N(\boldsymbol{x}, \boldsymbol{\xi};\theta)\), with parameters \(\theta\) A good choice for the parameterization is this structure:

\[ N(\mathbf{x}, \boldsymbol{\xi};\theta) = \sum_{i=1}^m b_i(\boldsymbol{\xi};\theta)\phi_i(\mathbf{x};\theta). \]

This is a good choice, because it resembles an expansion in a basis of functions \(\phi_i(\mathbf{x};\theta)\), with coefficients \(b_i(\boldsymbol{\xi};\theta)\). The difference here is that the basis functions are learned - they are not fixed. Third, we find the parameters \(\theta\), by minimizing the expected loss:

\[ \mathcal{L}(\theta) = \int_{\mathbb{R}^d} \mathcal{L}_{\boldsymbol{\xi}}(\boldsymbol{\xi}) p(\boldsymbol{\xi}) d\boldsymbol{\xi}. \]

We can do this using Adam, or any other optimization algorithm. All we need is to express the loss as an expectation and construct an unbiased estimator of the gradient of the loss with respect to the parameters \(\theta\).

Example: Parametric Poisson Equation#

Let’s consider the following parametric Poisson equation:

\begin{aligned} -\Delta u(\mathbf{x};\xi) &= f(\mathbf{x};\xi), \quad \mathbf{x} \in \Omega, \ u(\mathbf{x};\xi) &= 0, \quad \mathbf{x}\in \partial\Omega, \end{aligned}

where \(\Omega = [0,1]^2\). The source term is:

\[ f(\mathbf{x};\boldsymbol{\xi}) = 2\pi^2\xi\sin(\pi x) \sin(\pi y). \]

The exact solution to this problem is:

\[ u(\mathbf{x};\boldsymbol{\xi}) = \xi\sin(\pi x) \sin(\pi y). \]

Let’s verify that this is indeed the solution to the problem.

import sympy as sp
from sympy import symbols

x, y, xi = symbols('x y xi')
f = 2 * sp.pi ** 2 * xi * sp.sin(sp.pi * x) * sp.sin(sp.pi * y)
u = xi * sp.sin(sp.pi * x) * sp.sin(sp.pi * y)
res = sp.simplify(sp.diff(u, x, x) + sp.diff(u, y, y) + f)
res
\[\displaystyle 0\]

I have purposely chosen a simple problem, so that we don’t have to deal with non-dimensionalization and scaling. A parameterization that satisfies the boundary conditions is:

\[ u(x,y;\xi;\theta) = x (1 - x) y (1 - y) N(x,y;\xi;\theta), \]

where \(N(x,y;\xi;\theta)\) is a neural network with parameters \(\theta\). For the neural network, we will use a the structure suggested above.

First, some useful classes to avoid code repetition.

import equinox as eqx

class ParametricModel(eqx.Module):
    """This model captures a simple structure made out of branches and trunks."""
    branch: list  # These are the b's
    trunk: list   # These are the phi's

    def __init__(self, branch, trunk):
        self.branch = branch
        self.trunk = trunk

    def __call__(self, x, xi, **kwargs):
        res = 0.0
        for b, t in zip(self.branch, self.trunk):
            res += b(xi) * t(x)
        return res


class EnforceDirichletZeroBoundarySquare(eqx.Module):
    """This is an model that enforces zero Dirichlet boundary conditions on the square."""
    neural_net: eqx.Module

    def __init__(self, neural_net):
        self.neural_net = neural_net
    
    def __call__(self, x, y, xi, **kwargs):
        return self.neural_net(jnp.array([x, y]), xi) * (1 - x) * x * (1 - y) * y

Let’s now make the actual model we will use:

import jax
import jax.numpy as jnp
import jax.random as jrandom

key = jrandom.PRNGKey(0)
key1, key2, key = jrandom.split(key, 3)

# Make the parameterization
m = 10 # number of branches and trunks
# Branches are all MLP's
branch_width = 32
branch_depth = 3
branch = [eqx.nn.MLP('scalar', 'scalar', branch_width, branch_depth, jax.nn.tanh, key=k) for k in jrandom.split(key1, m)]
# Trunks have a FourierEncoding followed by an MLP
trunk_num_fourier_features = 32
trunk_width = 32
trunk_depth = 3
trunk = [eqx.nn.Sequential([
    FourierEncoding(2, trunk_num_fourier_features, key=k),
    eqx.nn.MLP(trunk_num_fourier_features * 2, 'scalar', trunk_width, trunk_depth, jax.nn.tanh, key=k)
]) for k in jrandom.split(key2, m)]
# Combine the branches and trunks into a ParametricModel
model = EnforceDirichletZeroBoundarySquare(ParametricModel(branch, trunk))

Here is how our model looks like as a PyTree:

model
Hide code cell output
EnforceDirichletZeroBoundarySquare(
  neural_net=ParametricModel(
    branch=[
      MLP(
        layers=(
          Linear(
            weight=f32[32,1],
            bias=f32[32],
            in_features='scalar',
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[32,32],
            bias=f32[32],
            in_features=32,
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[32,32],
            bias=f32[32],
            in_features=32,
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[1,32],
            bias=f32[1],
            in_features=32,
            out_features='scalar',
            use_bias=True
          )
        ),
        activation=<wrapped function <lambda>>,
        final_activation=<function <lambda>>,
        use_bias=True,
        use_final_bias=True,
        in_size='scalar',
        out_size='scalar',
        width_size=32,
        depth=3
      ),
      MLP(
        layers=(
          Linear(
            weight=f32[32,1],
            bias=f32[32],
            in_features='scalar',
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[32,32],
            bias=f32[32],
            in_features=32,
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[32,32],
            bias=f32[32],
            in_features=32,
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[1,32],
            bias=f32[1],
            in_features=32,
            out_features='scalar',
            use_bias=True
          )
        ),
        activation=<wrapped function <lambda>>,
        final_activation=<function <lambda>>,
        use_bias=True,
        use_final_bias=True,
        in_size='scalar',
        out_size='scalar',
        width_size=32,
        depth=3
      ),
      MLP(
        layers=(
          Linear(
            weight=f32[32,1],
            bias=f32[32],
            in_features='scalar',
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[32,32],
            bias=f32[32],
            in_features=32,
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[32,32],
            bias=f32[32],
            in_features=32,
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[1,32],
            bias=f32[1],
            in_features=32,
            out_features='scalar',
            use_bias=True
          )
        ),
        activation=<wrapped function <lambda>>,
        final_activation=<function <lambda>>,
        use_bias=True,
        use_final_bias=True,
        in_size='scalar',
        out_size='scalar',
        width_size=32,
        depth=3
      ),
      MLP(
        layers=(
          Linear(
            weight=f32[32,1],
            bias=f32[32],
            in_features='scalar',
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[32,32],
            bias=f32[32],
            in_features=32,
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[32,32],
            bias=f32[32],
            in_features=32,
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[1,32],
            bias=f32[1],
            in_features=32,
            out_features='scalar',
            use_bias=True
          )
        ),
        activation=<wrapped function <lambda>>,
        final_activation=<function <lambda>>,
        use_bias=True,
        use_final_bias=True,
        in_size='scalar',
        out_size='scalar',
        width_size=32,
        depth=3
      ),
      MLP(
        layers=(
          Linear(
            weight=f32[32,1],
            bias=f32[32],
            in_features='scalar',
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[32,32],
            bias=f32[32],
            in_features=32,
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[32,32],
            bias=f32[32],
            in_features=32,
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[1,32],
            bias=f32[1],
            in_features=32,
            out_features='scalar',
            use_bias=True
          )
        ),
        activation=<wrapped function <lambda>>,
        final_activation=<function <lambda>>,
        use_bias=True,
        use_final_bias=True,
        in_size='scalar',
        out_size='scalar',
        width_size=32,
        depth=3
      ),
      MLP(
        layers=(
          Linear(
            weight=f32[32,1],
            bias=f32[32],
            in_features='scalar',
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[32,32],
            bias=f32[32],
            in_features=32,
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[32,32],
            bias=f32[32],
            in_features=32,
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[1,32],
            bias=f32[1],
            in_features=32,
            out_features='scalar',
            use_bias=True
          )
        ),
        activation=<wrapped function <lambda>>,
        final_activation=<function <lambda>>,
        use_bias=True,
        use_final_bias=True,
        in_size='scalar',
        out_size='scalar',
        width_size=32,
        depth=3
      ),
      MLP(
        layers=(
          Linear(
            weight=f32[32,1],
            bias=f32[32],
            in_features='scalar',
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[32,32],
            bias=f32[32],
            in_features=32,
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[32,32],
            bias=f32[32],
            in_features=32,
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[1,32],
            bias=f32[1],
            in_features=32,
            out_features='scalar',
            use_bias=True
          )
        ),
        activation=<wrapped function <lambda>>,
        final_activation=<function <lambda>>,
        use_bias=True,
        use_final_bias=True,
        in_size='scalar',
        out_size='scalar',
        width_size=32,
        depth=3
      ),
      MLP(
        layers=(
          Linear(
            weight=f32[32,1],
            bias=f32[32],
            in_features='scalar',
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[32,32],
            bias=f32[32],
            in_features=32,
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[32,32],
            bias=f32[32],
            in_features=32,
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[1,32],
            bias=f32[1],
            in_features=32,
            out_features='scalar',
            use_bias=True
          )
        ),
        activation=<wrapped function <lambda>>,
        final_activation=<function <lambda>>,
        use_bias=True,
        use_final_bias=True,
        in_size='scalar',
        out_size='scalar',
        width_size=32,
        depth=3
      ),
      MLP(
        layers=(
          Linear(
            weight=f32[32,1],
            bias=f32[32],
            in_features='scalar',
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[32,32],
            bias=f32[32],
            in_features=32,
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[32,32],
            bias=f32[32],
            in_features=32,
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[1,32],
            bias=f32[1],
            in_features=32,
            out_features='scalar',
            use_bias=True
          )
        ),
        activation=<wrapped function <lambda>>,
        final_activation=<function <lambda>>,
        use_bias=True,
        use_final_bias=True,
        in_size='scalar',
        out_size='scalar',
        width_size=32,
        depth=3
      ),
      MLP(
        layers=(
          Linear(
            weight=f32[32,1],
            bias=f32[32],
            in_features='scalar',
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[32,32],
            bias=f32[32],
            in_features=32,
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[32,32],
            bias=f32[32],
            in_features=32,
            out_features=32,
            use_bias=True
          ),
          Linear(
            weight=f32[1,32],
            bias=f32[1],
            in_features=32,
            out_features='scalar',
            use_bias=True
          )
        ),
        activation=<wrapped function <lambda>>,
        final_activation=<function <lambda>>,
        use_bias=True,
        use_final_bias=True,
        in_size='scalar',
        out_size='scalar',
        width_size=32,
        depth=3
      )
    ],
    trunk=[
      Sequential(
        layers=(
          FourierEncoding(B=f32[32,2]),
          MLP(
            layers=(
              Linear(
                weight=f32[32,64],
                bias=f32[32],
                in_features=64,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[32,32],
                bias=f32[32],
                in_features=32,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[32,32],
                bias=f32[32],
                in_features=32,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[1,32],
                bias=f32[1],
                in_features=32,
                out_features='scalar',
                use_bias=True
              )
            ),
            activation=<wrapped function <lambda>>,
            final_activation=<function <lambda>>,
            use_bias=True,
            use_final_bias=True,
            in_size=64,
            out_size='scalar',
            width_size=32,
            depth=3
          )
        )
      ),
      Sequential(
        layers=(
          FourierEncoding(B=f32[32,2]),
          MLP(
            layers=(
              Linear(
                weight=f32[32,64],
                bias=f32[32],
                in_features=64,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[32,32],
                bias=f32[32],
                in_features=32,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[32,32],
                bias=f32[32],
                in_features=32,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[1,32],
                bias=f32[1],
                in_features=32,
                out_features='scalar',
                use_bias=True
              )
            ),
            activation=<wrapped function <lambda>>,
            final_activation=<function <lambda>>,
            use_bias=True,
            use_final_bias=True,
            in_size=64,
            out_size='scalar',
            width_size=32,
            depth=3
          )
        )
      ),
      Sequential(
        layers=(
          FourierEncoding(B=f32[32,2]),
          MLP(
            layers=(
              Linear(
                weight=f32[32,64],
                bias=f32[32],
                in_features=64,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[32,32],
                bias=f32[32],
                in_features=32,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[32,32],
                bias=f32[32],
                in_features=32,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[1,32],
                bias=f32[1],
                in_features=32,
                out_features='scalar',
                use_bias=True
              )
            ),
            activation=<wrapped function <lambda>>,
            final_activation=<function <lambda>>,
            use_bias=True,
            use_final_bias=True,
            in_size=64,
            out_size='scalar',
            width_size=32,
            depth=3
          )
        )
      ),
      Sequential(
        layers=(
          FourierEncoding(B=f32[32,2]),
          MLP(
            layers=(
              Linear(
                weight=f32[32,64],
                bias=f32[32],
                in_features=64,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[32,32],
                bias=f32[32],
                in_features=32,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[32,32],
                bias=f32[32],
                in_features=32,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[1,32],
                bias=f32[1],
                in_features=32,
                out_features='scalar',
                use_bias=True
              )
            ),
            activation=<wrapped function <lambda>>,
            final_activation=<function <lambda>>,
            use_bias=True,
            use_final_bias=True,
            in_size=64,
            out_size='scalar',
            width_size=32,
            depth=3
          )
        )
      ),
      Sequential(
        layers=(
          FourierEncoding(B=f32[32,2]),
          MLP(
            layers=(
              Linear(
                weight=f32[32,64],
                bias=f32[32],
                in_features=64,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[32,32],
                bias=f32[32],
                in_features=32,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[32,32],
                bias=f32[32],
                in_features=32,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[1,32],
                bias=f32[1],
                in_features=32,
                out_features='scalar',
                use_bias=True
              )
            ),
            activation=<wrapped function <lambda>>,
            final_activation=<function <lambda>>,
            use_bias=True,
            use_final_bias=True,
            in_size=64,
            out_size='scalar',
            width_size=32,
            depth=3
          )
        )
      ),
      Sequential(
        layers=(
          FourierEncoding(B=f32[32,2]),
          MLP(
            layers=(
              Linear(
                weight=f32[32,64],
                bias=f32[32],
                in_features=64,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[32,32],
                bias=f32[32],
                in_features=32,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[32,32],
                bias=f32[32],
                in_features=32,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[1,32],
                bias=f32[1],
                in_features=32,
                out_features='scalar',
                use_bias=True
              )
            ),
            activation=<wrapped function <lambda>>,
            final_activation=<function <lambda>>,
            use_bias=True,
            use_final_bias=True,
            in_size=64,
            out_size='scalar',
            width_size=32,
            depth=3
          )
        )
      ),
      Sequential(
        layers=(
          FourierEncoding(B=f32[32,2]),
          MLP(
            layers=(
              Linear(
                weight=f32[32,64],
                bias=f32[32],
                in_features=64,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[32,32],
                bias=f32[32],
                in_features=32,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[32,32],
                bias=f32[32],
                in_features=32,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[1,32],
                bias=f32[1],
                in_features=32,
                out_features='scalar',
                use_bias=True
              )
            ),
            activation=<wrapped function <lambda>>,
            final_activation=<function <lambda>>,
            use_bias=True,
            use_final_bias=True,
            in_size=64,
            out_size='scalar',
            width_size=32,
            depth=3
          )
        )
      ),
      Sequential(
        layers=(
          FourierEncoding(B=f32[32,2]),
          MLP(
            layers=(
              Linear(
                weight=f32[32,64],
                bias=f32[32],
                in_features=64,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[32,32],
                bias=f32[32],
                in_features=32,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[32,32],
                bias=f32[32],
                in_features=32,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[1,32],
                bias=f32[1],
                in_features=32,
                out_features='scalar',
                use_bias=True
              )
            ),
            activation=<wrapped function <lambda>>,
            final_activation=<function <lambda>>,
            use_bias=True,
            use_final_bias=True,
            in_size=64,
            out_size='scalar',
            width_size=32,
            depth=3
          )
        )
      ),
      Sequential(
        layers=(
          FourierEncoding(B=f32[32,2]),
          MLP(
            layers=(
              Linear(
                weight=f32[32,64],
                bias=f32[32],
                in_features=64,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[32,32],
                bias=f32[32],
                in_features=32,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[32,32],
                bias=f32[32],
                in_features=32,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[1,32],
                bias=f32[1],
                in_features=32,
                out_features='scalar',
                use_bias=True
              )
            ),
            activation=<wrapped function <lambda>>,
            final_activation=<function <lambda>>,
            use_bias=True,
            use_final_bias=True,
            in_size=64,
            out_size='scalar',
            width_size=32,
            depth=3
          )
        )
      ),
      Sequential(
        layers=(
          FourierEncoding(B=f32[32,2]),
          MLP(
            layers=(
              Linear(
                weight=f32[32,64],
                bias=f32[32],
                in_features=64,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[32,32],
                bias=f32[32],
                in_features=32,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[32,32],
                bias=f32[32],
                in_features=32,
                out_features=32,
                use_bias=True
              ),
              Linear(
                weight=f32[1,32],
                bias=f32[1],
                in_features=32,
                out_features='scalar',
                use_bias=True
              )
            ),
            activation=<wrapped function <lambda>>,
            final_activation=<function <lambda>>,
            use_bias=True,
            use_final_bias=True,
            in_size=64,
            out_size='scalar',
            width_size=32,
            depth=3
          )
        )
      )
    ]
  )
)

Let’s ensure our model works:

model(0.5, 0.5, 0.5)
Array(0.00502437, dtype=float32)

Let’s create our loss function:

from jax import vmap, grad

def loss_density(model, x, y, xi):
    u_xx = grad(grad(model, 0), 0)(x, y, xi)
    u_yy = grad(grad(model, 1), 1)(x, y, xi)
    f = 2 * jnp.pi ** 2 * xi * jnp.sin(jnp.pi * x) * jnp.sin(jnp.pi * y)
    return (u_xx + u_yy + f) ** 2

# This evaluate the loss at multiple xs and ys, but at a single xi
loss_density_single_xi = vmap(loss_density, in_axes=(None, 0, 0, None))

# Τhis evaluates the loss at multiple xs, ys, and xis.
# But the xs and ys are the same for all xis
loss_density_many_xis = vmap(loss_density_single_xi, in_axes=(None, None, None, 0))

# And the final loss
loss = lambda model, xs, ys, xis: jnp.mean(loss_density_many_xis(model, xs, ys, xis))

Let’s make sure it works:

xs = jnp.linspace(0, 1, 100)
ys = jnp.linspace(0, 1, 100)
xis = jnp.linspace(0, 1, 20)
res = loss_density_many_xis(model, xs, ys, xis)
res.shape
(20, 100)

And:

loss(model, xs, ys, xis)
Array(48.726887, dtype=float32)

We are not going to worry about the Fourier features. We will let the model learn them. Here is the training code:

def train_parametric_pinn(
        loss,
        model,
        key,
        optimizer,
        Lx=1.0,
        Ly=1.0,
        num_collocation_residual=128,
        num_xis = 16,
        num_iter=10_000,
        freq=1,
    ):
    """Notice that it assumes the xi's are sclar and uniformly distributed in [0, 1]."""

    @eqx.filter_jit
    def step(opt_state, model, xs, ys, xis):
        value, grads = eqx.filter_value_and_grad(loss)(model, xs, ys, xis)
        updates, opt_state = optimizer.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, value
    
    opt_state = optimizer.init(eqx.filter(model, eqx.is_inexact_array))
    
    losses = []
    for i in range(num_iter):
        key1, key2, key3, key = jrandom.split(key, 4)
        xs = jrandom.uniform(key1, (num_collocation_residual,), maxval=Lx)
        ys = jrandom.uniform(key2, (num_collocation_residual,), maxval=Ly)
        xis = jrandom.uniform(key3, (num_xis,))
        model, opt_state, value = step(opt_state, model, xs, ys, xis)
        if i % freq == 0:
            losses.append(value)
            print(f"Step {i}, residual loss {value:.3e}")
    return model, losses

Let’s train. Be patient. It takes a while to compile the model. Once the iterations start, it is faster. It takes about 2.5 minutes to run on my machine.

import optax

optimizer = optax.adam(1e-3)

model, losses = train_parametric_pinn(
    loss, model, key, optimizer,
    num_collocation_residual=128, num_xis=2,
    num_iter=1_000, freq=1)
Hide code cell output
Step 0, residual loss 5.746e-02
Step 1, residual loss 4.884e-01
Step 2, residual loss 4.151e-02
Step 3, residual loss 6.356e-02
Step 4, residual loss 4.841e-01
Step 5, residual loss 2.277e-01
Step 6, residual loss 3.806e-02
Step 7, residual loss 1.023e-01
Step 8, residual loss 2.768e-01
Step 9, residual loss 2.563e-01
Step 10, residual loss 1.278e-01
Step 11, residual loss 2.503e-02
Step 12, residual loss 9.534e-02
Step 13, residual loss 1.992e-01
Step 14, residual loss 1.146e-01
Step 15, residual loss 1.077e-01
Step 16, residual loss 6.966e-03
Step 17, residual loss 9.710e-03
Step 18, residual loss 6.682e-02
Step 19, residual loss 1.755e-01
Step 20, residual loss 2.037e-01
Step 21, residual loss 9.842e-02
Step 22, residual loss 1.372e-02
Step 23, residual loss 3.400e-03
Step 24, residual loss 1.113e-02
Step 25, residual loss 5.384e-02
Step 26, residual loss 9.686e-02
Step 27, residual loss 1.681e-02
Step 28, residual loss 4.020e-02
Step 29, residual loss 2.934e-02
Step 30, residual loss 4.750e-02
Step 31, residual loss 1.868e-03
Step 32, residual loss 7.215e-02
Step 33, residual loss 4.529e-03
Step 34, residual loss 5.794e-03
Step 35, residual loss 9.468e-03
Step 36, residual loss 3.039e-02
Step 37, residual loss 3.862e-02
Step 38, residual loss 6.138e-03
Step 39, residual loss 3.440e-03
Step 40, residual loss 5.203e-02
Step 41, residual loss 8.518e-03
Step 42, residual loss 4.352e-02
Step 43, residual loss 2.826e-02
Step 44, residual loss 7.345e-03
Step 45, residual loss 2.369e-03
Step 46, residual loss 9.317e-02
Step 47, residual loss 6.692e-03
Step 48, residual loss 2.145e-03
Step 49, residual loss 7.465e-03
Step 50, residual loss 9.474e-03
Step 51, residual loss 1.743e-02
Step 52, residual loss 1.776e-02
Step 53, residual loss 1.649e-02
Step 54, residual loss 1.155e-02
Step 55, residual loss 7.490e-03
Step 56, residual loss 1.922e-02
Step 57, residual loss 2.145e-03
Step 58, residual loss 4.775e-02
Step 59, residual loss 1.389e-02
Step 60, residual loss 1.581e-02
Step 61, residual loss 1.537e-02
Step 62, residual loss 1.438e-02
Step 63, residual loss 8.039e-03
Step 64, residual loss 4.789e-03
Step 65, residual loss 4.748e-03
Step 66, residual loss 2.187e-02
Step 67, residual loss 7.972e-03
Step 68, residual loss 3.059e-02
Step 69, residual loss 3.683e-02
Step 70, residual loss 4.065e-03
Step 71, residual loss 6.698e-03
Step 72, residual loss 2.020e-03
Step 73, residual loss 7.261e-02
Step 74, residual loss 2.657e-03
Step 75, residual loss 1.979e-02
Step 76, residual loss 3.570e-02
Step 77, residual loss 1.695e-03
Step 78, residual loss 1.001e-03
Step 79, residual loss 4.178e-03
Step 80, residual loss 4.020e-02
Step 81, residual loss 3.045e-02
Step 82, residual loss 1.574e-02
Step 83, residual loss 2.879e-02
Step 84, residual loss 1.134e-03
Step 85, residual loss 1.071e-02
Step 86, residual loss 1.154e-02
Step 87, residual loss 1.385e-02
Step 88, residual loss 5.302e-03
Step 89, residual loss 5.111e-03
Step 90, residual loss 1.249e-02
Step 91, residual loss 1.942e-02
Step 92, residual loss 6.095e-02
Step 93, residual loss 3.569e-03
Step 94, residual loss 3.117e-02
Step 95, residual loss 2.568e-02
Step 96, residual loss 9.341e-03
Step 97, residual loss 5.175e-03
Step 98, residual loss 9.163e-03
Step 99, residual loss 1.957e-02
Step 100, residual loss 3.101e-03
Step 101, residual loss 1.011e-03
Step 102, residual loss 3.847e-03
Step 103, residual loss 7.183e-03
Step 104, residual loss 5.838e-03
Step 105, residual loss 5.146e-03
Step 106, residual loss 1.959e-03
Step 107, residual loss 2.035e-02
Step 108, residual loss 8.989e-02
Step 109, residual loss 1.917e-02
Step 110, residual loss 2.116e-02
Step 111, residual loss 8.374e-02
Step 112, residual loss 2.246e-02
Step 113, residual loss 4.037e-03
Step 114, residual loss 1.309e-01
Step 115, residual loss 6.506e-02
Step 116, residual loss 1.440e-02
Step 117, residual loss 1.780e-02
Step 118, residual loss 2.072e-03
Step 119, residual loss 5.282e-02
Step 120, residual loss 8.265e-02
Step 121, residual loss 1.052e-02
Step 122, residual loss 6.682e-03
Step 123, residual loss 1.039e-01
Step 124, residual loss 1.430e-02
Step 125, residual loss 1.615e-03
Step 126, residual loss 7.159e-03
Step 127, residual loss 4.547e-02
Step 128, residual loss 7.674e-03
Step 129, residual loss 3.268e-02
Step 130, residual loss 5.254e-02
Step 131, residual loss 1.249e-02
Step 132, residual loss 1.120e-02
Step 133, residual loss 5.041e-02
Step 134, residual loss 5.764e-03
Step 135, residual loss 9.048e-03
Step 136, residual loss 1.923e-02
Step 137, residual loss 1.010e-02
Step 138, residual loss 4.080e-02
Step 139, residual loss 7.653e-02
Step 140, residual loss 1.318e-02
Step 141, residual loss 4.564e-03
Step 142, residual loss 1.545e-01
Step 143, residual loss 9.994e-02
Step 144, residual loss 9.469e-03
Step 145, residual loss 1.931e-02
Step 146, residual loss 7.228e-02
Step 147, residual loss 8.018e-02
Step 148, residual loss 4.356e-03
Step 149, residual loss 1.444e-02
Step 150, residual loss 7.911e-02
Step 151, residual loss 8.005e-02
Step 152, residual loss 1.700e-02
Step 153, residual loss 4.045e-02
Step 154, residual loss 7.199e-02
Step 155, residual loss 4.796e-02
Step 156, residual loss 3.663e-02
Step 157, residual loss 6.941e-02
Step 158, residual loss 6.634e-02
Step 159, residual loss 3.795e-01
Step 160, residual loss 4.992e-02
Step 161, residual loss 4.795e-02
Step 162, residual loss 6.500e-02
Step 163, residual loss 3.562e-01
Step 164, residual loss 7.286e-02
Step 165, residual loss 4.861e-02
Step 166, residual loss 3.488e-02
Step 167, residual loss 2.559e-01
Step 168, residual loss 2.826e-02
Step 169, residual loss 1.268e-01
Step 170, residual loss 2.375e-02
Step 171, residual loss 5.011e-02
Step 172, residual loss 7.292e-02
Step 173, residual loss 6.487e-02
Step 174, residual loss 2.825e-02
Step 175, residual loss 2.879e-02
Step 176, residual loss 1.848e-02
Step 177, residual loss 8.314e-02
Step 178, residual loss 6.772e-02
Step 179, residual loss 5.128e-02
Step 180, residual loss 2.830e-02
Step 181, residual loss 9.147e-02
Step 182, residual loss 6.876e-02
Step 183, residual loss 5.308e-02
Step 184, residual loss 9.269e-03
Step 185, residual loss 5.564e-02
Step 186, residual loss 3.709e-02
Step 187, residual loss 2.145e-02
Step 188, residual loss 5.759e-03
Step 189, residual loss 2.348e-02
Step 190, residual loss 2.592e-02
Step 191, residual loss 3.925e-02
Step 192, residual loss 1.839e-03
Step 193, residual loss 1.270e-02
Step 194, residual loss 4.844e-02
Step 195, residual loss 1.436e-02
Step 196, residual loss 7.471e-03
Step 197, residual loss 2.330e-02
Step 198, residual loss 5.054e-03
Step 199, residual loss 1.844e-02
Step 200, residual loss 1.062e-02
Step 201, residual loss 4.255e-02
Step 202, residual loss 2.615e-02
Step 203, residual loss 3.604e-02
Step 204, residual loss 2.991e-02
Step 205, residual loss 8.713e-03
Step 206, residual loss 4.696e-03
Step 207, residual loss 1.909e-02
Step 208, residual loss 1.848e-02
Step 209, residual loss 5.206e-03
Step 210, residual loss 9.892e-03
Step 211, residual loss 6.036e-03
Step 212, residual loss 2.916e-02
Step 213, residual loss 2.258e-02
Step 214, residual loss 7.668e-03
Step 215, residual loss 5.130e-03
Step 216, residual loss 6.794e-03
Step 217, residual loss 3.035e-02
Step 218, residual loss 1.515e-01
Step 219, residual loss 2.558e-03
Step 220, residual loss 5.182e-02
Step 221, residual loss 1.511e-02
Step 222, residual loss 4.930e-02
Step 223, residual loss 6.024e-03
Step 224, residual loss 2.821e-02
Step 225, residual loss 1.645e-02
Step 226, residual loss 1.506e-02
Step 227, residual loss 2.830e-02
Step 228, residual loss 1.765e-02
Step 229, residual loss 1.403e-02
Step 230, residual loss 9.785e-03
Step 231, residual loss 7.159e-03
Step 232, residual loss 8.578e-04
Step 233, residual loss 2.905e-03
Step 234, residual loss 1.291e-03
Step 235, residual loss 1.924e-03
Step 236, residual loss 5.199e-03
Step 237, residual loss 3.874e-02
Step 238, residual loss 1.016e-02
Step 239, residual loss 6.630e-03
Step 240, residual loss 3.365e-02
Step 241, residual loss 2.899e-02
Step 242, residual loss 2.276e-02
Step 243, residual loss 1.109e-02
Step 244, residual loss 7.194e-03
Step 245, residual loss 2.966e-02
Step 246, residual loss 1.499e-02
Step 247, residual loss 4.221e-03
Step 248, residual loss 1.790e-02
Step 249, residual loss 3.227e-02
Step 250, residual loss 1.007e-02
Step 251, residual loss 8.839e-03
Step 252, residual loss 2.215e-02
Step 253, residual loss 1.817e-02
Step 254, residual loss 3.197e-03
Step 255, residual loss 3.349e-03
Step 256, residual loss 1.974e-03
Step 257, residual loss 1.697e-02
Step 258, residual loss 3.050e-02
Step 259, residual loss 8.071e-03
Step 260, residual loss 1.392e-02
Step 261, residual loss 5.500e-02
Step 262, residual loss 1.721e-03
Step 263, residual loss 3.728e-02
Step 264, residual loss 1.841e-02
Step 265, residual loss 1.558e-02
Step 266, residual loss 4.278e-02
Step 267, residual loss 6.543e-02
Step 268, residual loss 4.384e-02
Step 269, residual loss 1.161e-01
Step 270, residual loss 8.294e-02
Step 271, residual loss 7.818e-02
Step 272, residual loss 2.551e-02
Step 273, residual loss 8.430e-02
Step 274, residual loss 6.553e-02
Step 275, residual loss 4.106e-02
Step 276, residual loss 3.539e-02
Step 277, residual loss 1.004e-02
Step 278, residual loss 1.161e-01
Step 279, residual loss 8.207e-02
Step 280, residual loss 1.468e-02
Step 281, residual loss 2.419e-02
Step 282, residual loss 4.599e-02
Step 283, residual loss 3.409e-02
Step 284, residual loss 7.606e-03
Step 285, residual loss 7.092e-03
Step 286, residual loss 2.161e-02
Step 287, residual loss 3.806e-02
Step 288, residual loss 2.674e-02
Step 289, residual loss 4.406e-03
Step 290, residual loss 1.494e-02
Step 291, residual loss 2.684e-02
Step 292, residual loss 1.787e-02
Step 293, residual loss 5.313e-02
Step 294, residual loss 1.536e-02
Step 295, residual loss 1.439e-02
Step 296, residual loss 3.659e-03
Step 297, residual loss 4.517e-03
Step 298, residual loss 1.278e-03
Step 299, residual loss 2.078e-02
Step 300, residual loss 1.163e-02
Step 301, residual loss 2.170e-02
Step 302, residual loss 1.136e-02
Step 303, residual loss 5.355e-03
Step 304, residual loss 9.612e-03
Step 305, residual loss 1.019e-03
Step 306, residual loss 5.046e-03
Step 307, residual loss 2.898e-03
Step 308, residual loss 4.896e-03
Step 309, residual loss 5.056e-04
Step 310, residual loss 2.493e-03
Step 311, residual loss 4.543e-03
Step 312, residual loss 3.962e-03
Step 313, residual loss 1.825e-03
Step 314, residual loss 8.930e-04
Step 315, residual loss 4.074e-02
Step 316, residual loss 2.462e-02
Step 317, residual loss 1.797e-02
Step 318, residual loss 1.101e-01
Step 319, residual loss 5.199e-03
Step 320, residual loss 5.337e-02
Step 321, residual loss 1.259e-01
Step 322, residual loss 4.916e-02
Step 323, residual loss 1.323e-02
Step 324, residual loss 1.857e-02
Step 325, residual loss 1.422e-01
Step 326, residual loss 7.292e-03
Step 327, residual loss 9.222e-03
Step 328, residual loss 1.444e-02
Step 329, residual loss 5.332e-02
Step 330, residual loss 2.825e-04
Step 331, residual loss 3.610e-02
Step 332, residual loss 4.389e-02
Step 333, residual loss 3.292e-03
Step 334, residual loss 6.634e-03
Step 335, residual loss 3.601e-02
Step 336, residual loss 1.724e-02
Step 337, residual loss 2.315e-02
Step 338, residual loss 1.395e-03
Step 339, residual loss 5.296e-03
Step 340, residual loss 6.034e-02
Step 341, residual loss 1.582e-02
Step 342, residual loss 6.491e-03
Step 343, residual loss 2.297e-02
Step 344, residual loss 3.877e-02
Step 345, residual loss 1.528e-02
Step 346, residual loss 7.997e-03
Step 347, residual loss 2.959e-02
Step 348, residual loss 1.888e-02
Step 349, residual loss 6.948e-03
Step 350, residual loss 5.571e-04
Step 351, residual loss 4.280e-03
Step 352, residual loss 7.512e-03
Step 353, residual loss 5.442e-03
Step 354, residual loss 4.049e-03
Step 355, residual loss 6.885e-03
Step 356, residual loss 5.400e-03
Step 357, residual loss 1.401e-02
Step 358, residual loss 1.472e-02
Step 359, residual loss 1.456e-03
Step 360, residual loss 6.274e-03
Step 361, residual loss 9.698e-03
Step 362, residual loss 8.534e-03
Step 363, residual loss 4.425e-02
Step 364, residual loss 1.447e-02
Step 365, residual loss 5.505e-02
Step 366, residual loss 3.813e-03
Step 367, residual loss 1.837e-02
Step 368, residual loss 1.304e-02
Step 369, residual loss 1.891e-02
Step 370, residual loss 3.833e-02
Step 371, residual loss 5.800e-02
Step 372, residual loss 3.883e-03
Step 373, residual loss 3.553e-02
Step 374, residual loss 2.498e-02
Step 375, residual loss 3.499e-02
Step 376, residual loss 4.557e-04
Step 377, residual loss 4.132e-03
Step 378, residual loss 3.490e-02
Step 379, residual loss 7.374e-04
Step 380, residual loss 1.994e-03
Step 381, residual loss 8.756e-03
Step 382, residual loss 8.640e-03
Step 383, residual loss 2.569e-02
Step 384, residual loss 2.940e-02
Step 385, residual loss 1.340e-02
Step 386, residual loss 1.504e-02
Step 387, residual loss 9.394e-03
Step 388, residual loss 8.377e-03
Step 389, residual loss 1.389e-03
Step 390, residual loss 5.376e-04
Step 391, residual loss 6.166e-03
Step 392, residual loss 4.918e-02
Step 393, residual loss 1.758e-02
Step 394, residual loss 2.209e-02
Step 395, residual loss 3.300e-02
Step 396, residual loss 3.794e-03
Step 397, residual loss 2.819e-03
Step 398, residual loss 1.018e-03
Step 399, residual loss 5.520e-03
Step 400, residual loss 1.666e-02
Step 401, residual loss 2.283e-02
Step 402, residual loss 4.877e-02
Step 403, residual loss 6.181e-02
Step 404, residual loss 1.415e-02
Step 405, residual loss 2.331e-02
Step 406, residual loss 1.936e-02
Step 407, residual loss 7.252e-03
Step 408, residual loss 9.806e-03
Step 409, residual loss 1.733e-03
Step 410, residual loss 1.051e-02
Step 411, residual loss 1.239e-03
Step 412, residual loss 1.902e-03
Step 413, residual loss 2.189e-03
Step 414, residual loss 4.213e-03
Step 415, residual loss 1.329e-03
Step 416, residual loss 1.208e-02
Step 417, residual loss 2.982e-02
Step 418, residual loss 6.433e-03
Step 419, residual loss 6.159e-03
Step 420, residual loss 2.871e-03
Step 421, residual loss 6.953e-03
Step 422, residual loss 1.369e-02
Step 423, residual loss 2.883e-03
Step 424, residual loss 1.489e-03
Step 425, residual loss 3.180e-03
Step 426, residual loss 1.955e-03
Step 427, residual loss 1.733e-02
Step 428, residual loss 1.581e-03
Step 429, residual loss 4.138e-03
Step 430, residual loss 2.963e-03
Step 431, residual loss 1.589e-03
Step 432, residual loss 2.873e-02
Step 433, residual loss 3.491e-03
Step 434, residual loss 8.196e-03
Step 435, residual loss 7.946e-02
Step 436, residual loss 8.102e-03
Step 437, residual loss 2.767e-02
Step 438, residual loss 2.798e-02
Step 439, residual loss 7.460e-03
Step 440, residual loss 4.991e-03
Step 441, residual loss 6.363e-03
Step 442, residual loss 3.460e-03
Step 443, residual loss 3.882e-04
Step 444, residual loss 2.823e-02
Step 445, residual loss 6.463e-03
Step 446, residual loss 1.079e-02
Step 447, residual loss 1.639e-02
Step 448, residual loss 3.138e-02
Step 449, residual loss 7.585e-03
Step 450, residual loss 6.128e-02
Step 451, residual loss 2.796e-02
Step 452, residual loss 2.478e-02
Step 453, residual loss 9.585e-03
Step 454, residual loss 2.990e-02
Step 455, residual loss 4.221e-03
Step 456, residual loss 1.211e-02
Step 457, residual loss 1.111e-02
Step 458, residual loss 2.605e-02
Step 459, residual loss 4.161e-03
Step 460, residual loss 1.035e-02
Step 461, residual loss 6.654e-03
Step 462, residual loss 8.555e-03
Step 463, residual loss 4.770e-03
Step 464, residual loss 3.544e-03
Step 465, residual loss 6.358e-03
Step 466, residual loss 8.989e-04
Step 467, residual loss 6.297e-04
Step 468, residual loss 4.231e-04
Step 469, residual loss 1.161e-02
Step 470, residual loss 1.195e-03
Step 471, residual loss 5.171e-03
Step 472, residual loss 5.102e-03
Step 473, residual loss 1.377e-03
Step 474, residual loss 2.624e-02
Step 475, residual loss 1.862e-02
Step 476, residual loss 1.394e-02
Step 477, residual loss 2.209e-02
Step 478, residual loss 4.368e-03
Step 479, residual loss 2.433e-02
Step 480, residual loss 2.203e-02
Step 481, residual loss 2.282e-02
Step 482, residual loss 1.583e-03
Step 483, residual loss 3.504e-02
Step 484, residual loss 5.329e-03
Step 485, residual loss 2.493e-02
Step 486, residual loss 1.988e-03
Step 487, residual loss 8.857e-03
Step 488, residual loss 1.289e-02
Step 489, residual loss 4.636e-02
Step 490, residual loss 3.610e-02
Step 491, residual loss 5.524e-02
Step 492, residual loss 6.033e-02
Step 493, residual loss 5.537e-02
Step 494, residual loss 9.656e-03
Step 495, residual loss 2.641e-02
Step 496, residual loss 6.350e-02
Step 497, residual loss 6.365e-02
Step 498, residual loss 1.070e-03
Step 499, residual loss 7.117e-02
Step 500, residual loss 1.314e-01
Step 501, residual loss 2.083e-03
Step 502, residual loss 4.491e-02
Step 503, residual loss 9.527e-02
Step 504, residual loss 2.034e-02
Step 505, residual loss 3.494e-02
Step 506, residual loss 3.742e-02
Step 507, residual loss 3.082e-02
Step 508, residual loss 2.461e-01
Step 509, residual loss 5.793e-02
Step 510, residual loss 7.950e-02
Step 511, residual loss 4.050e-01
Step 512, residual loss 1.173e-01
Step 513, residual loss 9.066e-02
Step 514, residual loss 1.117e-01
Step 515, residual loss 1.451e-01
Step 516, residual loss 1.359e-01
Step 517, residual loss 2.991e-01
Step 518, residual loss 4.687e-02
Step 519, residual loss 6.781e-02
Step 520, residual loss 2.206e-02
Step 521, residual loss 2.349e-02
Step 522, residual loss 2.624e-02
Step 523, residual loss 7.387e-02
Step 524, residual loss 3.589e-02
Step 525, residual loss 1.087e-01
Step 526, residual loss 2.858e-02
Step 527, residual loss 5.013e-02
Step 528, residual loss 7.361e-02
Step 529, residual loss 1.392e-02
Step 530, residual loss 2.018e-02
Step 531, residual loss 3.346e-01
Step 532, residual loss 1.200e-01
Step 533, residual loss 1.209e-01
Step 534, residual loss 1.846e-01
Step 535, residual loss 1.210e-01
Step 536, residual loss 6.095e-03
Step 537, residual loss 9.751e-03
Step 538, residual loss 2.711e-02
Step 539, residual loss 2.683e-01
Step 540, residual loss 1.054e-01
Step 541, residual loss 1.531e-02
Step 542, residual loss 7.417e-02
Step 543, residual loss 1.649e-01
Step 544, residual loss 3.098e-02
Step 545, residual loss 2.194e-03
Step 546, residual loss 2.567e-02
Step 547, residual loss 1.018e-01
Step 548, residual loss 4.143e-02
Step 549, residual loss 1.796e-02
Step 550, residual loss 8.992e-03
Step 551, residual loss 2.230e-03
Step 552, residual loss 2.398e-02
Step 553, residual loss 5.278e-02
Step 554, residual loss 3.224e-02
Step 555, residual loss 1.037e-02
Step 556, residual loss 2.094e-02
Step 557, residual loss 3.215e-02
Step 558, residual loss 3.843e-02
Step 559, residual loss 1.936e-02
Step 560, residual loss 5.845e-03
Step 561, residual loss 2.821e-02
Step 562, residual loss 3.872e-02
Step 563, residual loss 2.653e-03
Step 564, residual loss 3.291e-03
Step 565, residual loss 5.028e-03
Step 566, residual loss 5.108e-03
Step 567, residual loss 1.422e-02
Step 568, residual loss 5.223e-03
Step 569, residual loss 1.616e-02
Step 570, residual loss 1.021e-02
Step 571, residual loss 1.190e-02
Step 572, residual loss 1.063e-02
Step 573, residual loss 6.292e-02
Step 574, residual loss 2.054e-02
Step 575, residual loss 1.561e-02
Step 576, residual loss 2.106e-02
Step 577, residual loss 3.324e-02
Step 578, residual loss 2.645e-02
Step 579, residual loss 4.013e-03
Step 580, residual loss 2.627e-02
Step 581, residual loss 7.311e-02
Step 582, residual loss 1.191e-03
Step 583, residual loss 2.692e-02
Step 584, residual loss 3.939e-03
Step 585, residual loss 4.622e-02
Step 586, residual loss 1.505e-02
Step 587, residual loss 2.866e-03
Step 588, residual loss 9.558e-03
Step 589, residual loss 1.436e-02
Step 590, residual loss 5.647e-02
Step 591, residual loss 9.618e-03
Step 592, residual loss 8.672e-03
Step 593, residual loss 1.982e-02
Step 594, residual loss 5.086e-02
Step 595, residual loss 3.307e-03
Step 596, residual loss 9.158e-03
Step 597, residual loss 4.007e-02
Step 598, residual loss 8.114e-03
Step 599, residual loss 1.564e-02
Step 600, residual loss 4.495e-03
Step 601, residual loss 2.982e-03
Step 602, residual loss 6.259e-03
Step 603, residual loss 3.505e-03
Step 604, residual loss 1.700e-02
Step 605, residual loss 9.295e-03
Step 606, residual loss 1.517e-02
Step 607, residual loss 1.268e-02
Step 608, residual loss 6.131e-03
Step 609, residual loss 3.856e-03
Step 610, residual loss 1.482e-02
Step 611, residual loss 2.895e-02
Step 612, residual loss 2.233e-03
Step 613, residual loss 2.071e-02
Step 614, residual loss 1.030e-02
Step 615, residual loss 7.309e-03
Step 616, residual loss 1.123e-02
Step 617, residual loss 4.863e-03
Step 618, residual loss 2.574e-02
Step 619, residual loss 1.683e-02
Step 620, residual loss 1.855e-03
Step 621, residual loss 1.772e-02
Step 622, residual loss 5.604e-03
Step 623, residual loss 1.308e-02
Step 624, residual loss 6.386e-03
Step 625, residual loss 6.150e-02
Step 626, residual loss 1.021e-03
Step 627, residual loss 2.925e-02
Step 628, residual loss 3.893e-02
Step 629, residual loss 2.573e-02
Step 630, residual loss 7.597e-03
Step 631, residual loss 1.444e-02
Step 632, residual loss 3.973e-03
Step 633, residual loss 1.486e-02
Step 634, residual loss 1.085e-02
Step 635, residual loss 8.408e-04
Step 636, residual loss 4.126e-03
Step 637, residual loss 4.638e-04
Step 638, residual loss 3.002e-03
Step 639, residual loss 4.319e-04
Step 640, residual loss 1.061e-02
Step 641, residual loss 3.917e-04
Step 642, residual loss 1.611e-03
Step 643, residual loss 6.492e-03
Step 644, residual loss 1.340e-03
Step 645, residual loss 5.305e-03
Step 646, residual loss 4.682e-03
Step 647, residual loss 5.091e-03
Step 648, residual loss 3.118e-03
Step 649, residual loss 4.580e-03
Step 650, residual loss 8.131e-03
Step 651, residual loss 2.554e-03
Step 652, residual loss 2.165e-03
Step 653, residual loss 3.485e-03
Step 654, residual loss 2.919e-03
Step 655, residual loss 1.190e-03
Step 656, residual loss 2.320e-03
Step 657, residual loss 2.630e-03
Step 658, residual loss 1.177e-03
Step 659, residual loss 2.442e-03
Step 660, residual loss 2.532e-02
Step 661, residual loss 3.193e-03
Step 662, residual loss 6.967e-03
Step 663, residual loss 2.989e-03
Step 664, residual loss 3.006e-03
Step 665, residual loss 1.959e-02
Step 666, residual loss 2.305e-03
Step 667, residual loss 7.092e-03
Step 668, residual loss 6.731e-03
Step 669, residual loss 1.558e-03
Step 670, residual loss 1.961e-03
Step 671, residual loss 1.172e-02
Step 672, residual loss 2.275e-02
Step 673, residual loss 6.148e-03
Step 674, residual loss 4.150e-03
Step 675, residual loss 1.796e-03
Step 676, residual loss 1.152e-03
Step 677, residual loss 9.779e-04
Step 678, residual loss 4.985e-04
Step 679, residual loss 1.761e-04
Step 680, residual loss 5.332e-03
Step 681, residual loss 1.566e-02
Step 682, residual loss 1.152e-02
Step 683, residual loss 1.338e-02
Step 684, residual loss 5.961e-02
Step 685, residual loss 2.735e-03
Step 686, residual loss 5.391e-03
Step 687, residual loss 8.978e-03
Step 688, residual loss 1.079e-01
Step 689, residual loss 1.359e-02
Step 690, residual loss 5.425e-03
Step 691, residual loss 3.840e-02
Step 692, residual loss 4.813e-02
Step 693, residual loss 5.234e-03
Step 694, residual loss 1.163e-02
Step 695, residual loss 4.197e-02
Step 696, residual loss 4.449e-02
Step 697, residual loss 9.894e-03
Step 698, residual loss 2.000e-03
Step 699, residual loss 4.060e-03
Step 700, residual loss 1.250e-01
Step 701, residual loss 5.290e-03
Step 702, residual loss 3.253e-03
Step 703, residual loss 1.915e-02
Step 704, residual loss 3.045e-02
Step 705, residual loss 1.282e-02
Step 706, residual loss 1.034e-03
Step 707, residual loss 3.439e-03
Step 708, residual loss 4.327e-02
Step 709, residual loss 3.900e-02
Step 710, residual loss 2.700e-02
Step 711, residual loss 9.410e-04
Step 712, residual loss 2.685e-02
Step 713, residual loss 4.246e-02
Step 714, residual loss 7.359e-02
Step 715, residual loss 3.838e-02
Step 716, residual loss 1.270e-02
Step 717, residual loss 2.408e-02
Step 718, residual loss 1.562e-02
Step 719, residual loss 4.264e-03
Step 720, residual loss 2.806e-02
Step 721, residual loss 7.452e-03
Step 722, residual loss 3.647e-03
Step 723, residual loss 7.937e-03
Step 724, residual loss 1.647e-02
Step 725, residual loss 1.041e-02
Step 726, residual loss 2.486e-03
Step 727, residual loss 1.039e-02
Step 728, residual loss 9.411e-02
Step 729, residual loss 1.324e-02
Step 730, residual loss 5.376e-02
Step 731, residual loss 9.469e-02
Step 732, residual loss 6.227e-02
Step 733, residual loss 1.812e-02
Step 734, residual loss 1.036e-01
Step 735, residual loss 9.123e-02
Step 736, residual loss 1.714e-02
Step 737, residual loss 7.469e-02
Step 738, residual loss 8.045e-02
Step 739, residual loss 2.391e-01
Step 740, residual loss 6.206e-02
Step 741, residual loss 4.847e-02
Step 742, residual loss 2.904e-02
Step 743, residual loss 1.532e-01
Step 744, residual loss 1.374e-02
Step 745, residual loss 1.215e-02
Step 746, residual loss 1.972e-02
Step 747, residual loss 8.017e-02
Step 748, residual loss 9.181e-02
Step 749, residual loss 1.532e-02
Step 750, residual loss 5.447e-02
Step 751, residual loss 2.062e-01
Step 752, residual loss 2.743e-02
Step 753, residual loss 3.864e-02
Step 754, residual loss 4.070e-02
Step 755, residual loss 9.738e-02
Step 756, residual loss 3.823e-03
Step 757, residual loss 1.073e-01
Step 758, residual loss 8.210e-02
Step 759, residual loss 3.893e-02
Step 760, residual loss 1.020e-02
Step 761, residual loss 4.463e-02
Step 762, residual loss 8.701e-02
Step 763, residual loss 1.508e-02
Step 764, residual loss 2.790e-02
Step 765, residual loss 8.999e-03
Step 766, residual loss 5.616e-02
Step 767, residual loss 2.087e-02
Step 768, residual loss 2.541e-02
Step 769, residual loss 4.608e-02
Step 770, residual loss 7.793e-03
Step 771, residual loss 1.053e-02
Step 772, residual loss 2.831e-02
Step 773, residual loss 3.800e-03
Step 774, residual loss 1.550e-02
Step 775, residual loss 3.443e-02
Step 776, residual loss 1.856e-02
Step 777, residual loss 3.621e-02
Step 778, residual loss 2.449e-03
Step 779, residual loss 7.897e-02
Step 780, residual loss 4.681e-03
Step 781, residual loss 2.370e-02
Step 782, residual loss 5.024e-02
Step 783, residual loss 4.745e-02
Step 784, residual loss 2.033e-03
Step 785, residual loss 1.280e-02
Step 786, residual loss 1.137e-02
Step 787, residual loss 5.340e-02
Step 788, residual loss 2.714e-02
Step 789, residual loss 7.207e-02
Step 790, residual loss 4.036e-02
Step 791, residual loss 2.084e-02
Step 792, residual loss 4.085e-02
Step 793, residual loss 9.816e-02
Step 794, residual loss 5.760e-02
Step 795, residual loss 6.644e-02
Step 796, residual loss 3.140e-02
Step 797, residual loss 3.974e-02
Step 798, residual loss 9.903e-02
Step 799, residual loss 3.435e-02
Step 800, residual loss 1.734e-02
Step 801, residual loss 3.376e-02
Step 802, residual loss 5.184e-02
Step 803, residual loss 2.778e-02
Step 804, residual loss 2.585e-02
Step 805, residual loss 4.884e-02
Step 806, residual loss 2.887e-02
Step 807, residual loss 2.667e-02
Step 808, residual loss 2.873e-03
Step 809, residual loss 1.793e-02
Step 810, residual loss 1.576e-02
Step 811, residual loss 9.390e-02
Step 812, residual loss 9.335e-02
Step 813, residual loss 5.566e-02
Step 814, residual loss 2.008e-02
Step 815, residual loss 3.747e-02
Step 816, residual loss 9.227e-02
Step 817, residual loss 2.753e-02
Step 818, residual loss 3.015e-03
Step 819, residual loss 4.074e-02
Step 820, residual loss 1.425e-01
Step 821, residual loss 1.142e-03
Step 822, residual loss 1.285e-02
Step 823, residual loss 1.603e-01
Step 824, residual loss 3.522e-02
Step 825, residual loss 1.465e-02
Step 826, residual loss 8.759e-02
Step 827, residual loss 4.834e-02
Step 828, residual loss 2.284e-02
Step 829, residual loss 2.496e-02
Step 830, residual loss 3.406e-02
Step 831, residual loss 5.534e-02
Step 832, residual loss 7.083e-03
Step 833, residual loss 1.501e-02
Step 834, residual loss 9.737e-02
Step 835, residual loss 1.189e-02
Step 836, residual loss 8.556e-03
Step 837, residual loss 2.412e-02
Step 838, residual loss 1.530e-02
Step 839, residual loss 1.489e-03
Step 840, residual loss 2.268e-02
Step 841, residual loss 7.924e-03
Step 842, residual loss 7.530e-03
Step 843, residual loss 3.328e-02
Step 844, residual loss 4.360e-02
Step 845, residual loss 1.150e-02
Step 846, residual loss 2.166e-03
Step 847, residual loss 5.919e-03
Step 848, residual loss 1.116e-02
Step 849, residual loss 1.675e-02
Step 850, residual loss 6.515e-03
Step 851, residual loss 4.318e-03
Step 852, residual loss 2.280e-03
Step 853, residual loss 1.017e-02
Step 854, residual loss 4.940e-03
Step 855, residual loss 1.210e-02
Step 856, residual loss 3.901e-03
Step 857, residual loss 1.026e-02
Step 858, residual loss 3.326e-02
Step 859, residual loss 3.153e-02
Step 860, residual loss 8.601e-03
Step 861, residual loss 1.074e-03
Step 862, residual loss 4.938e-03
Step 863, residual loss 1.693e-03
Step 864, residual loss 9.706e-03
Step 865, residual loss 2.326e-02
Step 866, residual loss 1.646e-02
Step 867, residual loss 5.882e-03
Step 868, residual loss 1.038e-01
Step 869, residual loss 1.099e-02
Step 870, residual loss 5.345e-03
Step 871, residual loss 2.325e-02
Step 872, residual loss 1.010e-02
Step 873, residual loss 3.461e-04
Step 874, residual loss 1.321e-03
Step 875, residual loss 1.538e-03
Step 876, residual loss 1.355e-02
Step 877, residual loss 1.978e-02
Step 878, residual loss 8.598e-03
Step 879, residual loss 9.555e-04
Step 880, residual loss 5.617e-03
Step 881, residual loss 8.810e-03
Step 882, residual loss 5.576e-03
Step 883, residual loss 3.000e-03
Step 884, residual loss 3.692e-04
Step 885, residual loss 9.053e-04
Step 886, residual loss 5.325e-03
Step 887, residual loss 2.105e-02
Step 888, residual loss 1.892e-03
Step 889, residual loss 1.876e-03
Step 890, residual loss 9.719e-03
Step 891, residual loss 3.464e-02
Step 892, residual loss 3.074e-03
Step 893, residual loss 3.995e-03
Step 894, residual loss 3.667e-02
Step 895, residual loss 3.830e-02
Step 896, residual loss 6.732e-03
Step 897, residual loss 1.859e-02
Step 898, residual loss 4.289e-02
Step 899, residual loss 3.985e-02
Step 900, residual loss 1.846e-02
Step 901, residual loss 1.051e-02
Step 902, residual loss 2.409e-02
Step 903, residual loss 3.390e-02
Step 904, residual loss 3.432e-03
Step 905, residual loss 1.616e-02
Step 906, residual loss 5.849e-02
Step 907, residual loss 2.775e-02
Step 908, residual loss 3.624e-03
Step 909, residual loss 1.617e-02
Step 910, residual loss 5.856e-02
Step 911, residual loss 5.475e-03
Step 912, residual loss 9.437e-03
Step 913, residual loss 3.022e-03
Step 914, residual loss 1.402e-02
Step 915, residual loss 1.093e-02
Step 916, residual loss 2.066e-03
Step 917, residual loss 3.877e-03
Step 918, residual loss 3.212e-02
Step 919, residual loss 1.765e-02
Step 920, residual loss 7.718e-03
Step 921, residual loss 2.088e-02
Step 922, residual loss 3.641e-03
Step 923, residual loss 9.751e-04
Step 924, residual loss 1.982e-03
Step 925, residual loss 1.935e-03
Step 926, residual loss 6.966e-03
Step 927, residual loss 5.587e-03
Step 928, residual loss 3.955e-03
Step 929, residual loss 1.041e-03
Step 930, residual loss 6.161e-03
Step 931, residual loss 2.786e-02
Step 932, residual loss 3.515e-03
Step 933, residual loss 1.681e-02
Step 934, residual loss 5.316e-03
Step 935, residual loss 3.046e-02
Step 936, residual loss 1.118e-02
Step 937, residual loss 1.115e-02
Step 938, residual loss 1.379e-02
Step 939, residual loss 1.597e-02
Step 940, residual loss 8.111e-04
Step 941, residual loss 2.146e-02
Step 942, residual loss 3.365e-02
Step 943, residual loss 9.069e-04
Step 944, residual loss 4.293e-03
Step 945, residual loss 2.510e-02
Step 946, residual loss 4.000e-02
Step 947, residual loss 7.903e-03
Step 948, residual loss 5.667e-03
Step 949, residual loss 2.910e-02
Step 950, residual loss 8.749e-03
Step 951, residual loss 6.676e-02
Step 952, residual loss 7.253e-02
Step 953, residual loss 3.775e-02
Step 954, residual loss 8.557e-02
Step 955, residual loss 3.645e-02
Step 956, residual loss 1.741e-02
Step 957, residual loss 1.025e-03
Step 958, residual loss 1.680e-02
Step 959, residual loss 3.040e-03
Step 960, residual loss 8.427e-02
Step 961, residual loss 2.819e-02
Step 962, residual loss 2.564e-03
Step 963, residual loss 7.431e-02
Step 964, residual loss 3.191e-02
Step 965, residual loss 3.171e-03
Step 966, residual loss 2.272e-02
Step 967, residual loss 1.215e-02
Step 968, residual loss 5.868e-03
Step 969, residual loss 2.455e-02
Step 970, residual loss 7.002e-03
Step 971, residual loss 2.552e-02
Step 972, residual loss 3.528e-03
Step 973, residual loss 2.521e-03
Step 974, residual loss 5.296e-03
Step 975, residual loss 1.859e-02
Step 976, residual loss 5.122e-03
Step 977, residual loss 1.485e-02
Step 978, residual loss 5.264e-03
Step 979, residual loss 3.192e-04
Step 980, residual loss 1.339e-02
Step 981, residual loss 7.737e-03
Step 982, residual loss 2.263e-02
Step 983, residual loss 1.867e-02
Step 984, residual loss 2.431e-02
Step 985, residual loss 5.042e-03
Step 986, residual loss 1.328e-02
Step 987, residual loss 1.320e-03
Step 988, residual loss 2.036e-02
Step 989, residual loss 5.195e-03
Step 990, residual loss 1.642e-03
Step 991, residual loss 9.079e-02
Step 992, residual loss 2.940e-03
Step 993, residual loss 8.472e-03
Step 994, residual loss 1.654e-02
Step 995, residual loss 2.340e-03
Step 996, residual loss 3.054e-03
Step 997, residual loss 2.223e-02
Step 998, residual loss 2.702e-02
Step 999, residual loss 2.110e-02

Let’s visualize the loss:

fig, ax = plt.subplots()
ax.plot(losses)
ax.set_yscale('log')
ax.set_xlabel('Iteration')
ax.set_ylabel('Residual Loss')
sns.despine(trim=True);
../_images/9e089180b0a778aadce608fb78d83ede84d411da4407105e41851b81f2fb5efc.svg

Let’s make some predictions are random \(\xi\)’s:

import numpy as np
np.random.seed(1234)

u_true = eqx.filter_jit(vmap(lambda x, y, xi: xi * jnp.sin(jnp.pi * x) * jnp.sin(jnp.pi * y), (0, 0, None)))

u_pred = eqx.filter_jit(vmap(model, (0, 0, None)))

xs = np.linspace(0, 1, 100)
ys = np.linspace(0, 1, 100)
X, Y = np.meshgrid(xs, ys)
X_flat = X.flatten()
Y_flat = Y.flatten()

levels = np.linspace(0, 1.0, 20)
error_levels = np.linspace(0, 0.2, 20)

for xi in np.random.rand(5):
    fig, ax = plt.subplots(1, 3, figsize=(12, 4))
    ax[0].set_title(f"Prediction for $\\xi$={xi:.2f}")
    u = vmap(lambda x, y: model(x, y, xi))(xs, ys)
    Z_p = u_pred(X_flat, Y_flat, xi).reshape(X.shape)
    c = ax[0].contourf(X, Y, Z_p, levels=levels)
    fig.colorbar(c, ax=ax[0])
    ax[0].set_xlabel('$x$')
    ax[0].set_ylabel('$y$')
    sns.despine(trim=True)
    ax[1].set_title(f"True for $\\xi$={xi:.2f}")
    Z_t = u_true(X_flat, Y_flat, xi).reshape(X.shape)
    c = ax[1].contourf(X, Y, Z_t, levels=levels)
    fig.colorbar(c, ax=ax[1])
    ax[1].set_xlabel('$x$')
    ax[1].set_ylabel('$y$')
    sns.despine(trim=True)
    ax[2].set_title(f"Relative error for $\\xi$={xi:.2f}")
    E = np.abs(Z_p - Z_t) / np.max(Z_t)
    c = ax[2].contourf(X, Y, E, levels=error_levels)
    fig.colorbar(c, ax=ax[2])
    ax[2].set_xlabel('$x$')
    ax[2].set_ylabel('$y$')
    sns.despine(trim=True)
    plt.tight_layout()
../_images/295a66b44ddb482b7a02a0749559d437eb4062f21a819b850f03fba03615d3a7.svg ../_images/fc7cab554765b91ab0fe1b65684d6c923ac5b8ac927be4fc0deca83324341556.svg ../_images/76177c4aea6bc3034a819badb3f89357c1c51d51954e0bcc490dd43326817d9a.svg ../_images/b9fa7723ded1a6d20a7a527df797ab43346871525a3383ee69c11cd4b0dc3093.svg ../_images/835c740ba1a6b2f19edd3d4b11ecfad1f0ff0e5efa33dec419fe9270951ec9e9.svg

Okay, pretty good! We would have to train for more iterations to get a better fit.