Hide code cell source
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks")

Physics-informed regularization: Solving PDEs#

We learn how to solve PDEs with neural networks. We solve problems from [Lagaris et al., 1998].

import numpy as np
import torch
import torch.nn as nn

# This is useful for taking derivatives:
def grad(outputs, inputs):
    return torch.autograd.grad(outputs, inputs, grad_outputs=torch.ones_like(outputs), create_graph=True)[0]

Example: Solving PDEs#

consider a PDE of the form:

\[ \frac{\partial^2}{\partial x^2}\Psi(x,y) + \frac{\partial^2}{\partial y^2}\Psi(x,y) = f(x,y), \]

on \((x,y) \in [0,1]^2\) with Dirichlet boundary conditions:

\[ \Psi(0, y) = f_0(y), \]
\[ \Psi(1, y) = f_1(y), \]
\[ \Psi(x, 0) = g_0(x), \]

and

\[ \Psi(x, 1) = g_1(x). \]

We write:

\[ \hat{\Psi}(x,y;\theta) = A(x,y) + x(1-x)y(1-y)N(x,y;\theta), \]

where \(A(x,y)\) is chosen to satisfy the boundary conditions:

\[ A(x,y) = (1-x)f_0(y) + xf_1(y) + (1-y)\{g_0(x) - [(1-x)g_0(0)+xg_0(1)]\} + y\{g_1(x)-[(1-x)g_1(0) + xg_1(1)]\}. \]

The loss function that we need to minimize is:

\[ L(\theta) = \int_{[0,1]^2} \left\{\frac{\partial^2}{\partial x^2}\hat{\Psi}(x,y;\theta) + \frac{\partial^2}{\partial y^2}\hat{\Psi}(x,y;\theta) - f(x,y)\right\}^2dxdy. \]

Here is code that solves the same problem:

Hide code cell source
class PDEProblemDC(object):
    """A class representing PDE with DC boundary.

    Arguments
    ---------
    rhs   -- The right hand side of the equation.
    f0    -- Left boundary condition.
    f1    -- Right boundary condition.
    g0    -- Bottom boundary condition.
    g1    -- Top boundary condition.
    net   -- A neural network for representing the solution. This must have
                two-dimensional input and one-dimensional output.
    """
    def __init__(self, rhs, f0, f1, g0, g1, net):
        self._rhs = rhs
        self._f0 = f0
        self._f1 = f1
        self._g0 = g0
        self._g1 = g1
        self._net = net
        # This implements a function that satisfies the boundary conditions exactly
        g00 = self.g0(torch.zeros((1,)))[0]
        g01 = self.g0(torch.ones((1,)))[0]
        g10 = self.g1(torch.zeros((1,)))[0]
        g11 = self.g1(torch.ones((1,)))[0]
        def A(x):
            res = (1.0 - x[:, 0]) * self.f0(x[:, 1])
            res += x[:, 0] * self.f1(x[:, 1])
            res += (1.0 - x[:, 1]) * (self.g0(x[:, 0]) - ((1.0 - x[:, 0]) * g00 + x[:, 0] * g01))
            res += x[:, 1] * (self.g1(x[:, 0]) - ((1.0 - x[:, 0]) * g10 + x[:, 0] * g11))
            return res
        self._A = A
        self._solution = lambda x: self.A(x) + x[:, 0] * (1.0 - x[:, 0]) * x[:, 1] * (1.0 - x[:, 1]) * self.net(x)[:, 0]
    
    @property
    def rhs(self):
        return self._rhs
    
    @property
    def f0(self):
        return self._f0
    
    @property
    def f1(self):
        return self._f1
    
    @property
    def g0(self):
        return self._g0
    
    @property
    def g1(self):
        return self._g1
    
    @property
    def A(self):
        return self._A
    
    @property
    def net(self):
        return self._net
    
    @property
    def solution(self):
        """Return the solution function."""
        return self._solution
    
    def squared_residual_loss(self, X):
        """Returns the squared residual loss at spatial locations X.
        
        Arguments
        ---------
        X     -- The spatial locations where the loss is evaluated.
        """
        X.requires_grad = True
        sol = self.solution(X)
        A = self.A(X)
        sol_x = grad(sol, X)
        # Get the second derivatives
        sol_xx = grad(sol_x[:, 0], X)[:, 0]
        sol_yy = grad(sol_x[:, 1], X)[:, 1]
        rhs = self.rhs(X)
        return torch.mean((sol_xx + sol_yy - rhs) ** 2)
    
    def solve_lbfgs(self, X_colloc, max_iter=10):
        """Solve the problem by minimizing the squared residual loss.
        
        Arguments
        ---------
        X_colloc -- The collocation points used to solve the problem.
        """
        optimizer = torch.optim.LBFGS(self.net.parameters())

        # Run the optimizer
        def closure():
            optimizer.zero_grad()
            l = self.squared_residual_loss(X_colloc)
            l.backward()
            return l
        for i in range(max_iter):
            res = optimizer.step(closure)
            print(res)
    
    def solve_sgd(self, bounds, max_iter=1000, batch_size=10, lr=0.01):
        """Solve the problem using stochastic gradient descent.
        
        Arguments
        ---------
        bounds     -- The bounds of the domain. A 2x2 array. The first column
                        is the lower bound and the second column is the upper
                        bound.
        max_iter   -- The maximum number of iterations to do. Default is 1000.
        batch_size -- The batch size to use. Default is 10.
        lr         -- The learning rate. Default is 0.01.
        """
        optimizer = torch.optim.Adam(self.net.parameters(), lr=lr)

        for i in range(max_iter):
            # Randomly pick n_batch random x's:
            X = torch.rand(batch_size, 2)
            X = bounds[:, 0] + (bounds[:, 1] - bounds[:, 0]) * X
            # Zero-out the gradient buffers
            optimizer.zero_grad()
            # Evaluate the loss
            l = self.squared_residual_loss(X)
            # Calculate the gradients
            l.backward()
            # Update the network
            optimizer.step()
            # Print the iteration number
            if i % 100 == 99:
                print(f"it = {i}, loss = {l.item()}")
            

