lululxvi / deepxde

A library for scientific machine learning and physics-informed learning
https://deepxde.readthedocs.io
GNU Lesser General Public License v2.1
2.61k stars 734 forks source link

strict periodic boundary condition #44

Closed smao-astro closed 4 years ago

smao-astro commented 4 years ago

Hi Lu,

Thank you for sharing your excellent work!

I am currently working on 2D disk hydrodynamic problem, for example, in Cartesian coordinate system, the domain is x**2+y**2<1, however, since the problem and the boundary condition is more simple under Polar coordinate system (r, theta), one usually want to define the problem in such a coordinate system, as a result, the domain is 0<r<1 && 0<theta<2*pi. In such a case, one need to enforce period boundary condition on theta in order to constrain the problem and get the solution. I am not comfortable with it for these reasons

  1. Physically, point (r,theta=0) and point (r,theta=2 pi) is exactly the same point. There is no such thing like a boundary or a step when walking through the "boundary" from 2 pi to 0.
  2. Because of the reason above, the physical quantities (and its any order derivative) should strictly equal at theta=2*pi and 0, a soft constraint such as adding sample points and then applying a loss function can not guarantee a strict equal since it is almost impossible to make the loss equal to 0, though it can infinitely close to.
  3. Adding one more boundary could make the neural network harder to train.

I noticed that in the work https://doi.org/10.1137/19M1260141, they implemented a method that might could avoid the problem I mentioned above, see last paragraph of 4.3.2

Note that the periodic boundary condition can be strictly imposed by modifying the neural nets unn and Unn by replacing the input x with the combination of sin(2πx/L) and cos(2πx/L), where L is the length of domain D. This is because any continuous 2π-periodic function can be written as a nonlinear function of sin(x) and cos(x). This modification simplifies the loss function by removing the loss due to the periodic boundary condition.

From my point of view, this like adding one additional layer directly after the input layer which map theta to sin(theta) and cos(theta), the point is that:

after the transformation, the neural network can not differentiate 0 and 2 pi, since sin(0)=sin(2 pi) and cos(0)=cos(2 pi).

As a result, no need to sample and enforce period boundary condition.

So I am wondering:

  1. What do you think of the problem of periodic boundary condition I mentioned above?
  2. What do you think of the method from https://doi.org/10.1137/19M1260141?
  3. How to implement the method from https://doi.org/10.1137/19M1260141 in DeepXDE?

Thank you for any help!

Shunyuan

smao-astro commented 4 years ago

Hi there,

OK, I made an implementation here: in file deepxde/maps/fnn.py

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math

from . import activations
from . import initializers
from . import regularizers
from .map import Map
from .. import config
from ..backend import tf
from ..utils import timing

class FNN(Map):
    """Feed-forward neural networks.
    """

    def __init__(
        self,
        layer_size,
        activation,
        kernel_initializer,
        regularization=None,
        dropout_rate=0,
        batch_normalization=None,
        periodic_layer=False,
    ):
        self.layer_size = layer_size
        self.activation = activations.get(activation)
        self.kernel_initializer = initializers.get(kernel_initializer)
        self.regularizer = regularizers.get(regularization)
        self.dropout_rate = dropout_rate
        self.batch_normalization = batch_normalization
        self.periodic_layer = periodic_layer

        self._modify = None

        super(FNN, self).__init__()

    @property
    def inputs(self):
        return self.x

    @property
    def outputs(self):
        return self.y

    @property
    def targets(self):
        return self.y_

    @timing
    def build(self):
        print("Building feed-forward neural network...")
        self.x = tf.placeholder(config.real(tf), [None, self.layer_size[0]])

        y = self.x
        if self.periodic_layer:
            # map r, theta to r*sin(theta) and r*cos(theta)
            y = tf.concat([y[:, 0:1]*tf.sin(y[:, 1:2]),
                           y[:, 0:1]*tf.cos(y[:, 1:2])],
                          axis=1
                          )
        for i in range(len(self.layer_size) - 2):
            if self.batch_normalization is None:
                y = self.dense(y, self.layer_size[i + 1], activation=self.activation)
            elif self.batch_normalization == "before":
                y = self.dense_batchnorm_v1(y, self.layer_size[i + 1])
            elif self.batch_normalization == "after":
                y = self.dense_batchnorm_v2(y, self.layer_size[i + 1])
            else:
                raise ValueError("batch_normalization")
            if self.dropout_rate > 0:
                y = tf.layers.dropout(y, rate=self.dropout_rate, training=self.dropout)
        self.y = self.dense(y, self.layer_size[-1])

        if self._modify is not None:
            self.y = self._modify(self.x, self.y)

        self.y_ = tf.placeholder(config.real(tf), [None, self.layer_size[-1]])
        self.built = True

    def outputs_modify(self, modify):
        self._modify = modify

    def dense(self, inputs, units, activation=None, use_bias=True):
        return tf.layers.dense(
            inputs,
            units,
            activation=activation,
            use_bias=use_bias,
            kernel_initializer=self.kernel_initializer,
            kernel_regularizer=self.regularizer,
        )

    @staticmethod
    def dense_weightnorm(inputs, units, activation=None, use_bias=True):
        shape = inputs.get_shape().as_list()
        fan_in = shape[1]
        W = tf.Variable(tf.random_normal([fan_in, units], stddev=math.sqrt(2 / fan_in)))
        g = tf.Variable(tf.ones(units))
        W = tf.nn.l2_normalize(W, axis=0) * g
        y = tf.matmul(inputs, W)
        if use_bias:
            b = tf.Variable(tf.zeros(units))
            y += b
        if activation is not None:
            return activation(y)
        return y

    def dense_batchnorm_v1(self, inputs, units):
        # FC - BN - activation
        y = self.dense(inputs, units, use_bias=False)
        y = tf.layers.batch_normalization(y, training=self.training)
        return self.activation(y)

    def dense_batchnorm_v2(self, inputs, units):
        # FC - activation - BN
        y = self.dense(inputs, units, activation=self.activation)
        return tf.layers.batch_normalization(y, training=self.training)

