patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.07k stars 134 forks source link

Dropout causes `None`s with gradient accumulation #772

Open neel04 opened 3 months ago

neel04 commented 3 months ago

Dropout is really the bane of equinox it seems. Loose follow-up of #681 - effectively, I'm trying to fix this problem that cropped up a while ago when using optax.MultiSteps for gradient accumulation.

Effectively, the change is:

- optim = optax.adamw(1e-3)
+ accum_steps: int = 3
+ optim = optax.MultiSteps(optax.adamw(1e-3), accum_steps)

Leading to this error:

Details

```py jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/Users/neel/Documents/research/ReAct_Jax/train_model.py", line 171, in main(key) File "/Users/neel/Documents/research/ReAct_Jax/train_model.py", line 133, in main trainer.train() File "/Users/neel/Documents/research/ReAct_Jax/ReAct/utils/trainer.py", line 363, in train loss, model, opt_state = make_step(model, opt_state, filter_spec, seq, label, pad_mask, File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/scalax/sharding.py", line 300, in wrapped results = jitted_fn(*args) File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/scalax/sharding.py", line 264, in sharding_constrained_fun return fun(*args) File "/Users/neel/Documents/research/ReAct_Jax/ReAct/utils/trainer.py", line 131, in make_step updates, opt_state = optim.update(grads, opt_state, model) File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/optax/_src/combine.py", line 73, in update_fn updates, new_s = fn(updates, s, params, **extra_args) File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/optax/_src/base.py", line 337, in update return tx.update(updates, state, params) File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/optax/_src/wrappers.py", line 435, in update new_updates, new_state = jax.lax.cond( File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/optax/_src/wrappers.py", line 407, in _do_update inner_opt_state=jax.tree_util.tree_map( File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/optax/_src/wrappers.py", line 408, in lambda st, nst: jnp.where(emit, nst, st), File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1140, in where util.check_arraylike("where", acondition, if_true, if_false) File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/jax/_src/numpy/util.py", line 335, in check_arraylike raise TypeError(msg.format(fun_name, type(arg), pos)) TypeError: where requires ndarray or scalar arguments, got at position 1. jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/Users/neel/Documents/research/ReAct_Jax/train_model.py", line 171, in main(key) File "/Users/neel/Documents/research/ReAct_Jax/train_model.py", line 133, in main trainer.train() File "/Users/neel/Documents/research/ReAct_Jax/ReAct/utils/trainer.py", line 363, in train loss, model, opt_state = make_step(model, opt_state, filter_spec, seq, label, pad_mask, File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/scalax/sharding.py", line 300, in wrapped results = jitted_fn(*args) File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/scalax/sharding.py", line 264, in sharding_constrained_fun return fun(*args) File "/Users/neel/Documents/research/ReAct_Jax/ReAct/utils/trainer.py", line 131, in make_step updates, opt_state = optim.update(grads, opt_state, model) File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/optax/_src/combine.py", line 73, in update_fn updates, new_s = fn(updates, s, params, **extra_args) File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/optax/_src/base.py", line 337, in update return tx.update(updates, state, params) File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/optax/_src/wrappers.py", line 435, in update new_updates, new_state = jax.lax.cond( File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/optax/_src/wrappers.py", line 407, in _do_update inner_opt_state=jax.tree_util.tree_map( File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/optax/_src/wrappers.py", line 408, in lambda st, nst: jnp.where(emit, nst, st), File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1140, in where util.check_arraylike("where", acondition, if_true, if_false) File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/jax/_src/numpy/util.py", line 335, in check_arraylike raise TypeError(msg.format(fun_name, type(arg), pos)) TypeError: where requires ndarray or scalar arguments, got at position 1. ```

Here's a colab MRE if you want to play around with this.

From a little debugging, the problem seems to be solely with eqx.nn.Dropout - somehow with inference flag maybe? According to the debugger, the problem is with the nst: PyTree variable here in optax.

It's weird, but if you print out the leaves:

jax.tree_util.tree_map(
    lambda st, nst: print(f"{st} : {nst} | {emit} | {nst is None}"),
    state.inner_opt_state,
    new_inner_state,
)

Then it appears the None leaves are bool tracers. For example:

Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=2/0)> : None
| Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=2/0)> | True

The only bool around should be the inference argument of Dropout.

I'm not too sure - but I suppose even with our earlier discussed fix of filtering by is_array_like, we don't really have a way to filter bools trivially.

