lululxvi / deepxde

A library for scientific machine learning and physics-informed learning
https://deepxde.readthedocs.io
GNU Lesser General Public License v2.1
2.47k stars 712 forks source link

Separable-PINN in DeepXDE #1776

Open bonneted opened 2 weeks ago

bonneted commented 2 weeks ago

This is an implementation of the SPINN model: https://jwcho5576.github.io/spinn.github.io/

The code for the network architecture (snn.py) is directly adapted from the original paper (https://github.com/stnamjef/SPINN)

I've achieved really fast convergence with this implementation of SPINN compared to PINN (similar to the paper claim), for both forward and inverse quantification on the linear elastic plate problem. Forward comparison : https://github.com/lululxvi/deepxde/assets/53513604/499a961d-748c-458f-be99-56156b516ace

Inverse with PINN : https://github.com/lululxvi/deepxde/assets/53513604/ab89554d-b82b-406a-8d73-05b3f72a3961

Inverse with SPINN : https://github.com/lululxvi/deepxde/assets/53513604/07171442-ea03-48b4-87a6-8b5094f6809c

The implementation was more complicated than expected for the following reasons: due to its architecture, SPINN takes an input of size n and outputs an array of size ndim (it does the cartesian product of each coordinate) : `(n,2) --> SPINN --> n2`

This brings some difficulty with how inputs are handled in data.pde. Indeed, all inputs are concatenated (PDE and BCS points) and throw the net simultaneously. So if we have n_PDE PDE points and n_BC BC points we will end up with (n_PDE+n_BC)**2 points instead of n_PDE**2+n_BC**2

I tried to find a workaround with minimal changes to model.py, and came up with the following: adding a list_handler decorator to the outputs function in JAX so that it can handle list inputs by applying the function to each input and then concatenates.

I then modified the pde.py file by adding a is_SPINN argument, if true, PDE and BC inputs are put together in a list instead of stacked. The bcs_start should also be modified as the outputs sizes no longer equal the inputs.

I understand that this brings a lot of changes to data.pde, so another possibility is to create a separate data subclass dedicated to SPINN so that the data.pde class isn't overly complicated.

lululxvi commented 2 weeks ago

There are too many modifications. We can start with a mathematically equivalent (but slow speed) implementation by repeating the n inputs to n**2. This is similar to DeepONet

https://github.com/lululxvi/deepxde/blob/ad6399b27c3a0702abd24d45f889ca55497476c2/deepxde/nn/tensorflow/deeponet.py#L17

vs DeepONetCartesianProd

https://github.com/lululxvi/deepxde/blob/ad6399b27c3a0702abd24d45f889ca55497476c2/deepxde/nn/tensorflow/deeponet.py#L153

Then the code change would be very minimal.

bonneted commented 23 hours ago

OK, I'll try that when I have more time, hopefully soon.