RobertTLange / evosax

Evolution Strategies in JAX 🦎
Apache License 2.0
475 stars 44 forks source link

ParameterReshaper should support any PyTrees not only dicts #25

Closed ynotzort closed 1 year ago

ynotzort commented 1 year ago

The current version of the ParameterReshaper only supports Dicts for reshaping. This is fine for libraries like Flax and Haiku, but e.g. Equinox uses a DataClass internally which is not supported here. A simple solution would be to utilize the tree_unflatten and tree_flatten functions that are provided by jax.tree_utils. This would also enable support for lists, tuples, and any kind of PyTrees as long as they are registered with jax.

A minimal example:

from evosax import ParameterReshaper
import jax.numpy as jnp

params = [jnp.array([1,2,3]), jnp.array([[1,2,3], [4,5,6]])]
ParameterReshaper(placeholder_params=params)
# -> AttributeError: 'list' object has no attribute 'shape'

expected: should not throw an exception.

I just created a pull request for fixing that: https://github.com/RobertTLange/evosax/pull/26

RobertTLange commented 1 year ago

Awesome! Thank you for the PR. It is merged and will be included in the next release!