evosax.ParameterReshaper assumes a naming convention for the network parameters, which reserves the / character for evosax internal use. However, dm-haiku also uses / in its parameter naming convention, causing a clash between the two packages. Demo below. Tested at 36e9f75aa9024091db85382d3cb934fc8ac11da0
import jax
import jax.numpy as jnp
import jax.random as jrandom
import haiku as hk
from evosax import ParameterReshaper, OpenES
class Net(hk.Module):
def __init__(self, name=None):
super(Net, self).__init__(name=name)
def __call__(self, inputs):
return hk.Linear(1)(inputs)
init_theta, apply_theta = hk.without_apply_rng(hk.transform(lambda inputs: Net()(inputs)))
k = jrandom.PRNGKey(0)
theta = init_theta(k, jnp.ones(1))
print("Printing the names of the parameters\n", jax.tree_map(lambda a: (), theta))
outputs = apply_theta(theta, jnp.ones(1))
param_reshaper = ParameterReshaper(theta)
es_strategy = OpenES(
num_dims=param_reshaper.total_params,
popsize=2,
)
es_params = es_strategy.default_params
es_state = es_strategy.initialize(k, es_params)
x, es_state = es_strategy.ask(k, es_state, es_params)
# This creates the problem by renaming the parameters
theta = param_reshaper.reshape(x)
print("Printing the names of the parameters\n", jax.tree_map(lambda a: (), theta))
outputs = jax.vmap(apply_theta, (0, None))(theta, jnp.ones(1))
Output:
python -m param_reshaper_bug
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Printing the names of the parameters
{'net/linear': {'b': (), 'w': ()}}
Printing the names of the parameters
{'net': {'linear': {'b': (), 'w': ()}}}
Traceback (most recent call last):
....
evosax.ParameterReshaper
assumes a naming convention for the network parameters, which reserves the/
character for evosax internal use. However,dm-haiku
also uses/
in its parameter naming convention, causing a clash between the two packages. Demo below. Tested at 36e9f75aa9024091db85382d3cb934fc8ac11da0Output: