Closed grfrederic closed 4 months ago
I appreciate the effort, but I don't think it'll work without further adjustments for the following reason: jax.nn.relu
is a parameter-free function and Optax has problems with initialising such layers...
src/bmi/estimators/neural/_estimators.py:167: in estimate
return self.estimate_with_info(x, y).mi_estimate
src/bmi/estimators/neural/_estimators.py:144: in estimate_with_info
training_log, new_critic = basic_training(
src/bmi/estimators/neural/_basic_training.py:56: in basic_training
opt_state = optimizer.init(critic)
../../micromamba/envs/bmi/lib/python3.10/site-packages/optax/_src/combine.py:50: in init_fn
return tuple(fn(params) for fn in init_fns)
../../micromamba/envs/bmi/lib/python3.10/site-packages/optax/_src/combine.py:50: in <genexpr>
return tuple(fn(params) for fn in init_fns)
../../micromamba/envs/bmi/lib/python3.10/site-packages/optax/_src/transform.py:353: in init_fn
mu = jax.tree_util.tree_map( # First moment
../../micromamba/envs/bmi/lib/python3.10/site-packages/jax/_src/tree_util.py:312: in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
../../micromamba/envs/bmi/lib/python3.10/site-packages/jax/_src/tree_util.py:312: in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
../../micromamba/envs/bmi/lib/python3.10/site-packages/optax/_src/transform.py:354: in <lambda>
lambda t: jnp.zeros_like(t, dtype=mu_dtype), params)
../../micromamba/envs/bmi/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:2248: in zeros_like
One possible solution in Equinox is filtering, but I worry that it'll make the rest of the codebase a bit more complex. What do you think?
getting back in the saddle :)