lululxvi / deepxde

A library for scientific machine learning and physics-informed learning
https://deepxde.readthedocs.io
GNU Lesser General Public License v2.1
2.63k stars 739 forks source link

Use network output in boundary condition #603

Closed filipbojovic closed 2 years ago

filipbojovic commented 2 years ago

Dear Dr Lu Lu,

First, I would like to thank you for making the library, it is great. I searched the issues and could not find whether it is possible to use the network's output in the boundary condition? As inputs I have x and y, output is phi. Boundary condition is valid only when y = phi. It is 2D Free Surface problem. Is this possible to do it in DeepXDE? When the boundary condition is met, the value is given with:

def free_surface_pde(x, y, _):

    k1 = tf.gradients(y, x)[0][:, 0:1]
    alpha = tf.math.atan(k1) + np.pi / 2
    nx = tf.math.cos(alpha)
    ny = tf.math.sin(alpha)

    dp_dx = tf.gradients(y, x)[0]
    dp_dy = dp_dx[:, 1:2]
    dp_dx = dp_dx[:, 0:1]

    return k * (dp_dx * nx + dp_dy * ny)

def boundary_fs(x, on_boundary):
    return on_boundary and np.isclose(x[1], phi?)

bc_fs = dde.icbc.OperatorBC(geom, func = free_surface_pde, on_boundary = boundary_fs)

I can't find in the documentation how to define boundary_fs so that the free_surface_pde is valid only when y = phi.

Best regards, Filip.

praksharma commented 2 years ago

Do you have ground truth/ simulation data to compare with. I guess your implementation is correct.

filipbojovic commented 2 years ago

@praksharma I did not define the boundary conditions well. Below are the definition of the problem and the boundary conditions. Neumann's condition should be valid only when y = phi. How to define that?

image

def pde(x, p):
    dp_x = tf.gradients(p, x)[0]
    dp_x, dp_y = dp_x[:, 0:1], dp_x[:, 1:2]

    dp_xx = tf.gradients(dp_x, x)[0][:, 0:1]
    dp_yy = tf.gradients(dp_y, x)[0][:, 1:2]

    return tf.cast(tf.greater(p, x[:, 1:2]), tf.float32) * dp_xx + dp_yy

x_start = 0
x_end = 2.0
y_start = 0
y_end = 2.0

phi_value = 2.0

phi_start_x = 2.0
phi_end_x = 1.0
phi_start_y = phi_value
phi_end_y = phi_value

def boundary_xs(x, on_boundary):
    return on_boundary and np.isclose(x[0], x_start)

def boundary_xe(x, on_boundary):
    return on_boundary and np.isclose(x[0], x_end)

def boundary_ys(x, on_boundary):
    return on_boundary and np.isclose(x[1], y_start)

def boundary_ye(x, on_boundary):
    return here should be np.isclose(x[1], output) < -------------------------------

def func_xs(x):
    return phi_start_x * np.ones((len(x), 1), dtype = np.float32)

def func_xe(x):
    return phi_end_x * np.ones((len(x), 1), dtype = np.float32)

def func_ys(x, y, X):
    return tf.gradients(y, x)[0][:, 1:2]

def func_ye(x):
    return np.zeros((len(x), 1), dtype = np.float32)

def true_phi(x):
    return np.sqrt(4 - 1.5 * x)

geom = dde.geometry.Rectangle(xmin = [.0, .0], xmax = [x_end, y_end])

bc_xs = dde.icbc.DirichletBC(geom, func = func_xs, on_boundary = boundary_xs)
bc_xe = dde.icbc.DirichletBC(geom, func = func_xe, on_boundary = boundary_xe)
bc_ys = dde.icbc.OperatorBC(geom, func = func_ys, on_boundary = boundary_ys)
bc_ye = dde.icbc.NeumannBC(geom, func = func_ye, on_boundary = boundary_ye)
praksharma commented 2 years ago

Actually there is a problem. Are you sure this problem is well posed? Also, how do you implement this Neumann BC in finite elements?

filipbojovic commented 2 years ago

The problem is posed in the same was as is posed in SciANN library. The way the Neumann BC is defined in finite elements is not the same as the way in PINN. Thanks for answer.

lululxvi commented 2 years ago

There is no example of free surface problems. We don't know where is the location of the BC, and we don't know the normal direction of the BC. At this moment, I have no idea how to do it. Do you know how other people solve it?

praksharma commented 2 years ago

