RobertTLange / evosax

Evolution Strategies in JAX 🦎
Apache License 2.0
499 stars 42 forks source link

ParameterReshaper does not work with Haiku #20

Closed vuoristo closed 2 years ago

vuoristo commented 2 years ago

evosax.ParameterReshaper assumes a naming convention for the network parameters, which reserves the / character for evosax internal use. However, dm-haiku also uses / in its parameter naming convention, causing a clash between the two packages. Demo below. Tested at 36e9f75aa9024091db85382d3cb934fc8ac11da0

import jax
import jax.numpy as jnp
import jax.random as jrandom
import haiku as hk

from evosax import ParameterReshaper, OpenES

class Net(hk.Module):
    def __init__(self, name=None):
        super(Net, self).__init__(name=name)

    def __call__(self, inputs):
        return hk.Linear(1)(inputs)

init_theta, apply_theta = hk.without_apply_rng(hk.transform(lambda inputs: Net()(inputs)))
k = jrandom.PRNGKey(0)
theta = init_theta(k, jnp.ones(1))
print("Printing the names of the parameters\n", jax.tree_map(lambda a: (), theta))
outputs = apply_theta(theta, jnp.ones(1))
param_reshaper = ParameterReshaper(theta)
es_strategy = OpenES(
    num_dims=param_reshaper.total_params,
    popsize=2,
)
es_params = es_strategy.default_params
es_state = es_strategy.initialize(k, es_params)
x, es_state = es_strategy.ask(k, es_state, es_params)

# This creates the problem by renaming the parameters
theta = param_reshaper.reshape(x)
print("Printing the names of the parameters\n", jax.tree_map(lambda a: (), theta))
outputs = jax.vmap(apply_theta, (0, None))(theta, jnp.ones(1))

Output:

python -m param_reshaper_bug
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Printing the names of the parameters
 {'net/linear': {'b': (), 'w': ()}}
Printing the names of the parameters
 {'net': {'linear': {'b': (), 'w': ()}}}
Traceback (most recent call last):
....
RobertTLange commented 2 years ago

Thank you very much -- the PR is merged and will be part of the next release.