lululxvi / deepxde

A library for scientific machine learning and physics-informed learning
https://deepxde.readthedocs.io
GNU Lesser General Public License v2.1
2.63k stars 739 forks source link

Jax VariableValue callback #1689

Closed bonneted closed 4 months ago

bonneted commented 6 months ago

This is not the cleanest since it has to reset the model at each iteration to access the variables through self.model.external_trainable_variables. Maybe there is a better way using a similar approach as tf v1 ?

bonneted commented 4 months ago

I just added the L-BFGS is currently not implemented for JAX. But hopefully it will be the case soon as jaxopt (which support L-BFGS) is currenlty being merged into optax, the optimization library used by DeepXDE: https://github.com/google/jaxopt