Show code cell source
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks")
Physics-informed regularization: Solving PDEs#
We learn how to solve PDEs with neural networks. We solve problems from [Lagaris et al., 1998].
import numpy as np
import torch
import torch.nn as nn
# This is useful for taking derivatives:
def grad(outputs, inputs):
return torch.autograd.grad(outputs, inputs, grad_outputs=torch.ones_like(outputs), create_graph=True)[0]
Example: Solving PDEs#
consider a PDE of the form:
on \((x,y) \in [0,1]^2\) with Dirichlet boundary conditions:
and
We write:
where \(A(x,y)\) is chosen to satisfy the boundary conditions:
The loss function that we need to minimize is:
Here is code that solves the same problem:
Show code cell source
class PDEProblemDC(object):
"""A class representing PDE with DC boundary.
Arguments
---------
rhs -- The right hand side of the equation.
f0 -- Left boundary condition.
f1 -- Right boundary condition.
g0 -- Bottom boundary condition.
g1 -- Top boundary condition.
net -- A neural network for representing the solution. This must have
two-dimensional input and one-dimensional output.
"""
def __init__(self, rhs, f0, f1, g0, g1, net):
self._rhs = rhs
self._f0 = f0
self._f1 = f1
self._g0 = g0
self._g1 = g1
self._net = net
# This implements a function that satisfies the boundary conditions exactly
g00 = self.g0(torch.zeros((1,)))[0]
g01 = self.g0(torch.ones((1,)))[0]
g10 = self.g1(torch.zeros((1,)))[0]
g11 = self.g1(torch.ones((1,)))[0]
def A(x):
res = (1.0 - x[:, 0]) * self.f0(x[:, 1])
res += x[:, 0] * self.f1(x[:, 1])
res += (1.0 - x[:, 1]) * (self.g0(x[:, 0]) - ((1.0 - x[:, 0]) * g00 + x[:, 0] * g01))
res += x[:, 1] * (self.g1(x[:, 0]) - ((1.0 - x[:, 0]) * g10 + x[:, 0] * g11))
return res
self._A = A
self._solution = lambda x: self.A(x) + x[:, 0] * (1.0 - x[:, 0]) * x[:, 1] * (1.0 - x[:, 1]) * self.net(x)[:, 0]
@property
def rhs(self):
return self._rhs
@property
def f0(self):
return self._f0
@property
def f1(self):
return self._f1
@property
def g0(self):
return self._g0
@property
def g1(self):
return self._g1
@property
def A(self):
return self._A
@property
def net(self):
return self._net
@property
def solution(self):
"""Return the solution function."""
return self._solution
def squared_residual_loss(self, X):
"""Returns the squared residual loss at spatial locations X.
Arguments
---------
X -- The spatial locations where the loss is evaluated.
"""
X.requires_grad = True
sol = self.solution(X)
A = self.A(X)
sol_x = grad(sol, X)
# Get the second derivatives
sol_xx = grad(sol_x[:, 0], X)[:, 0]
sol_yy = grad(sol_x[:, 1], X)[:, 1]
rhs = self.rhs(X)
return torch.mean((sol_xx + sol_yy - rhs) ** 2)
def solve_lbfgs(self, X_colloc, max_iter=10):
"""Solve the problem by minimizing the squared residual loss.
Arguments
---------
X_colloc -- The collocation points used to solve the problem.
"""
optimizer = torch.optim.LBFGS(self.net.parameters())
# Run the optimizer
def closure():
optimizer.zero_grad()
l = self.squared_residual_loss(X_colloc)
l.backward()
return l
for i in range(max_iter):
res = optimizer.step(closure)
print(res)
def solve_sgd(self, bounds, max_iter=1000, batch_size=10, lr=0.01):
"""Solve the problem using stochastic gradient descent.
Arguments
---------
bounds -- The bounds of the domain. A 2x2 array. The first column
is the lower bound and the second column is the upper
bound.
max_iter -- The maximum number of iterations to do. Default is 1000.
batch_size -- The batch size to use. Default is 10.
lr -- The learning rate. Default is 0.01.
"""
optimizer = torch.optim.Adam(self.net.parameters(), lr=lr)
for i in range(max_iter):
# Randomly pick n_batch random x's:
X = torch.rand(batch_size, 2)
X = bounds[:, 0] + (bounds[:, 1] - bounds[:, 0]) * X
# Zero-out the gradient buffers
optimizer.zero_grad()
# Evaluate the loss
l = self.squared_residual_loss(X)
# Calculate the gradients
l.backward()
# Update the network
optimizer.step()
# Print the iteration number
if i % 100 == 99:
print(f"it = {i}, loss = {l.item()}")
def plot_contour(ex, true_sol):
"""Plot the contour of the true solution and the approximation."""
xx = np.linspace(0, 1, 64)
X, Y = np.meshgrid(xx, xx)
X_flat = torch.Tensor(np.hstack([X.flatten()[:, None], Y.flatten()[:, None]]))
Z_flat = ex.solution(X_flat).detach().numpy()
Z_t_flat = true_sol(X_flat)
Z_t_flat = Z_t_flat.detach().numpy()
Z = Z_flat.reshape(64, 64)
Z_t = Z_t_flat.reshape(64, 64)
fig, ax = plt.subplots()
c = ax.contourf(X, Y, Z)
ax.set_title("Neural network solution")
plt.colorbar(c)
fig, ax = plt.subplots()
c = ax.contourf(X, Y, Z_t)
ax.set_title("True solution")
plt.colorbar(c)
Here is how to solve it with neural networks:
# Problem 5 of Lagaris
rhs = lambda x: torch.exp(-x[:, 0]) * (x[:, 0] - 2.0 + x[:, 1] ** 3 + 6.0 * x[:, 1])
f0 = lambda x2: x2 ** 3
f1 = lambda x2: (1.0 + x2 ** 3) * np.exp(-1.0)
g0 = lambda x1: x1 * torch.exp(-x1)
g1 = lambda x1: torch.exp(-x1) * (x1 + 1.0)
ex5 = PDEProblemDC(rhs, f0, f1, g0, g1,
nn.Sequential(nn.Linear(2, 10), nn.Sigmoid(), nn.Linear(10,1, bias=False)))
x = np.linspace(0, 1, 10)
X, Y = np.meshgrid(x, x)
X_flat = torch.Tensor(np.hstack([X.flatten()[:, None], Y.flatten()[:, None]]))
ex5.solve_lbfgs(X_flat);
tensor(0.1201, grad_fn=<MeanBackward0>)
tensor(6.5121e-05, grad_fn=<MeanBackward0>)
tensor(5.6691e-06, grad_fn=<MeanBackward0>)
tensor(4.8895e-06, grad_fn=<MeanBackward0>)
tensor(4.8895e-06, grad_fn=<MeanBackward0>)
tensor(4.8895e-06, grad_fn=<MeanBackward0>)
tensor(4.8895e-06, grad_fn=<MeanBackward0>)
tensor(4.8895e-06, grad_fn=<MeanBackward0>)
tensor(4.8895e-06, grad_fn=<MeanBackward0>)
tensor(4.8895e-06, grad_fn=<MeanBackward0>)
ex5_true_sol = lambda x: torch.exp(-x[:, 0]) * (x[:, 0] + x[:, 1] ** 3)
plot_contour(ex5, ex5_true_sol);
# Problem 6 of Lagaris
a = 3.0
def rhs(x):
tmp1 = torch.exp(-(a * x[:, 0] + x[:, 1]) / 5.0)
tmp2 = (-4.0 / 5.0 * a ** 3 * x[:, 0] - 2.0 / 5.0 + 2.0 * a ** 2) * torch.cos(a ** 2 * x[:, 0] ** 2 + x[:, 1])
tmp2 += (1.0 / 25.0 - 1.0 - 4.0 * a ** 4 * x[:, 0] ** 2 + a ** 2 / 25.0) * torch.sin(a ** 2 * x[:, 0] ** 2 + x[:, 1])
return tmp1 * tmp2
ex6_true_sol = lambda x: torch.exp(-(a * x[:, 0] + x[:, 1]) / 5.0) * torch.sin(a ** 2 * x[:, 0] ** 2 + x[:, 1])
f0 = lambda x2: ex6_true_sol(torch.stack((torch.zeros_like(x2), x2), dim=1))
f1 = lambda x2: ex6_true_sol(torch.stack((torch.ones_like(x2), x2), dim=1))
g0 = lambda x1: ex6_true_sol(torch.stack((x1, torch.zeros_like(x1)), dim=1))
g1 = lambda x1: ex6_true_sol(torch.stack((x1, torch.ones_like(x1)), dim=1))
net = nn.Sequential(nn.Linear(2, 10), nn.Sigmoid(), nn.Linear(10, 1, bias=False))
ex6 = PDEProblemDC(rhs, f0, f1, g0, g1, net)
x = np.linspace(0, 1, 10)
X, Y = np.meshgrid(x, x)
X_flat = torch.Tensor(np.hstack([X.flatten()[:, None], Y.flatten()[:, None]]))
# Does not always work because of local minima.
# Try multiple times.
# ex6.solve_lbfgs(X_flat, max_iter=10)
ex6.solve_sgd(torch.from_numpy(np.array([[0, 0], [1, 1]], dtype=np.float32)),
max_iter=10000, batch_size=100, lr=0.01)
it = 99, loss = 2.2737367544323206e-13
it = 199, loss = 2.2737367544323206e-13
it = 299, loss = 2.2737367544323206e-13
it = 399, loss = 2.2737367544323206e-13
it = 499, loss = 2.2737367544323206e-13
it = 599, loss = 2.2737367544323206e-13
it = 699, loss = 2.2737367544323206e-13
it = 799, loss = 2.2737367544323206e-13
it = 899, loss = 2.2737367544323206e-13
it = 999, loss = 2.2737367544323206e-13
it = 1099, loss = 2.2737367544323206e-13
it = 1199, loss = 2.2737367544323206e-13
it = 1299, loss = 2.2737367544323206e-13
it = 1399, loss = 2.2737367544323206e-13
it = 1499, loss = 2.2737367544323206e-13
it = 1599, loss = 2.2737367544323206e-13
it = 1699, loss = 2.2737367544323206e-13
it = 1799, loss = 2.2737367544323206e-13
it = 1899, loss = 2.2737367544323206e-13
it = 1999, loss = 2.2737367544323206e-13
it = 2099, loss = 2.2737367544323206e-13
it = 2199, loss = 2.2737367544323206e-13
it = 2299, loss = 2.2737367544323206e-13
it = 2399, loss = 2.2737367544323206e-13
it = 2499, loss = 2.2737367544323206e-13
it = 2599, loss = 2.2737367544323206e-13
it = 2699, loss = 2.2737367544323206e-13
it = 2799, loss = 2.2737367544323206e-13
it = 2899, loss = 2.2737367544323206e-13
it = 2999, loss = 2.2737367544323206e-13
it = 3099, loss = 2.2737367544323206e-13
it = 3199, loss = 2.2737367544323206e-13
it = 3299, loss = 2.2737367544323206e-13
it = 3399, loss = 2.2737367544323206e-13
it = 3499, loss = 2.2737367544323206e-13
it = 3599, loss = 2.2737367544323206e-13
it = 3699, loss = 2.2737367544323206e-13
it = 3799, loss = 2.2737367544323206e-13
it = 3899, loss = 2.2737367544323206e-13
it = 3999, loss = 2.2737367544323206e-13
it = 4099, loss = 2.2737367544323206e-13
it = 4199, loss = 2.2737367544323206e-13
it = 4299, loss = 2.2737367544323206e-13
it = 4399, loss = 2.2737367544323206e-13
it = 4499, loss = 2.2737367544323206e-13
it = 4599, loss = 2.2737367544323206e-13
it = 4699, loss = 2.2737367544323206e-13
it = 4799, loss = 2.2737367544323206e-13
it = 4899, loss = 2.2737367544323206e-13
it = 4999, loss = 2.2737367544323206e-13
it = 5099, loss = 2.2737367544323206e-13
it = 5199, loss = 2.2737367544323206e-13
it = 5299, loss = 2.2737367544323206e-13
it = 5399, loss = 2.2737367544323206e-13
it = 5499, loss = 2.2737367544323206e-13
it = 5599, loss = 2.2737367544323206e-13
it = 5699, loss = 2.2737367544323206e-13
it = 5799, loss = 2.2737367544323206e-13
it = 5899, loss = 2.2737367544323206e-13
it = 5999, loss = 2.2737367544323206e-13
it = 6099, loss = 2.2737367544323206e-13
it = 6199, loss = 2.2737367544323206e-13
it = 6299, loss = 2.2737367544323206e-13
it = 6399, loss = 2.2737367544323206e-13
it = 6499, loss = 2.2737367544323206e-13
it = 6599, loss = 2.2737367544323206e-13
it = 6699, loss = 2.2737367544323206e-13
it = 6799, loss = 2.2737367544323206e-13
it = 6899, loss = 2.2737367544323206e-13
it = 6999, loss = 2.2737367544323206e-13
it = 7099, loss = 2.2737367544323206e-13
it = 7199, loss = 2.2737367544323206e-13
it = 7299, loss = 2.2737367544323206e-13
it = 7399, loss = 2.2737367544323206e-13
it = 7499, loss = 2.2737367544323206e-13
it = 7599, loss = 2.2737367544323206e-13
it = 7699, loss = 2.2737367544323206e-13
it = 7799, loss = 2.2737367544323206e-13
it = 7899, loss = 2.2737367544323206e-13
it = 7999, loss = 2.2737367544323206e-13
it = 8099, loss = 2.2737367544323206e-13
it = 8199, loss = 2.2737367544323206e-13
it = 8299, loss = 2.2737367544323206e-13
it = 8399, loss = 2.2737367544323206e-13
it = 8499, loss = 2.2737367544323206e-13
it = 8599, loss = 2.2737367544323206e-13
it = 8699, loss = 2.2737367544323206e-13
it = 8799, loss = 2.2737367544323206e-13
it = 8899, loss = 2.2737367544323206e-13
it = 8999, loss = 2.2737367544323206e-13
it = 9099, loss = 2.2737367544323206e-13
it = 9199, loss = 2.2737367544323206e-13
it = 9299, loss = 2.2737367544323206e-13
it = 9399, loss = 2.2737367544323206e-13
it = 9499, loss = 2.2737367544323206e-13
it = 9599, loss = 2.2737367544323206e-13
it = 9699, loss = 2.2737367544323206e-13
it = 9799, loss = 2.2737367544323206e-13
it = 9899, loss = 2.2737367544323206e-13
it = 9999, loss = 2.2737367544323206e-13
plot_contour(ex6, ex6_true_sol)
Questions#
Feel free to skip this, as it can be challenging if you are not an expert in Python.
Add a method to the class
PDEProblemDC
that uses stochastic gradient descent to solve the same problems. Once you are done, rerun the problems above with your code.According to the Dirchlet principle, the solution of the PDE:
minimizes the energy functional:
subject to the boundary conditions. This means that you can solve the problem by minimizing the loss function:
Add this functionality to the class PDEProblemDC
.