That is why I asked him to tell us how people solve it using FEM? OR how other PINNs implemented it. I am really interested in knowing how people model this problem in ANSYS Mechanical/Fluent. Please let us know which BC from ANSYS they use.

forxltk commented 2 years ago

I don't know whether this paper for free boundary problem in PINN can help you https://arxiv.org/abs/2006.05311.

filipbojovic commented 2 years ago

Dear all,

I am sorry for very late reply. Thank you for your interest.

Two-dimensional steady flow through porous medium is governed by a difference in a potential of two surface. The potential of surface above the free surface (FS) is always 0, and the potential of surface bellow FS, using Darcy's law, is governed by the mass equation: image Here, kx = ky = 1 are coefficients of permeability. This is the main PDE.

The way the problem is tried to be solved using SciANN is as follows: The domain is governed by rectangle (x_start = 0m, y_start = 0m) and (x_end = 2m, y_end = 2m). The potential of seepage region (bellow FS) is governed by previously defined PDE and following conditions: phi (x = x_start, y) = y_end phi (x = x_end, y) = 1m dphi / dy = 0, when y = 0 and the last condition, the one that I failed to find a way to implement it in DeepXDE is: dphi / dn = 0, when phi is close to y, otherwise phi = 0. This is the FS condition. My problem in DeepXDE is that I can't get a output of approximator network (phi in this case) in definition of boundary condition.

On the end, just pass through the all outputs and take only those points (x, y) for which the output phi is close to y. That's the points on the free surface.

image

The results is not perfect, but I think that in DeepXDE it would be even better because of the way the collocation points are made, RAR method, etc. One important thing, the result was achieved thanks to the to analytic solution. That is because more boundary points were generated around area where FS is expected to be. I am not sure if this is the right way to solve it using PINNs, but some results are achived. What is your opinion about this methodology?

lululxvi commented 2 years ago

It is doable to check if phi is close to y, but how do you compute the vector n for each point?

forxltk commented 2 years ago

Dear all,

I am sorry for very late reply. Thank you for your interest.

Two-dimensional steady flow through porous medium is governed by a difference in a potential of two surface. The potential of surface above the free surface (FS) is always 0, and the potential of surface bellow FS, using Darcy's law, is governed by the mass equation: image Here, kx = ky = 1 are coefficients of permeability. This is the main PDE.

The way the problem is tried to be solved using SciANN is as follows: The domain is governed by rectangle (x_start = 0m, y_start = 0m) and (x_end = 2m, y_end = 2m). The potential of seepage region (bellow FS) is governed by previously defined PDE and following conditions: phi (x = x_start, y) = y_end phi (x = x_end, y) = 1m dphi / dy = 0, when y = 0 and the last condition, the one that I failed to find a way to implement it in DeepXDE is: dphi / dn = 0, when phi is close to y, otherwise phi = 0. This is the FS condition. My problem in DeepXDE is that I can't get a output of approximator network (phi in this case) in definition of boundary condition.

On the end, just pass through the all outputs and take only those points (x, y) for which the output phi is close to y. That's the points on the free surface.

image

The results is not perfect, but I think that in DeepXDE it would be even better because of the way the collocation points are made, RAR method, etc. One important thing, the result was achieved thanks to the to analytic solution. That is because more boundary points were generated around area where FS is expected to be. I am not sure if this is the right way to solve it using PINNs, but some results are achived. What is your opinion about this methodology?

@filipbojovic May I have your SciANN script for the implementation of FS? I am quite interesting with it. Many thanks!

filipbojovic commented 2 years ago

@forxltk @lululxvi here is the script.

import numpy as np
import matplotlib.pyplot as plt 
import sciann as sn
from sciann.utils.math import diff, sign, sin, cos, atan, sqrt, abs

sn.reset_session()

""" --------------- HYPERPARAMETERS --------------- """
activation_function = 'tanh'
num_of_epochs = 300
batch_size = 1024
lr = 0.001
hidden_layers = 8 * [20]
num_of_train_points = 210
optimizer = 'adam'

x_max = 2.0
y_max = 2.1

""" --------------- TRAIN DATASET --------------- """
x_train, y_train = np.meshgrid(
    np.linspace(0, x_max, num_of_train_points), 
    np.linspace(0, y_max, int(num_of_train_points * y_max / x_max))
)
x_train=x_train.flatten()
y_train=y_train.flatten()
x_train, y_train = np.array(x_train), np.array(y_train)

