lululxvi / deepxde

A library for scientific machine learning and physics-informed learning
https://deepxde.readthedocs.io
GNU Lesser General Public License v2.1
2.63k stars 739 forks source link

fix external variable initialization #1775

Open bonneted opened 3 months ago

bonneted commented 3 months ago

I've faced two bugs when trying to implement : https://github.com/lu-group/sbinn/blob/b2c1c94d6564732189722f6e6772af0f63cb0d8c/sbinn/sbinn_tf.py#L8

lululxvi commented 3 months ago

Could you point out an example for using this code?

bonneted commented 3 months ago

Here : https://github.com/bonneted/sbinn/blob/main/sbinn/sbinn_jax.py

The implementation of sbinn using JAX. We first train the model without the external variables :

    def ODE(t, y, unknowns=[var.value for var in var_list_]):
    ...

    model.compile("adam", lr=1e-3, loss_weights=[0, 0, 0, 0, 0, 0, 1e-2])
    model.train(epochs=firsttrain, display_every=1000)
    model.compile(
        "adam",
        lr=1e-3,
        loss_weights=[1, 1, 1e-2, 1, 1, 1, 1e-2],
        external_trainable_variables=var_list_,
    )
    variablefilename = "variables.csv"
    variable = dde.callbacks.VariableValue(
        var_list_, period=callbackperiod, filename=variablefilename
    )
    losshistory, train_state = model.train(
        epochs=maxepochs, display_every=1000, callbacks=[variable]
    )

For this first train, we want to use the default unknowns argument for the ODE

lululxvi commented 3 months ago

The code modification seems necessary. But there is another example https://github.com/lululxvi/deepxde/blob/master/examples/pinn_inverse/Lorenz_inverse.py , which works well (at least worked earlier).

bonneted commented 3 months ago

This one was already working well because there is no pertaining without the external variables. The model is only compiled with the external trainable variables :

model.compile(
    "adam", lr=0.001, external_trainable_variables=external_trainable_variables
)
losshistory, train_state = model.train(iterations=20000, callbacks=[variable])

The problem occurs when we compile without the external trainable variables, which is when we want the PDE to use the default unknowns argument.

lululxvi commented 3 months ago

The code seems OK. But the underlying logic becomes extremely complicated now.

In fact, you can simply add external_trainable_variables in the first compile. As the PDE loss weight is 0, those variables won't get updated any way.

bonneted commented 3 months ago

That's true in that case, but it can be interesting to start training the model with frozen parameters (for example https://doi.org/10.1126/sciadv.abk0644) Moreover, it would mean that putting default unknowns values for the PDE is useless and misleading as they can never be used.

lululxvi commented 3 months ago

Please resolve the conflicts.

bonneted commented 2 months ago

I've resolved the conflict based on your improved logic. In the JAX backend conditional I added the possibility that there are no external trainable variables but a default value available.