dynamicslab / pysindy

A package for the sparse identification of nonlinear dynamical systems from data
https://pysindy.readthedocs.io/en/latest/
Other
1.46k stars 324 forks source link

Support jax arrays (and optionally, cvxpy expressions) everywhere #574

Open Jacob-Stevens-Haas opened 1 month ago

Jacob-Stevens-Haas commented 1 month ago

See #562

This was thought to be easy, because in many cases jax arrays were an almost drop-in replacement for numpy arrays. However, they are far less amenable to subclassing. Why does this matter?

The codebase gained a lot of readability with AxesArray allowing arrays to dynamically know what their axes meant, even after indexing changed their shape. However, extending AxesArray to dynamically subclass either numpy.ndarray or jax.Array is impossible - even a static subclass of the latter is impossible.

Long term, we will need our own metadata type that carries around an array, it's type package (numpy or jax.numpy or cvxpy.numpy), its bidirectional mapping between axis index and axis meaning, and maybe even something from sympy. The hard part of this is done, since after all, AxesArray functionality only deals with the axes

Short term, we should expose our general expectations for axis definitions as global constants. This is still error prone, as the constants are incorrect for arrays that have changed shape due to indexing, but will be far more readable than magic numbers.