def plot_contour(ex, true_sol):
    """Plot the contour of the true solution and the approximation."""
    xx = np.linspace(0, 1, 64)
    X, Y = np.meshgrid(xx, xx)
    X_flat = torch.Tensor(np.hstack([X.flatten()[:, None], Y.flatten()[:, None]]))
    Z_flat = ex.solution(X_flat).detach().numpy()
    Z_t_flat = true_sol(X_flat)
    Z_t_flat = Z_t_flat.detach().numpy()
    Z = Z_flat.reshape(64, 64)
    Z_t = Z_t_flat.reshape(64, 64)
    fig, ax = plt.subplots()
    c = ax.contourf(X, Y, Z)
    ax.set_title("Neural network solution")
    plt.colorbar(c)
    fig, ax = plt.subplots()
    c = ax.contourf(X, Y, Z_t)
    ax.set_title("True solution")
    plt.colorbar(c)

Here is how to solve it with neural networks:

# Problem 5 of Lagaris
rhs = lambda x: torch.exp(-x[:, 0]) * (x[:, 0] - 2.0 + x[:, 1] ** 3 + 6.0 * x[:, 1])
f0 = lambda x2: x2 ** 3
f1 = lambda x2: (1.0 + x2 ** 3) * np.exp(-1.0)
g0 = lambda x1: x1 * torch.exp(-x1)
g1 = lambda x1: torch.exp(-x1) * (x1 + 1.0)
ex5 = PDEProblemDC(rhs, f0, f1, g0, g1,
                   nn.Sequential(nn.Linear(2, 10), nn.Sigmoid(), nn.Linear(10,1, bias=False)))
x = np.linspace(0, 1, 10)
X, Y = np.meshgrid(x, x)
X_flat = torch.Tensor(np.hstack([X.flatten()[:, None], Y.flatten()[:, None]]))
ex5.solve_lbfgs(X_flat);
tensor(0.1201, grad_fn=<MeanBackward0>)
tensor(6.5121e-05, grad_fn=<MeanBackward0>)
tensor(5.6691e-06, grad_fn=<MeanBackward0>)
tensor(4.8895e-06, grad_fn=<MeanBackward0>)
tensor(4.8895e-06, grad_fn=<MeanBackward0>)
tensor(4.8895e-06, grad_fn=<MeanBackward0>)
tensor(4.8895e-06, grad_fn=<MeanBackward0>)
tensor(4.8895e-06, grad_fn=<MeanBackward0>)
tensor(4.8895e-06, grad_fn=<MeanBackward0>)
tensor(4.8895e-06, grad_fn=<MeanBackward0>)
ex5_true_sol = lambda x: torch.exp(-x[:, 0]) * (x[:, 0] + x[:, 1] ** 3)
plot_contour(ex5, ex5_true_sol);
../_images/377b696a515fd58195e641918f8836d17df5465b2b633dc031305687298e5840.svg../_images/a9038f9756a2767b1bfff23ffed3e7ce12f2a6b3acc42382205acdef6c359940.svg
# Problem 6 of Lagaris
a = 3.0
def rhs(x):
    tmp1 = torch.exp(-(a * x[:, 0] + x[:, 1]) / 5.0)
    tmp2 = (-4.0 / 5.0 * a ** 3 * x[:, 0] - 2.0 / 5.0 + 2.0 * a ** 2) * torch.cos(a ** 2 * x[:, 0] ** 2 + x[:, 1])
    tmp2 += (1.0 / 25.0 - 1.0 - 4.0 * a ** 4 * x[:, 0] ** 2 + a ** 2 / 25.0) * torch.sin(a ** 2 * x[:, 0] ** 2 + x[:, 1])
    return tmp1 * tmp2
