Open siwuxei opened 1 month ago
Hi @siwuxei I think you are right that this gradient cannot be specified in required_ujs_phys
. However I think you might be able to use the chain rule to decompose it into computable quantities?
$\frac{\partial}{\partial y}\left[ \left( v+\frac{v_t}{\sigma _k} \right) \frac{\partial k}{\partial y} \right] = \left( v + \frac{v_t}{\sigma_k} \right) \frac{\partial^2 k}{\partial y^2} + \left( \frac{\partial v}{\partial y} + \frac{\sigma_k \frac{\partial v_t}{\partial y} - v_t \frac{\partial \sigma_k}{\partial y}}{\sigma_k^2} \right) \frac{\partial k}{\partial y}$
To avoid manipulating gradients by hand like this, another option is to define custom composite jax gradient transformations in trainers.py
(changing https://github.com/benmoseley/FBPINNs/blob/main/fbpinns/trainers.py#L197), this is more involved but more flexible
Hi @siwuxei I think you are right that this gradient cannot be specified in
required_ujs_phys
. However I think you might be able to use the chain rule to decompose it into computable quantities? ∂∂y[(v+vtσk)∂k∂y]=(v+vtσk)∂2k∂y2+(∂v∂y+σk∂vt∂y−vt∂σk∂yσk2)∂k∂yTo avoid manipulating gradients by hand like this, another option is to define custom composite jax gradient transformations in
trainers.py
(changing https://github.com/benmoseley/FBPINNs/blob/main/fbpinns/trainers.py#L197), this is more involved but more flexible
Thank you for your response!
trainers.py
could be a more permanent solution. I am currently working on getting up to speed with the relevant JAX APIs, so progress is slow.Thanks again for sharing your work, which has motivated me to try new things. 😄
Thank you for sharing your work, it's very interesting! The new version using JAX is indeed much faster, but I'm not very familiar with it (I use PyTorch more). Recently, when solving a PDE, I encountered this problem:
$\mathrm{Loss}_1=\frac{\partial u}{\partial x}+\frac{\partial v}{\partial y}$
$\mathrm{Loss}_2=\frac{\partial}{\partial y}\left[ \left( v+\frac{v_t}{\sigma _k} \right) \frac{\partial k}{\partial y} \right] $
$\sigma _k$ is given, the input of the neural network is $x$, $y$, and the output of the neural network is $u$, $v$, and $k$.
When constructing the physical loss $Loss_2$ of the above equation, $\frac{\partial k}{\partial y}$ needs to be used. The current FBPINN framework uses
required_ujs_phys
to callback gradients, as shown in the following code framework:This causes a problem: I can't calculate the gradient of $\frac{\partial}{\partial y}\left[ \left( v+\frac{v_t}{\sigma _k} \right) \frac{\partial k}{\partial y} \right] $, because it's a mixed second-order gradient that requires the first-order $\frac{\partial k}{\partial y}$ to calculate the final gradient. It can't be recalled through
required_ujs_phys
.This kind of composite gradient is quite common. Do you have any good suggestions to solve this problem?
Thank you for your reading!