cbg-ethz / bmi

Mutual information estimators and benchmark
https://cbg-ethz.github.io/bmi/
MIT License
26 stars 4 forks source link

clean up equinox mlp implementation #144

Closed grfrederic closed 4 months ago

grfrederic commented 4 months ago

getting back in the saddle :)

pawel-czyz commented 4 months ago

I appreciate the effort, but I don't think it'll work without further adjustments for the following reason: jax.nn.relu is a parameter-free function and Optax has problems with initialising such layers...

src/bmi/estimators/neural/_estimators.py:167: in estimate
    return self.estimate_with_info(x, y).mi_estimate
src/bmi/estimators/neural/_estimators.py:144: in estimate_with_info
    training_log, new_critic = basic_training(
src/bmi/estimators/neural/_basic_training.py:56: in basic_training
    opt_state = optimizer.init(critic)
../../micromamba/envs/bmi/lib/python3.10/site-packages/optax/_src/combine.py:50: in init_fn
    return tuple(fn(params) for fn in init_fns)
../../micromamba/envs/bmi/lib/python3.10/site-packages/optax/_src/combine.py:50: in <genexpr>
    return tuple(fn(params) for fn in init_fns)
../../micromamba/envs/bmi/lib/python3.10/site-packages/optax/_src/transform.py:353: in init_fn
    mu = jax.tree_util.tree_map(  # First moment
../../micromamba/envs/bmi/lib/python3.10/site-packages/jax/_src/tree_util.py:312: in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
../../micromamba/envs/bmi/lib/python3.10/site-packages/jax/_src/tree_util.py:312: in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
../../micromamba/envs/bmi/lib/python3.10/site-packages/optax/_src/transform.py:354: in <lambda>
    lambda t: jnp.zeros_like(t, dtype=mu_dtype), params)
../../micromamba/envs/bmi/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:2248: in zeros_like

One possible solution in Equinox is filtering, but I worry that it'll make the rest of the codebase a bit more complex. What do you think?