brentyi / jaxls

Sparse nonlinear least squares in JAX
MIT License
179 stars 12 forks source link

PyTree registration refactor #3

Closed brentyi closed 3 years ago

brentyi commented 3 years ago

@SuperN1ck would be nice if you could skim through this and let me know if it makes sense!

Thing that affects you: after this is merged, factor classes will need to be decorated with @register_pytree_dataclass*.

Main change is to refactor the logic used for designating static dataclass fields. Allows us to:

*I may just pull all this dataclass code into its own library and do some renaming, it seems somewhat generally useful and could help reduce boilerplate in jaxlie as well.