search for keyword periodic_layer for my modification.

Then I applied the method to a simple 2D Laplace's equation problem.

Screen Shot 2020-05-19 at 13 34 00

the implementation is here

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

import deepxde as dde
from deepxde.backend import tf

def gen_testdata():

    r = np.linspace(start=0., stop=1.)
    theta = np.linspace(start=0., stop=2*np.pi)

    r, theta = np.meshgrid(r, theta)
    r = r.reshape((-1,1))
    theta = theta.reshape((-1,1))

    y = r*np.cos(theta)

    X = np.hstack([r, theta])

    return X, y

def main():
    def pde(x, y):
        dy_x = tf.gradients(y, x)[0]
        dy_r, dy_theta = dy_x[:, 0:1], dy_x[:, 1:2]
        dy_rr = tf.gradients(dy_r, x)[0][:, 0:1]
        dy_thetatheta = tf.gradients(dy_theta, x)[0][:, 1:2]
        return x[:, 0:1]*dy_r+ x[:, 0:1]**2*dy_rr+dy_thetatheta

    geom = dde.geometry.Rectangle(xmin=[0., 0.], xmax=[1., 2.*np.pi])

    bc_rad = dde.DirichletBC(
        geom, lambda x: np.cos(x[:, 1:2]), lambda x, on_boundary: on_boundary and np.isclose(x[0], 1)
    )

    data = dde.data.PDE(
        geom, pde, bc_rad, num_domain=2540, num_boundary=80
    )
    net = dde.maps.FNN([2] + [20] * 3 + [1], "tanh", "Glorot normal", periodic_layer=True)
    model = dde.Model(data, net)

    model.compile("adam", lr=1e-3)
    model.train(epochs=15000)
    model.compile("L-BFGS-B")
    losshistory, train_state = model.train()
    dde.saveplot(losshistory, train_state, issave=True, isplot=True)

    X, y_true = gen_testdata()
    y_pred = model.predict(X)
    f = model.predict(X, operator=pde)
    print("Mean residual:", np.mean(np.absolute(f)))
    print("L2 relative error:", dde.metrics.l2_relative_error(y_true, y_pred))
    np.savetxt("test.dat", np.hstack((X, y_true, y_pred)))

if __name__ == "__main__":
    main()

the result is here

Step      Train loss              Test loss               Test metric
0         [1.84e-02, 8.30e-02]    [1.84e-02, 0.00e+00]    []
1000      [7.87e-06, 1.39e-06]    [7.87e-06, 0.00e+00]    []
2000      [4.02e-06, 7.71e-07]    [4.02e-06, 0.00e+00]    []
3000      [1.95e-06, 2.09e-07]    [1.95e-06, 0.00e+00]    []
4000      [1.06e-06, 2.01e-07]    [1.06e-06, 0.00e+00]    []
5000      [6.89e-07, 9.25e-08]    [6.89e-07, 0.00e+00]    []
6000      [5.10e-07, 1.32e-07]    [5.10e-07, 0.00e+00]    []
7000      [4.46e-07, 8.74e-07]    [4.46e-07, 0.00e+00]    []
8000      [3.35e-07, 1.32e-06]    [3.35e-07, 0.00e+00]    []
9000      [2.71e-07, 8.52e-08]    [2.71e-07, 0.00e+00]    []
10000     [2.31e-07, 1.23e-07]    [2.31e-07, 0.00e+00]    []
11000     [2.03e-07, 1.40e-07]    [2.03e-07, 0.00e+00]    []
12000     [1.70e-07, 1.80e-08]    [1.70e-07, 0.00e+00]    []
13000     [1.50e-07, 9.29e-09]    [1.50e-07, 0.00e+00]    []
14000     [1.37e-07, 1.06e-08]    [1.37e-07, 0.00e+00]    []
15000     [2.07e-07, 1.57e-05]    [2.07e-07, 0.00e+00]    []

