ucl-bug / jaxdf

A JAX-based research framework for writing differentiable numerical simulators with arbitrary discretizations
GNU Lesser General Public License v3.0
117 stars 7 forks source link

Consider moving to equinox modules #121

Closed astanziola closed 9 months ago

astanziola commented 1 year ago

Following the suggestion in #110 , it probably makes sense to generally allow for a custom backend or, if that's too complicated, leverage equinox and make it the default backend.

astanziola commented 1 year ago

At the moment, it is not obvious what the advantages of equinox are compared to declaring Fields as custom PyTrees, since the main features of equinox is to allow for the composition of multiple modules together and we really don't have fields within fields*.

On the other hand, there does seem to be no harm in migrating the Field classes to equinox, except for the fact of adding a dependency to the package. Let's keep this open as more documentation is written up, as the cost-balance trade off may easily change as requirements arise.

*Note that, currently, the parameters of a field can be an equinox module if needed. This may be used to make a better interface for the Continuous field.