lululxvi / deepxde

A library for scientific machine learning and physics-informed learning
https://deepxde.readthedocs.io
GNU Lesser General Public License v2.1
2.76k stars 760 forks source link

Question related to heat equation in 2 dimensions #471

Closed Saransh-cpp closed 2 years ago

Saransh-cpp commented 2 years ago

Greeting,

I am new to DeepXDE and I am trying to solve the heat equation in 2 dimensions (apologies if I am making a silly mistake) but I am getting an error that is probably related to the BCs. Could someone please help me out and let me know where I am going wrong?

Thank you!

Code

import numpy as np

import deepxde as dde
from deepxde.backend import tf

c = 1
n = 1

def pde(x, u):
    u_t = dde.grad.jacobian(u, x, j=2)
    u_xx = dde.grad.hessian(u, x, i=0, j=0)
    u_yy = dde.grad.hessian(u, x, i=1, j=1)
    return u_t - (c ** 2) * (u_xx + u_yy)

def boundary_u(x, on_boundary):
    return on_boundary and np.isclose(x[1], 1)

def boundary_b(x, on_boundary):
    return on_boundary and np.isclose(x[1], 0)

def boundary_r_and_l(x, on_boundary):
    return on_boundary and (np.isclose(x[0], 0) or np.isclose(x[0], 1))

geom = dde.geometry.Rectangle(xmin=[x_min, y_min], xmax=[x_max, y_max])
timedomain = dde.geometry.TimeDomain(t_min, t_max)
geomtime = dde.geometry.GeometryXTime(geom, timedomain)

d_bc_b = dde.DirichletBC(geom, lambda x: np.sin(n * np.pi * x[:, 0:1] / L), boundary_b)
d_bc_u = dde.DirichletBC(geom, lambda x: 0, boundary_u)
n_bc = dde.NeumannBC(geom, lambda X: 0, boundary_r_and_l)

data = dde.data.TimePDE(
    geomtime,
    pde,
    [
        d_bc_l,
        d_bc_r,
        n_bc
    ],
    num_domain=2540,
    num_boundary=80,
    num_initial=160,
    num_test=2540,
)

net = dde.maps.FNN([3] + [32] * 3 + [1], "tanh", "Glorot uniform")

model = dde.Model(data, net)

model.compile("adam", lr=0.001)
model.train(epochs=20000)
model.compile("L-BFGS")
losshistory, train_state = model.train()

dde.saveplot(losshistory, train_state, issave=True, isplot=True)

Error

ValueError: operands could not be broadcast together with shapes (2000,3) (2,)

Stacktrace

