Open engsbk opened 2 years ago
Also, it'll be helpful to know how can I change the source function?
sin (2 pi x).....(1)
to
sin (2 pi t).....(2)
Simply modifying the initial condition that contains the original equation (1) does not satisfy all t instances except t = 0.
Hi @engsbk sorry for replying slowly. Please check out the latest FBPINN release - it is a major update and there is no need to update the gradients of the FBPINN by hand when applying constraining operators, this is now done automatically using autodiff, so your workflow should be much simpler. Also the memory performance was improved, so you may be able to train/test with more points now
Hi @engsbk I am also interested in implementing source term. I have done it for 1D wave equation by adding a term in physics_loss but 2D seems problematic. Not sure what the ansatz should look like. @benmoseley could you give a example of how source term with zero IC should be implemented? Thanks!
Here is some code for the (2+1)D wave equation, with zero ICs and a source term. Please ignore the exact solution - you would need to add e.g. FD modelling code to compare to this.
import jax
import jax.numpy as jnp
import numpy as np
from fbpinns.domains import RectangularDomainND
from fbpinns.problems import Problem
from fbpinns.decompositions import RectangularDecompositionND
from fbpinns.networks import FCN
from fbpinns.constants import Constants, get_subdomain_ws
from fbpinns.trainers import FBPINNTrainer, PINNTrainer
class WaveEquation3D(Problem):
"""Solves the time-dependent (2+1)D wave equation with constant velocity
d^2 u d^2 u 1 d^2 u
----- + ----- - --- ----- = s(x,y,t)
dx^2 dy^2 c^2 dt^2
Boundary conditions:
u(x,y,0) = 0
du
--(x,y,0) = 0
dt
"""
@staticmethod
def init_params(c=1, sd=1):
static_params = {
"dims":(1,3),
"c":c,
"sd":sd,
}
return static_params, {}
@staticmethod
def sample_constraints(all_params, domain, key, sampler, batch_shapes):
# physics loss
x_batch_phys = domain.sample_interior(all_params, key, sampler, batch_shapes[0])
required_ujs_phys = (
(0,(0,0)),
(0,(1,1)),
(0,(2,2)),
)
return [[x_batch_phys, required_ujs_phys],]
@staticmethod
def constraining_fn(all_params, x_batch, u):
c = all_params["static"]["problem"]["c"]
sd = all_params["static"]["problem"]["sd"]
t = x_batch[:,2:3]
u = (jax.nn.tanh(c*t/(2*sd))**2)*u# constrains u(x,y,0) = u_t(x,y,0) = 0
return u
@staticmethod
def loss_fn(all_params, constraints):
c = all_params["static"]["problem"]["c"]
sd = all_params["static"]["problem"]["sd"]
x_batch, uxx, uyy, utt = constraints[0]
x, y, t = x_batch[:,0:1], x_batch[:,1:2], x_batch[:,2:3]
e = -0.5*(x**2 + y**2 + t**2)/(sd**2)
s = 2e3*(1+e)*jnp.exp(e)# ricker source term
phys = (uxx + uyy) - (1/c**2)*utt - s
return jnp.mean(phys**2)
@staticmethod
def exact_solution(all_params, x_batch, batch_shape):
key = jax.random.PRNGKey(0)
return jax.random.normal(key, (x_batch.shape[0],1))
subdomain_xs = [np.linspace(-1,1,5), np.linspace(-1,1,5), np.linspace(0,1,5)]
subdomain_ws = get_subdomain_ws(subdomain_xs, 1.9)
c = Constants(
run="test",
domain=RectangularDomainND,
domain_init_kwargs=dict(
xmin=np.array([-1,-1,0]),
xmax=np.array([1,1,1]),
),
problem=WaveEquation3D,
problem_init_kwargs=dict(
c=1, sd=0.1,
),
decomposition=RectangularDecompositionND,
decomposition_init_kwargs=dict(
subdomain_xs=subdomain_xs,
subdomain_ws=subdomain_ws,
unnorm=(0.,1.),
),
network=FCN,
network_init_kwargs=dict(
layer_sizes=[3,32,1],
),
ns=((50,50,50),),
n_test=(100,100,5),
n_steps=5000,
optimiser_kwargs=dict(learning_rate=1e-3),
summary_freq=200,
test_freq=200,
show_figures=True,
clear_output=True,
)
#run = FBPINNTrainer(c)
#run.train()
c["network_init_kwargs"] = dict(layer_sizes=[3,64,64,1])
run = PINNTrainer(c)
run.train()
Thank you for the innovative contribution!
I tried modifying the wave 3D problem to have the following boundary conditions:
u(x,y,0) = 0 u(0,0,t) = 2 sin (2 pi t) #time-dependent source
in this way:
I also made some changes for the FD file to be this way:
Mainly attempting to change the IC and add a time-dependent source term to the equation so it becomes:
where,
but the results are as shown in the image. So, my questions are:
The results in the image were executed with these batch sizes:
because of the limited memory on my GPU.
batch_size_test
without getting OOM error?Thanks again! Looking forward to your reply.