""" --------------- GENERATING POINTS WHERE FS IS EXPECTED --------------- """
num_of_fs_points = 801
x1, y1 = np.meshgrid(
    np.linspace(0, x_max, num_of_fs_points),
    np.linspace(0, y_max, int(num_of_fs_points * y_max / x_max))
)
x1=x1.flatten()
y1=y1.flatten()

x_region=[]
y_region=[]
for i in range(x1.shape[0]):
    if np.abs(y1[i] - (2-0.5*x1[i])) < 0.15:
        x_region.append(x1[i])
        y_region.append(y1[i])

x_train = np.concatenate((x_train, x_region))
y_train = np.concatenate((y_train, y_region))

x = sn.Variable('x')
y = sn.Variable('y')
phi = sn.Functional('phi', [x, y], 8 * [20], activation_function)

k = 1.
TOL = 0.015
fun1 =  diff(phi, x, order=2) + diff(phi, y, order=2)
C1 = (1-sign(x - (0+TOL))) * (phi-2)
C2 = (1+sign(x - (2-TOL))) * (phi-1) 
N1 = (1-sign(y - (0+TOL))) * diff(phi,y)

""" --------------- dn_dx = 0 --------------- """
k1 = diff(phi,x)
alpha = atan(k1)+np.pi/2
nx = cos(alpha)
ny = sin(alpha)
FS1 = (abs(y-phi)<0.009) * k * (diff(phi,x)*nx + diff(phi,y)*ny)

""" --------------- DEFINING MODEL --------------- """
m2 = sn.SciModel([x, y], [fun1, C1, C2, N1, FS1],  optimizer = optimizer)

""" --------------- TRAINING MODEL --------------- """
pinn_model = m2.train([x_train, y_train], 5 * ['zero'], learning_rate = lr, batch_size = batch_size, epochs = num_of_epochs, stop_loss_value = 1E-15, verbose = 2)

""" --------------- TEST DATASET --------------- """
num_of_test_points = 101
x_test, y_test = np.meshgrid(
    np.linspace(0, x_max, num_of_test_points), 
    np.linspace(0, y_max, num_of_test_points)
)
x_test, y_test = np.array(x_test).reshape(-1, 1), np.array(y_test).reshape(-1, 1)

""" --------------- MODEL EVALUATION --------------- """
phi_pred = phi.eval(m2, [x_test, y_test])

""" --------------- TAKE POINTS ON FREE SURFACE (phi_pred(x_test, y_test) = y_test) --------------- """
phi_pred.reshape(-1, num_of_test_points)
x_test.reshape(-1, num_of_test_points)
y_test.reshape(-1, num_of_test_points)
x_graph = []
y_graph = []
phi_graph = []
sum_points = 0.0
n = 1
for i in range(phi_pred.shape[0]):
        for j in range(phi_pred.shape[1]):
             if np.abs(y_test[i, j] - phi_pred[i, j]) < TOL:
                 x_graph.append(x_test[i, j])
                 y_graph.append(y_test[i, j])
                 phi_graph.append(phi_pred[i][j])
                 real = np.math.sqrt(4 - 1.5 * x_test[i, j]) # analytic solution where FS is expected to be
                 sum_points += np.math.pow(real - phi_pred[i, j], 2)
rmse = np.math.sqrt(sum_points / n)

print(rmse)

""" --------------- PLOT RESULTS --------------- """
plt.clf()
plt.plot(x_graph, phi_graph, 'red', label = "PINN free surface")
x_axes = np.linspace(0, x_max, num_of_test_points)
plt.plot(x_axes, np.sqrt(4 - 1.5 * x_axes), linewidth = 3, color= 'blue', linestyle = 'dotted', label = "analytic free surface")
plt.xlabel("x [m]")
plt.ylabel("$\Phi$ [m]")
plt.legend()
plt.savefig("results_free_surface_phi.png")
plt.show()

It is doable to check if phi is close to y.

Could you please tell me how to do it in DeepXDE?

forxltk commented 2 years ago

@filipbojovic

import matplotlib.pyplot as plt
import numpy as np
import deepxde as dde
import time as time
import matplotlib

from matplotlib import cm
import tensorflow as tf
from deepxde.callbacks import EarlyStopping

tol = 0.001
#dde.config.real.set_float64()

