benmoseley / FBPINNs

Solve forward and inverse problems related to partial differential equations using finite basis physics-informed neural networks (FBPINNs)
MIT License
308 stars 63 forks source link

Empty gradient when the dim of u is 2 #8

Closed wenscarl closed 9 months ago

wenscarl commented 9 months ago

I am trying to modeling 1D maxwell equation: dH/dt - dE/dx = 0 ; dE/dt - dH/dx = source, The problem dim is set to (2,2) and I can implement this equation in pytorch but when comes to jax, I found the gradient of E, dE/dt, dE/dx are both empty([]) which results in nan in loss. Please help to identify the issue. Thanks @benmoseley

import jax
import jax.numpy as jnp
import numpy as np

from fbpinns.domains import RectangularDomainND
from fbpinns.problems import Problem
from fbpinns.decompositions import RectangularDecompositionND
from fbpinns.networks import FCN
from fbpinns.constants import Constants, get_subdomain_ws
from fbpinns.trainers import FBPINNTrainer, PINNTrainer

class FDTD2D(Problem):
    """Solves the time-dependent (1+1)D Maxwell equation with constant velocity
        u = [H, E]
        d H     dE
        ---- - ----  =  0
        dt       dx

        d E     dH
        ---- - ----  =  0
        dt      dx

        Boundary conditions:
        E(x,0) = exp( -(1/2)((x/sd)^2) )
        du
        --(x,0) = 0
        dt
    """

    @staticmethod
    def init_params(c=1, sd=1):

        static_params = {
            "dims":(2,2),
            "c":c,
            "sd":sd,
            }
        return static_params, {}

    @staticmethod
    def sample_constraints(all_params, domain, key, sampler, batch_shapes):

        # physics loss
        x_batch_phys = domain.sample_interior(all_params, key, sampler, batch_shapes[0])
        required_ujs_phys = (
            (0,(0,)),#dH / dx
            (1,(0,)),#dE / dx
            (0,(1,)),#dH / dt
            (1,(1,)),#dE /dt
        )
        return [[x_batch_phys, required_ujs_phys],]

    @staticmethod
    def constraining_fn(all_params, x_batch, u):
        c = all_params["static"]["problem"]["c"]
        sd = all_params["static"]["problem"]["sd"]
        t = x_batch[:,1:2]

        u = (jax.nn.tanh(c*t/(2*sd))**2)*u# constrains u(x,y,0) = u_t(x,y,0) = 0
        return u

    @staticmethod
    def loss_fn(all_params, constraints):
        c = all_params["static"]["problem"]["c"]
        sd = all_params["static"]["problem"]["sd"]
        x_batch, dHdx, dEdx, dHdt, dEdt = constraints[0]
#        jax.debug.print("ret {}", dEdx) # dEdx and dEdt being [] while dHdx and dHdt are OK

        x, t = x_batch[:,0:1], x_batch[:,1:2]

        e = -0.5*(x**2 + t**2)/(sd**2)
        s = 2e3*(1+e)*jnp.exp(e)# ricker source term

        phys1 = jnp.mean((dHdx - dEdt - s)**2)
        phys2 = jnp.mean((dEdx - dHdt)**2)
        phys = phys1 + phys2
        return phys

    @staticmethod
    def exact_solution(all_params, x_batch, batch_shape):

        key = jax.random.PRNGKey(0)
        return jax.random.normal(key, (x_batch.shape[0],1))

subdomain_xs = [np.linspace(-1,1,5), np.linspace(0,1,5)]
subdomain_ws = get_subdomain_ws(subdomain_xs, 1.9)

c = Constants(
    run="test",
    domain=RectangularDomainND,
    domain_init_kwargs=dict(
        xmin=np.array([-1,0]),
        xmax=np.array([1,1]),
    ),
    problem=FDTD2D,
    problem_init_kwargs=dict(
        c=1, sd=0.1,
    ),
    decomposition=RectangularDecompositionND,
    decomposition_init_kwargs=dict(
        subdomain_xs=subdomain_xs,
        subdomain_ws=subdomain_ws,
        unnorm=(0.,1.),
    ),
    network=FCN,
    network_init_kwargs=dict(
        layer_sizes=[2,32,2],
    ),
    ns=((100,50),),
    n_test=(100,5),
    n_steps=5000,
    optimiser_kwargs=dict(learning_rate=1e-3),
    summary_freq=200,
    test_freq=200,
    show_figures=True,
    clear_output=True,
)

#run = FBPINNTrainer(c)
#run.train()

c["network_init_kwargs"] = dict(layer_sizes=[2,128,128,2])
run = PINNTrainer(c)
run.train()
ciri1998 commented 5 months ago

Good morning, how did you solve the problem? I am in a similar situation. I saw that by imposing in plot_trainer_2D u_test[:,0] and excluding histograms you can get results. Did you follow a similar path or did you do it differently?

In particular, to add a second variablemto the case, you just modified the problem class and c = constant? Did you modify other parts of the code?