Closed ciupakabra closed 1 year ago
loss
is redefined in here
loss, drift, opt_state = make_step( drift, 10, dim, step_key, 32, opt_state ) rename
loss
function toloss_fn
, it works.
A @uuirs points out, you've redefined the variable loss
from a function to an array.
That said, it's a little odd that you're only hitting this when using BatchNorm
. Is BatchNorm
triggering a recompilation, somehow?
Yeah it's odd that this happened in such a way and you're right -- I checked and with bn=False
it compiles once, whereas with bn=True
it compiles twice.
Had some time to debug this tonight. I think it's because the state of BatchNorm
is first initialized with a pytree
of a different PyTreeDef
than what it stores once it actually gets a state, i.e. it's first initialized as an array:
but once it calculates the running means and vars it sets it to a tuple of arrays:
This changes the PyTreeDef
of the StateIndex
module, which then changes the hash of PyTreeDef
of the whole model. And since the PyTreeDef
is used as a static argument somewhere in equinox
jitting logic, it gets recompiled.
I'm assuming the fix is a one line change in BatchNorm
initialization?
Aren't these different state indices?
Sorry, you're right, for initialization I meant to reference:
To be fair, even then I don't understand some of the behaviour. A small working example of the bug above is:
import jax
import jax.random as jrandom
import equinox as eqx
@eqx.filter_jit
def fun(bn, inp):
print("Compiling!")
return jax.vmap(bn, axis_name="batch")(inp)
def info(bn):
print(f"bn.state_index._state: {bn.state_index._state}") # prints None all the time
children, aux = bn.state_index.tree_flatten() # children should be dynamic_field_values and so children[0] should be the value of _state
print(f"children[0] of bn.state_index flatten: {children[0]}") # prints a tuple of arrays after 1st call to fun
if __name__=="__main__":
bn = eqx.experimental.BatchNorm(10, axis_name="batch")
inp = jrandom.normal(jrandom.PRNGKey(0), (32, 10))
info(bn)
fun(bn, inp)
info(bn)
fun(bn, inp)
info(bn)
fun(bn, inp)
which outputs
bn.state_index._state: None
children[0] of bn.state_index flatten: None
Compiling!
bn.state_index._state: None
children[0] of bn.state_index flatten: (DeviceArray([-0.32395738, -0.21207958, -0.31645954, 0.05969752,
-0.11307174, 0.0944065 , -0.14875616, -0.05194131,
0.10097986, 0.25392908], dtype=float32), DeviceArray([0.92587423, 0.95594984, 1.0194211 , 0.84475476, 0.76749337,
0.77187854, 1.38814 , 0.8497227 , 1.1132355 , 0.86574566], dtype=float32))
Compiling!
bn.state_index._state: None
children[0] of bn.state_index flatten: (DeviceArray([-0.32395738, -0.21207958, -0.31645954, 0.05969752,
-0.11307174, 0.0944065 , -0.14875618, -0.05194131,
0.10097986, 0.25392908], dtype=float32), DeviceArray([0.92587423, 0.95594984, 1.0194211 , 0.84475476, 0.76749337,
0.77187854, 1.38814 , 0.8497227 , 1.1132355 , 0.86574566], dtype=float32))
So you can clearly see that at least when accessing _state
through tree_flatten
of StateIndex
, _state
changes from None
to a 2-tuple of arrays, changing the PyTreeDef
of the static args.
I don't understand why bn.state_index._state
is constantly outputing None
though, I thought this would set it to a new value:
but I guess it gets deleted somewhere in between the two prints by _delete_smuggled_state
.
Closing as eqx.experimental.BatchNorm
is now available (in theory without bugs) as eqx.nn.BatchNorm
.
I'm getting weird errors with
BatchNorm
. One example is the code below, where ODE parameters are optimized and the drift is a neural network with someBatchNorm
layers. The error thrown isTypeError: Expected a callable value, got inf
when vmapping a loss function.I wasn't able to reduce this to something without
diffrax
-- double vmaps or vmapping a scan doesn't raise exceptions. I'm guessing this is an equinox issue since without batchnorm (bn=False
) there are no problems.