Open neel04 opened 4 months ago
At first glance it's probably that you're doing opt_state = optim.init(eqx.filter(model, eqx.is_array_like))
rather than opt_state = optim.init(eqx.filter(model, eqx.is_array))
. Equinox generally wants you to treat non-arrays as static.
Can confirm, if you make that change in the code provided it works
That doesn't work for my codebase. I'm filtering by is_array_like
to catch dropout.p
which is a float
and is promoted to tracer - according to how you explained here 😄
Ideally it should be catching inference
as well since that's a bool
- so perhaps the source of the error is elsewhere.
This is the traceback when filtering by eqx.is_array
- It just encounters None
s, but earlier 🤷
```py
Traceback (most recent call last):
File "/Users/neel/Documents/research/ReAct_Jax/train_model.py", line 171, in
Same bug for me, but not with MultiSteps, I have it with optax.contrib.scheduler_free()
, even by filtering with eqx.is_array
... https://gist.github.com/clementpoiret/f9e53a934462d38e56f32be87730f661
Removing the dropout assignment resolves the issue. making p and inference as static fields also solves the issue
Unsure if you've fixed this, but I ran into this in #853. For adamw
, filtering in the initialization alone is not sufficient as adamw
requires you to pass in the model parameters during an optimizer update. So you need to do optim.update(grads, opt_state, eqx.filter(model, eqx.is_array))
as well. Perhaps the documentation should be updated to indicate that during init and update you may need to filter?
Oh yea, I'd fixed it but forgot to update this thread 😅
It should definitely be added to the FAQs in the docs - It's a pity one needs to do so much filter
-ing, and would save some beginners quite a bit of time.
@clementpoiret unfortunately, you'd have to step through the optax
code and figure out when and how to filter out the PyTree
; If pre-filtering doesn't work for some reason, then maybe try inserting additional filter
statements in the optax
code itself.
As I've mentioned before, Dropout
is an absolute bane for equinox
. Ik there are some discussions here about the best way to fix it - but at this point, I feel like using a stronger version of filter_*
transforms that explicitly takes Dropout
into account would be the way to solve this problem, as the cost of breaking compatibility I suppose...
Thanks all for your pointers. I ended up marking p
and inference
as static fields. This adds some verbosity in the code as I now need to pass the inference param in each forward pass 😅
Dropout is really the bane of equinox it seems. Loose follow-up of #681 - effectively, I'm trying to fix this problem that cropped up a while ago when using
optax.MultiSteps
for gradient accumulation.Effectively, the change is:
Leading to this error:
Details
```py jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/Users/neel/Documents/research/ReAct_Jax/train_model.py", line 171, in
main(key)
File "/Users/neel/Documents/research/ReAct_Jax/train_model.py", line 133, in main
trainer.train()
File "/Users/neel/Documents/research/ReAct_Jax/ReAct/utils/trainer.py", line 363, in train
loss, model, opt_state = make_step(model, opt_state, filter_spec, seq, label, pad_mask,
File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/scalax/sharding.py", line 300, in wrapped
results = jitted_fn(*args)
File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/scalax/sharding.py", line 264, in sharding_constrained_fun
return fun(*args)
File "/Users/neel/Documents/research/ReAct_Jax/ReAct/utils/trainer.py", line 131, in make_step
updates, opt_state = optim.update(grads, opt_state, model)
File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/optax/_src/combine.py", line 73, in update_fn
updates, new_s = fn(updates, s, params, **extra_args)
File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/optax/_src/base.py", line 337, in update
return tx.update(updates, state, params)
File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/optax/_src/wrappers.py", line 435, in update
new_updates, new_state = jax.lax.cond(
File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/optax/_src/wrappers.py", line 407, in _do_update
inner_opt_state=jax.tree_util.tree_map(
File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/optax/_src/wrappers.py", line 408, in
lambda st, nst: jnp.where(emit, nst, st),
File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1140, in where
util.check_arraylike("where", acondition, if_true, if_false)
File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/jax/_src/numpy/util.py", line 335, in check_arraylike
raise TypeError(msg.format(fun_name, type(arg), pos))
TypeError: where requires ndarray or scalar arguments, got at position 1.
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/neel/Documents/research/ReAct_Jax/train_model.py", line 171, in
main(key)
File "/Users/neel/Documents/research/ReAct_Jax/train_model.py", line 133, in main
trainer.train()
File "/Users/neel/Documents/research/ReAct_Jax/ReAct/utils/trainer.py", line 363, in train
loss, model, opt_state = make_step(model, opt_state, filter_spec, seq, label, pad_mask,
File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/scalax/sharding.py", line 300, in wrapped
results = jitted_fn(*args)
File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/scalax/sharding.py", line 264, in sharding_constrained_fun
return fun(*args)
File "/Users/neel/Documents/research/ReAct_Jax/ReAct/utils/trainer.py", line 131, in make_step
updates, opt_state = optim.update(grads, opt_state, model)
File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/optax/_src/combine.py", line 73, in update_fn
updates, new_s = fn(updates, s, params, **extra_args)
File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/optax/_src/base.py", line 337, in update
return tx.update(updates, state, params)
File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/optax/_src/wrappers.py", line 435, in update
new_updates, new_state = jax.lax.cond(
File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/optax/_src/wrappers.py", line 407, in _do_update
inner_opt_state=jax.tree_util.tree_map(
File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/optax/_src/wrappers.py", line 408, in
lambda st, nst: jnp.where(emit, nst, st),
File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1140, in where
util.check_arraylike("where", acondition, if_true, if_false)
File "/Users/neel/miniconda3/envs/react_jax/lib/python3.10/site-packages/jax/_src/numpy/util.py", line 335, in check_arraylike
raise TypeError(msg.format(fun_name, type(arg), pos))
TypeError: where requires ndarray or scalar arguments, got at position 1.
```
From a little debugging, the problem seems to be solely with
eqx.nn.Dropout
- somehow withinference
flag maybe? According to the debugger, the problem is with thenst: PyTree
variable here in optax.It's weird, but if you print out the leaves:
Then it appears the
None
leaves arebool
tracers. For example:The only
bool
around should be theinference
argument ofDropout
.I'm not too sure - but I suppose even with our earlier discussed fix of filtering by
is_array_like
, we don't really have a way to filterbool
s trivially.Maybe this might be from
optax
's side somehow - but I'm not sure whyinference
would be promoted to a tracer at all. I'm alreadypartition
-ing outDropout
:Not sure what's going on here. Hopefully you can shed some light on the internals and figure out whats going wrong...