Best model at step 14000:
  train loss: 1.47e-07
  test loss: 1.37e-07
  test metric: []

'train' took 88.334735 s

...continue after ignore some unrelated output

Step      Train loss              Test loss               Test metric
15000     [2.07e-07, 1.57e-05]    [2.07e-07, 0.00e+00]    []
15011     [1.24e-07, 5.58e-09]    [1.24e-07, 0.00e+00]    []

Best model at step 15011:
  train loss: 1.30e-07
  test loss: 1.24e-07
  test metric: []

'train' took 1.379301 s

Saving loss history to loss.dat ...
Saving training data to train.dat ...
Saving test data to test.dat ...
Predicting...
'predict' took 0.058991 s

Predicting...
'predict' took 0.760356 s

Mean residual: 0.00021337881
L2 relative error: 0.0001060648099557595

seems it works!

lululxvi commented 4 years ago

Great! But I have one question, in the code:

# map r, theta to r*sin(theta) and r*cos(theta)
y = tf.concat([y[:, 0:1]*tf.sin(y[:, 1:2]), y[:, 0:1]*tf.cos(y[:, 1:2])], axis=1)

Why do you remove r? In your case, it might be OK, because the solution only depends on r sin(theta) and r cos(theta). It might be more general to use r, sin(theta), and cos(theta), i.e.,

y = tf.concat([y[:, 0:1], tf.sin(y[:, 1:2]), tf.cos(y[:, 1:2])], axis=1)

How do you think?

smao-astro commented 4 years ago

Great! But I have one question, in the code:

# map r, theta to r*sin(theta) and r*cos(theta)
y = tf.concat([y[:, 0:1]*tf.sin(y[:, 1:2]), y[:, 0:1]*tf.cos(y[:, 1:2])], axis=1)

Why do you remove r? In your case, it might be OK, because the solution only depends on r sin(theta) and r cos(theta). It might be more general to use r, sin(theta), and cos(theta), i.e.,

y = tf.concat([y[:, 0:1], tf.sin(y[:, 1:2]), tf.cos(y[:, 1:2])], axis=1)

How do you think?

Hi Lu,

  1. Why did I use r*cos(theta) and r*sin(theta)?

Well, I guess the main reason is that I was thinking about the inverse transformation x = r*cos(theta) and y = r*sin(theta) when I working on the code. (notice that it is just a trick to force the neural network output periodic value, at least till now I do not think they have physical meaning)

  1. Is r*cos(theta) and r*sin(theta) good for any periodic problem?
  1. Is (r, cos(theta), sin(theta)) preferred over (r*cos(theta), r*sin(theta))?
smao-astro commented 4 years ago

@lululxvi

do you think I inserted the piece of code

        if self.periodic_layer:
            # map r, theta to r*sin(theta) and r*cos(theta)
            y = tf.concat([y[:, 0:1]*tf.sin(y[:, 1:2]),
                           y[:, 0:1]*tf.cos(y[:, 1:2])],
                          axis=1
                          )

at the proper line with the proper way without influence the speed? Sorry I am new to TensorFlow.

lululxvi commented 4 years ago

I agree with you that more experiments are required to verify. I also think sin(n*theta) and cos(n*theta) could be helpful in some cases.

Yes, your implementation is correct and proper. I will add an interface to the PDE class to allow users add "features" to the inputs, like this transform for the periodic problem.

lululxvi commented 4 years ago

The new feature is supported in DeepXDE, see the new example Laplace_disk.py.

smao-astro commented 4 years ago

The new feature is supported in DeepXDE, see the new example Laplace_disk.py.

Great! Thank you very much!

lululxvi commented 4 years ago

@tangqi Check this approach for periodic BC.

tangqi commented 4 years ago

Thanks! This seems an interesting way to impose periodic BC. It is also good that you allow users to add features as input