def pde(x, p):
    dpxx = dde.grad.hessian(p, x, i=0, j=0)
    dpyy = dde.grad.hessian(p, x, i=1, j=1)
    main_pde = dpxx+dpyy

    fs_condition = tf.less(tf.abs(x[:, 1:2] - p), tol)
    lmbd = tf.where(fs_condition, 1.0, 0.0)

    dpx = dde.grad.jacobian(p, x, j=0)
    dpy = dde.grad.jacobian(p, x, j=1)
    alpha = tf.math.atan(dpx) + np.pi / 2
    nx = tf.math.cos(alpha)
    ny = tf.math.sin(alpha)

    # Test  tf.cast(tf.greater(p, x[:, 1:2]), tf.float32)*main_pde
    return [main_pde,  lmbd*(dpx*nx+dpy*ny)]

def boundary_left(x, on_boundary):
    return on_boundary and np.isclose(x[0], 0.0)

def boundary_right(x, on_boundary):
    return on_boundary and np.isclose(x[0], 2.0)

def boundary_bot(x, on_boundary):
    return on_boundary and np.isclose(x[1], 0.0)

def func_fs(x, y, _):
    return dde.grad.jacobian(y, x, j=1)

geom = dde.geometry.Rectangle([0.0, 0.0], [2.0, 2.0])

bc_left = dde.DirichletBC(geom, lambda x: 2, boundary_left)
bc_right = dde.DirichletBC(geom, lambda x: 1, boundary_right)
bc_free_surface = dde.OperatorBC(geom, func_fs, boundary_bot)

BC = [bc_left, bc_right, bc_free_surface]

data = dde.data.PDE(
    geom, pde, BC,
    num_domain=100000,
    num_boundary=2000,
    )

layer_size = [2] + [20] * 8 + [1]
activation = "tanh"
initializer = "Glorot uniform"

net = dde.maps.FNN(layer_size, activation, initializer)
model = dde.Model(data, net)

loss_weights = [1, 1e5, 1, 1, 1]
model.compile('adam', lr=0.001, loss_weights=loss_weights)
checkpointer = dde.callbacks.ModelCheckpoint(
     "model/model.ckpt", verbose=1, save_better_only=True, period=1000)
#model.restore("model/model.ckpt-8000.ckpt", verbose=1)
model.train(epochs=10000, display_every=1000, callbacks=[checkpointer])

num_of_test_points = 101
x_test, y_test = np.meshgrid(
    np.linspace(0, 2.0, num_of_test_points),
    np.linspace(0, 2.1, num_of_test_points)
)
x_test, y_test = np.array(x_test).reshape(-1, 1), np.array(y_test).reshape(-1, 1)
X = np.vstack((np.ravel(x_test), np.ravel(y_test))).T

phi_pred = model.predict(X)

""" --------------- TAKE POINTS ON FREE SURFACE (phi_pred(x_test, y_test) = y_test) --------------- """
phi_pred.reshape(-1, num_of_test_points)
x_test.reshape(-1, num_of_test_points)
y_test.reshape(-1, num_of_test_points)
x_graph = []
y_graph = []
phi_graph = []
sum_points = 0.0
n = 1
for i in range(phi_pred.shape[0]):
        for j in range(phi_pred.shape[1]):
             if np.abs(y_test[i, j] - phi_pred[i, j]) < 0.015:
                 x_graph.append(x_test[i, j])
                 y_graph.append(y_test[i, j])
                 phi_graph.append(phi_pred[i][j])
                 real = np.math.sqrt(4 - 1.5 * x_test[i, j]) # analytic solution where FS is expected to be
                 sum_points += np.math.pow(real - phi_pred[i, j], 2)
rmse = np.math.sqrt(sum_points / n)

print(rmse)

""" --------------- PLOT RESULTS --------------- """
plt.clf()
plt.plot(x_graph, phi_graph, 'red', label = "PINN free surface")
x_axes = np.linspace(0, 2.0, num_of_test_points)
plt.plot(x_axes, np.sqrt(4 - 1.5 * x_axes), linewidth = 3, color= 'blue', linestyle = 'dotted', label = "analytic free surface")
plt.xlabel("x [m]")
plt.ylabel("$\Phi$ [m]")
plt.legend()
plt.savefig("results_free_surface_phi.png")
plt.show()

This is my implementation in DeepXDE based on your script. And there is a slight difference. 1- The fs condition is defined in the pde, but not in boundary condition. 2- The training points are sampled randomly in the domain. You can certainly generate the fs points and use them in the model as anchor as your SciANN script. But I dont think it is doable for some complicate problem, so there are no additional sample points around the fs.

And the rmse of fs are around 0.09. image

Maybe you can change the tol num_domain loss_weights and hyperparameter for better result, but it requires some effort. Best wishes!

filipbojovic commented 2 years ago

@forxltk Thank you so much for this! Best wishes too!