RobertTLange / evosax

Evolution Strategies in JAX 🦎
Apache License 2.0
475 stars 44 forks source link

ParameterReshaper should support any PyTrees not only dicts #26

Closed ynotzort closed 1 year ago

ynotzort commented 1 year ago

See issue: https://github.com/RobertTLange/evosax/issues/25

The current version of the ParameterReshaper only supports Dicts for reshaping. This is fine for libraries like Flax and Haiku, but e.g. Equinox uses a DataClass internally which is not supported here. A simple solution would be to utilize the tree_unflatten and tree_flatten functions that are provided by jax.tree_utils. This would also enable support for lists, tuples, and any kind of PyTrees as long as they are registered with jax.

RobertTLange commented 1 year ago

Thank you very much! This is great.