Closed bonneted closed 4 months ago
I just added the L-BFGS is currently not implemented for JAX. But hopefully it will be the case soon as jaxopt (which support L-BFGS) is currenlty being merged into optax, the optimization library used by DeepXDE: https://github.com/google/jaxopt
This is not the cleanest since it has to reset the model at each iteration to access the variables through self.model.external_trainable_variables. Maybe there is a better way using a similar approach as tf v1 ?