To reduce the amount of code required to maintain in GPJax, we should explore migrating our PyTree module to either Equinox or Flax. Aspects to consider would be
Ease of defining the parameters' support e.g., the variance is defined on the strictly positive real line
State of parameters' value e.g., as leaves in the PyTree (as currently) or as a stateful objects (as in Flax)
Compatability with wider ecosystems i.e., optimistix, BlackJax, NumPyro
Feature Request
To reduce the amount of code required to maintain in GPJax, we should explore migrating our
PyTree
module to either Equinox or Flax. Aspects to consider would beRelevant libraries