Maybe this might be from optax's side somehow - but I'm not sure why inference would be promoted to a tracer at all. I'm already partition-ing out Dropout:

eqx.partition(layer, eqx.is_array, is_leaf=lambda x: isinstance(x, eqx.nn.Dropout))

Not sure what's going on here. Hopefully you can shed some light on the internals and figure out whats going wrong...

patrick-kidger commented 3 months ago

At first glance it's probably that you're doing opt_state = optim.init(eqx.filter(model, eqx.is_array_like)) rather than opt_state = optim.init(eqx.filter(model, eqx.is_array)). Equinox generally wants you to treat non-arrays as static.

lockwo commented 3 months ago

Can confirm, if you make that change in the code provided it works

neel04 commented 3 months ago

That doesn't work for my codebase. I'm filtering by is_array_like to catch dropout.p which is a float and is promoted to tracer - according to how you explained here 😄

Ideally it should be catching inference as well since that's a bool - so perhaps the source of the error is elsewhere.

This is the traceback when filtering by eqx.is_array - It just encounters Nones, but earlier 🤷

Details

```py Traceback (most recent call last): File "/Users/neel/Documents/research/ReAct_Jax/train_model.py", line 171, in main(key) File "/Users/neel/Documents/research/ReAct_Jax/train_model.py", line 133, in main trainer.train() File "/Users/neel/Documents/research/ReAct_Jax/ReAct/utils/trainer.py", line 364, in train loss, model, opt_state = make_step(model, opt_state, filter_spec, seq, label, pad_mask, File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/scalax/sharding.py", line 300, in wrapped results = jitted_fn(*args) File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/scalax/sharding.py", line 264, in sharding_constrained_fun return fun(*args) File "/Users/neel/Documents/research/ReAct_Jax/ReAct/utils/trainer.py", line 132, in make_step updates, opt_state = optim.update(grads, opt_state, model) File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/optax/_src/combine.py", line 73, in update_fn updates, new_s = fn(updates, s, params, **extra_args) File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/optax/_src/base.py", line 337, in update return tx.update(updates, state, params) File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/optax/_src/wrappers.py", line 435, in update new_updates, new_state = jax.lax.cond( File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/optax/_src/wrappers.py", line 391, in _do_update acc_grads = jax.tree_util.tree_map( File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/optax/_src/wrappers.py", line 392, in lambda upd, acc: self._acc_update(upd, acc, n_acc=state.mini_step), File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/optax/_src/wrappers.py", line 344, in lambda grad, acc, *, n_acc: acc + (grad - acc) / (n_acc + 1)) TypeError: unsupported operand type(s) for -: 'DynamicJaxprTracer' and 'NoneType' ```

clementpoiret commented 1 month ago

Same bug for me, but not with MultiSteps, I have it with optax.contrib.scheduler_free(), even by filtering with eqx.is_array... https://gist.github.com/clementpoiret/f9e53a934462d38e56f32be87730f661

Removing the dropout assignment resolves the issue. making p and inference as static fields also solves the issue

haydn-jones commented 2 weeks ago

Unsure if you've fixed this, but I ran into this in #853. For adamw, filtering in the initialization alone is not sufficient as adamw requires you to pass in the model parameters during an optimizer update. So you need to do optim.update(grads, opt_state, eqx.filter(model, eqx.is_array)) as well. Perhaps the documentation should be updated to indicate that during init and update you may need to filter?

neel04 commented 2 weeks ago

Oh yea, I'd fixed it but forgot to update this thread 😅

It should definitely be added to the FAQs in the docs - It's a pity one needs to do so much filter-ing, and would save some beginners quite a bit of time.

neel04 commented 2 weeks ago

@clementpoiret unfortunately, you'd have to step through the optax code and figure out when and how to filter out the PyTree; If pre-filtering doesn't work for some reason, then maybe try inserting additional filter statements in the optax code itself.

As I've mentioned before, Dropout is an absolute bane for equinox. Ik there are some discussions here about the best way to fix it - but at this point, I feel like using a stronger version of filter_* transforms that explicitly takes Dropout into account would be the way to solve this problem, as the cost of breaking compatibility I suppose...

clementpoiret commented 2 weeks ago

Thanks all for your pointers. I ended up marking p and inference as static fields. This adds some verbosity in the code as I now need to pass the inference param in each forward pass 😅