benmoseley / FBPINNs

Solve forward and inverse problems related to partial differential equations using finite basis physics-informed neural networks (FBPINNs)
MIT License
293 stars 59 forks source link

Higher-Order Gradient Derivative Problem #17

Open siwuxei opened 1 month ago

siwuxei commented 1 month ago

Thank you for sharing your work, it's very interesting! The new version using JAX is indeed much faster, but I'm not very familiar with it (I use PyTorch more). Recently, when solving a PDE, I encountered this problem:

$\mathrm{Loss}_1=\frac{\partial u}{\partial x}+\frac{\partial v}{\partial y}$

$\mathrm{Loss}_2=\frac{\partial}{\partial y}\left[ \left( v+\frac{v_t}{\sigma _k} \right) \frac{\partial k}{\partial y} \right] $

$\sigma _k$ is given, the input of the neural network is $x$, $y$, and the output of the neural network is $u$, $v$, and $k$.

When constructing the physical loss $Loss_2$ of the above equation, $\frac{\partial k}{\partial y}$ needs to be used. The current FBPINN framework uses required_ujs_phys to callback gradients, as shown in the following code framework:

    def sample_constraints(all_params, domain, key, sampler, batch_shapes):

        # physics loss
        y_batch_phys = domain.sample_interior(all_params, key, sampler, batch_shapes[0])
        required_ujs_phys = (
            (0,()),     # u
            (1,()),     # v
            (2,()),     # k
            (2,(1,)),   # k_y
        )

        return [[y_batch_phys, required_ujs_phys]]

This causes a problem: I can't calculate the gradient of $\frac{\partial}{\partial y}\left[ \left( v+\frac{v_t}{\sigma _k} \right) \frac{\partial k}{\partial y} \right] $, because it's a mixed second-order gradient that requires the first-order $\frac{\partial k}{\partial y}$ to calculate the final gradient. It can't be recalled through required_ujs_phys.

This kind of composite gradient is quite common. Do you have any good suggestions to solve this problem?

Thank you for your reading!

benmoseley commented 1 month ago

Hi @siwuxei I think you are right that this gradient cannot be specified in required_ujs_phys. However I think you might be able to use the chain rule to decompose it into computable quantities? $\frac{\partial}{\partial y}\left[ \left( v+\frac{v_t}{\sigma _k} \right) \frac{\partial k}{\partial y} \right] = \left( v + \frac{v_t}{\sigma_k} \right) \frac{\partial^2 k}{\partial y^2} + \left( \frac{\partial v}{\partial y} + \frac{\sigma_k \frac{\partial v_t}{\partial y} - v_t \frac{\partial \sigma_k}{\partial y}}{\sigma_k^2} \right) \frac{\partial k}{\partial y}$

To avoid manipulating gradients by hand like this, another option is to define custom composite jax gradient transformations in trainers.py (changing https://github.com/benmoseley/FBPINNs/blob/main/fbpinns/trainers.py#L197), this is more involved but more flexible

siwuxei commented 1 month ago

Hi @siwuxei I think you are right that this gradient cannot be specified in required_ujs_phys. However I think you might be able to use the chain rule to decompose it into computable quantities? ∂∂y[(v+vtσk)∂k∂y]=(v+vtσk)∂2k∂y2+(∂v∂y+σk∂vt∂y−vt∂σk∂yσk2)∂k∂y

To avoid manipulating gradients by hand like this, another option is to define custom composite jax gradient transformations in trainers.py (changing https://github.com/benmoseley/FBPINNs/blob/main/fbpinns/trainers.py#L197), this is more involved but more flexible

Thank you for your response!

  1. Manually decomposing the gradients is indeed a solution. The equation I am solving is more complex than this example, so it might be a bit difficult to work with. Nonetheless, I'll give it a try. O(∩_∩)O
  2. Modifying trainers.py could be a more permanent solution. I am currently working on getting up to speed with the relevant JAX APIs, so progress is slow.

Thanks again for sharing your work, which has motivated me to try new things. 😄