ex6_true_sol = lambda x: torch.exp(-(a * x[:, 0] + x[:, 1]) / 5.0) * torch.sin(a ** 2 * x[:, 0] ** 2 + x[:, 1])
f0 = lambda x2: ex6_true_sol(torch.stack((torch.zeros_like(x2), x2), dim=1))
f1 = lambda x2: ex6_true_sol(torch.stack((torch.ones_like(x2), x2), dim=1))
g0 = lambda x1: ex6_true_sol(torch.stack((x1, torch.zeros_like(x1)), dim=1))
g1 = lambda x1: ex6_true_sol(torch.stack((x1, torch.ones_like(x1)), dim=1))
net = nn.Sequential(nn.Linear(2, 10), nn.Sigmoid(), nn.Linear(10, 1, bias=False))
ex6 = PDEProblemDC(rhs, f0, f1, g0, g1, net)
x = np.linspace(0, 1, 10)
X, Y = np.meshgrid(x, x)
X_flat = torch.Tensor(np.hstack([X.flatten()[:, None], Y.flatten()[:, None]]))
# Does not always work because of local minima.
# Try multiple times.
# ex6.solve_lbfgs(X_flat, max_iter=10)
ex6.solve_sgd(torch.from_numpy(np.array([[0, 0], [1, 1]], dtype=np.float32)),
              max_iter=10000, batch_size=100, lr=0.01)