Traceback (most recent call last):
  File "c:\Users\Saransh\Saransh_softwares\Learning Projects\python\NNPDEs\heat_2D.py", line 56, in <module>
    data = dde.data.TimePDE(
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\data\pde.py", line 279, in __init__
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\data\pde.py", line 126, in __init__
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\utils\internal.py", line 41, in wrapper
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\data\pde.py", line 165, in train_next_batch     
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\utils\internal.py", line 41, in wrapper
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\data\pde.py", line 239, in bc_points
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\data\pde.py", line 239, in <listcomp>
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\icbcs\boundary_conditions.py", line 54, in collocation_points
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\icbcs\boundary_conditions.py", line 51, in filter
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\geometry\geometry_nd.py", line 38, in on_boundary
  File "<__array_function__ internals>", line 5, in isclose
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\numpy\core\numeric.py", line 2358, in isclose
    return within_tol(x, y, atol, rtol)
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\numpy\core\numeric.py", line 2339, in within_tol
    return less_equal(abs(x-y), atol + rtol * abs(y))
ValueError: operands could not be broadcast together with shapes (2780,3) (2,)
lululxvi commented 2 years ago

x is 1d array. np.isclose(x[0], 1)

Saransh-cpp commented 2 years ago

Thank you for the response, Dr. Lu. I tried changing x[:, 0:1] to x[0] in both the boundaries but I am still getting the same error -

Update: I tweaked the BCs a bit, as per your suggestions. I have edited the code above to reflect the changes.

Traceback (most recent call last):
  File "c:\Users\Saransh\Saransh_softwares\Learning Projects\python\NNPDEs\heat_2D.py", line 56, in <module>
    data = dde.data.TimePDE(
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\data\pde.py", line 279, in __init__
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\data\pde.py", line 126, in __init__
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\utils\internal.py", line 41, in wrapper
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\data\pde.py", line 165, in train_next_batch     
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\utils\internal.py", line 41, in wrapper
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\data\pde.py", line 239, in bc_points
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\data\pde.py", line 239, in <listcomp>
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\icbcs\boundary_conditions.py", line 54, in collocation_points
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\icbcs\boundary_conditions.py", line 51, in filter
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\geometry\geometry_nd.py", line 38, in on_boundary
  File "<__array_function__ internals>", line 5, in isclose
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\numpy\core\numeric.py", line 2358, in isclose
    return within_tol(x, y, atol, rtol)
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\numpy\core\numeric.py", line 2339, in within_tol
    return less_equal(abs(x-y), atol + rtol * abs(y))
ValueError: operands could not be broadcast together with shapes (2780,3) (2,)
Saransh-cpp commented 2 years ago

Update - I was passing geom instead of geomtime in the BCs. Now the error described above is gone but the code gives a TypeError. The code now -

import numpy as np

import deepxde as dde
from deepxde.backend import tf

c = 1
n = 1

def pde(x, u):
    u_t = dde.grad.jacobian(u, x, j=2)
    u_xx = dde.grad.hessian(u, x, i=0, j=0)
    u_yy = dde.grad.hessian(u, x, i=1, j=1)
    return u_t - (c ** 2) * (u_xx + u_yy)

def boundary_u(x, on_boundary):
    return on_boundary and np.isclose(x[1], 1)

def boundary_b(x, on_boundary):
    return on_boundary and np.isclose(x[1], 0)

def boundary_r_and_l(x, on_boundary):
    return on_boundary and (np.isclose(x[0], 0) or np.isclose(x[0], 1))

geom = dde.geometry.Rectangle(xmin=[x_min, y_min], xmax=[x_max, y_max])
timedomain = dde.geometry.TimeDomain(t_min, t_max)
geomtime = dde.geometry.GeometryXTime(geom, timedomain)

d_bc_b = dde.DirichletBC(geomtime, lambda x: np.sin(n * np.pi * x[:, 0:1] / L), boundary_b)
d_bc_u = dde.DirichletBC(geomtime, lambda x: 0, boundary_u)
n_bc = dde.NeumannBC(geomtime, lambda X: 0, boundary_r_and_l)

data = dde.data.TimePDE(
    geomtime,
    pde,
    [
        d_bc_b,
        d_bc_u,
        n_bc
    ],
    num_domain=2540,
    num_boundary=80,
    num_initial=160,
    num_test=2540,
)

net = dde.maps.FNN([3] + [32] * 3 + [1], "tanh", "Glorot uniform")

model = dde.Model(data, net)

model.compile("adam", lr=0.001)
model.train(epochs=20000)
model.compile("L-BFGS")
losshistory, train_state = model.train()

dde.saveplot(losshistory, train_state, issave=True, isplot=True)

The error now -

Traceback (most recent call last):
  File "c:\Users\Saransh\Saransh_softwares\Learning Projects\python\NNPDEs\heat_2D.py", line 91, in <module>
    model.train(epochs=20000)
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\utils\internal.py", line 26, in wrapper
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\model.py", line 353, in train
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\model.py", line 491, in _test
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\model.py", line 284, in _outputs_losses
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\tensorflow\python\eager\def_function.py", line 885, in __call__
    result = self._call(*args, **kwds)
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\tensorflow\python\eager\def_function.py", line 917, in _call
    return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\tensorflow\python\eager\function.py", line 3038, in __call__
    filtered_flat_args) = self._maybe_define_function(args, kwargs)
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\tensorflow\python\eager\function.py", line 3463, in _maybe_define_function       
    graph_function = self._create_graph_function(args, kwargs)
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\tensorflow\python\eager\function.py", line 3298, in _create_graph_function       
    func_graph_module.func_graph_from_py_func(
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\tensorflow\python\framework\func_graph.py", line 1007, in func_graph_from_py_func    func_outputs = python_func(*func_args, **func_kwargs)
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\tensorflow\python\eager\def_function.py", line 668, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\tensorflow\python\framework\func_graph.py", line 994, in wrapper
    raise e.ag_error_metadata.to_exception(e)
TypeError: in user code:

    C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\model.py:169 outputs_losses  *
        losses = self.data.losses(targets, outputs_, loss_fn, self)
    C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\data\pde.py:158 losses  *
        error = bc.error(self.train_x, model.net.inputs, outputs, beg, end)
    C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\icbcs\boundary_conditions.py:92 error  *
        return self.normal_derivative(X, inputs, outputs, beg, end) - values
    C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\deepxde-0.13.6-py3.9.egg\deepxde\icbcs\boundary_conditions.py:59 normal_derivative  *        return bkd.sum(dydx * n, 1, keepdims=True)
    C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\tensorflow\python\ops\math_ops.py:1383 binary_op_wrapper
        raise e
    C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\tensorflow\python\ops\math_ops.py:1367 binary_op_wrapper
        return func(x, y, name=name)
    C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\tensorflow\python\ops\math_ops.py:1710 _mul_dispatch
        return multiply(x, y, name=name)
    C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\tensorflow\python\util\dispatch.py:206 wrapper
        return target(*args, **kwargs)
    C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\tensorflow\python\ops\math_ops.py:530 multiply
        return gen_math_ops.mul(x, y, name)
    C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\tensorflow\python\ops\gen_math_ops.py:6244 mul
        _, _, _op, _outputs = _op_def_library._apply_op_helper(
    C:\Users\Saransh\Saransh_softwares\Python_3.9\lib\site-packages\tensorflow\python\framework\op_def_library.py:555 _apply_op_helper
        raise TypeError(

    TypeError: Input 'y' of 'Mul' Op has type float32 that does not match type float64 of argument 'x'.
lululxvi commented 2 years ago

What is your backend?

Saransh-cpp commented 2 years ago

tensorflow, should I use v1?

Saransh-cpp commented 2 years ago

The code works with tensorflow.compat.v1! Thank you for all the help, Dr. Lu. One last question -

I am using the following line to plot the loss and the solution using DeepXDE -

dde.saveplot(losshistory, train_state, issave=True, isplot=True)

but for some reason, it only plots the loss and not the solution. In the examples, this line does plot the solution. Am I doing something wrong?

Edit: I think that DeepXDE does not show a plot where >= 3 dimensions are involved (due to obvious limitations). Thanks again for all the help, Dr. Lu!