it = 99, loss = 2.2737367544323206e-13
it = 199, loss = 2.2737367544323206e-13
it = 299, loss = 2.2737367544323206e-13
it = 399, loss = 2.2737367544323206e-13
it = 499, loss = 2.2737367544323206e-13
it = 599, loss = 2.2737367544323206e-13
it = 699, loss = 2.2737367544323206e-13
it = 799, loss = 2.2737367544323206e-13
it = 899, loss = 2.2737367544323206e-13
it = 999, loss = 2.2737367544323206e-13
it = 1099, loss = 2.2737367544323206e-13
it = 1199, loss = 2.2737367544323206e-13
it = 1299, loss = 2.2737367544323206e-13
it = 1399, loss = 2.2737367544323206e-13
it = 1499, loss = 2.2737367544323206e-13
it = 1599, loss = 2.2737367544323206e-13
it = 1699, loss = 2.2737367544323206e-13
it = 1799, loss = 2.2737367544323206e-13
it = 1899, loss = 2.2737367544323206e-13
it = 1999, loss = 2.2737367544323206e-13
it = 2099, loss = 2.2737367544323206e-13
it = 2199, loss = 2.2737367544323206e-13
it = 2299, loss = 2.2737367544323206e-13
it = 2399, loss = 2.2737367544323206e-13
it = 2499, loss = 2.2737367544323206e-13
it = 2599, loss = 2.2737367544323206e-13
it = 2699, loss = 2.2737367544323206e-13
it = 2799, loss = 2.2737367544323206e-13
it = 2899, loss = 2.2737367544323206e-13
it = 2999, loss = 2.2737367544323206e-13
it = 3099, loss = 2.2737367544323206e-13
it = 3199, loss = 2.2737367544323206e-13
it = 3299, loss = 2.2737367544323206e-13
it = 3399, loss = 2.2737367544323206e-13
it = 3499, loss = 2.2737367544323206e-13
it = 3599, loss = 2.2737367544323206e-13
it = 3699, loss = 2.2737367544323206e-13
it = 3799, loss = 2.2737367544323206e-13
it = 3899, loss = 2.2737367544323206e-13
it = 3999, loss = 2.2737367544323206e-13
it = 4099, loss = 2.2737367544323206e-13
it = 4199, loss = 2.2737367544323206e-13
it = 4299, loss = 2.2737367544323206e-13
it = 4399, loss = 2.2737367544323206e-13
it = 4499, loss = 2.2737367544323206e-13
it = 4599, loss = 2.2737367544323206e-13
it = 4699, loss = 2.2737367544323206e-13
it = 4799, loss = 2.2737367544323206e-13
it = 4899, loss = 2.2737367544323206e-13
it = 4999, loss = 2.2737367544323206e-13
it = 5099, loss = 2.2737367544323206e-13
it = 5199, loss = 2.2737367544323206e-13
it = 5299, loss = 2.2737367544323206e-13
it = 5399, loss = 2.2737367544323206e-13
it = 5499, loss = 2.2737367544323206e-13
it = 5599, loss = 2.2737367544323206e-13
it = 5699, loss = 2.2737367544323206e-13
it = 5799, loss = 2.2737367544323206e-13
it = 5899, loss = 2.2737367544323206e-13
it = 5999, loss = 2.2737367544323206e-13
it = 6099, loss = 2.2737367544323206e-13
it = 6199, loss = 2.2737367544323206e-13
it = 6299, loss = 2.2737367544323206e-13
it = 6399, loss = 2.2737367544323206e-13
it = 6499, loss = 2.2737367544323206e-13
it = 6599, loss = 2.2737367544323206e-13
it = 6699, loss = 2.2737367544323206e-13
it = 6799, loss = 2.2737367544323206e-13
it = 6899, loss = 2.2737367544323206e-13
it = 6999, loss = 2.2737367544323206e-13
it = 7099, loss = 2.2737367544323206e-13
it = 7199, loss = 2.2737367544323206e-13
it = 7299, loss = 2.2737367544323206e-13
it = 7399, loss = 2.2737367544323206e-13
it = 7499, loss = 2.2737367544323206e-13
it = 7599, loss = 2.2737367544323206e-13
it = 7699, loss = 2.2737367544323206e-13
it = 7799, loss = 2.2737367544323206e-13
it = 7899, loss = 2.2737367544323206e-13
it = 7999, loss = 2.2737367544323206e-13
it = 8099, loss = 2.2737367544323206e-13
it = 8199, loss = 2.2737367544323206e-13
it = 8299, loss = 2.2737367544323206e-13
it = 8399, loss = 2.2737367544323206e-13
it = 8499, loss = 2.2737367544323206e-13
it = 8599, loss = 2.2737367544323206e-13
it = 8699, loss = 2.2737367544323206e-13
it = 8799, loss = 2.2737367544323206e-13
it = 8899, loss = 2.2737367544323206e-13
it = 8999, loss = 2.2737367544323206e-13
it = 9099, loss = 2.2737367544323206e-13
it = 9199, loss = 2.2737367544323206e-13
it = 9299, loss = 2.2737367544323206e-13
it = 9399, loss = 2.2737367544323206e-13
it = 9499, loss = 2.2737367544323206e-13
it = 9599, loss = 2.2737367544323206e-13
it = 9699, loss = 2.2737367544323206e-13
it = 9799, loss = 2.2737367544323206e-13
it = 9899, loss = 2.2737367544323206e-13
it = 9999, loss = 2.2737367544323206e-13
plot_contour(ex6, ex6_true_sol)
../_images/a7f6c3678e57077236e026505e47bfac842acdadc78adf997ce06e6e0d2f98b4.svg../_images/ffacb97fc5524dcb6e6e0fbbe5fcaae369847457bab689d86fc90fa3c4fc934a.svg

Questions#

Feel free to skip this, as it can be challenging if you are not an expert in Python.

  • Add a method to the class PDEProblemDC that uses stochastic gradient descent to solve the same problems. Once you are done, rerun the problems above with your code.

  • According to the Dirchlet principle, the solution of the PDE:

\[ \frac{\partial^2}{\partial x^2}\Psi(x,y) + \frac{\partial^2}{\partial y^2}\Psi(x,y) = f(x,y), \]

minimizes the energy functional:

\[ J[\Psi] = \int_{[0,1]^2} \left[\frac{1}{2}\parallel \nabla \Psi\parallel^2 + \Psi f\right]dxdy, \]

subject to the boundary conditions. This means that you can solve the problem by minimizing the loss function:

\[ J(\theta) = \int_{[0,1]^2} \left[\frac{1}{2}\parallel \nabla \hat{\Psi}(x,y;\theta)\parallel^2 + \hat{\Psi}(x,y;\theta) f(x,y)\right]dxdy. \]

Add this functionality to the class PDEProblemDC.