Closed mblondel closed 1 year ago
Similarly to the
dtype
andaux
consistency checks incommon_test.py
, we need to checkweak_type
consistency for each solver in a systematic manner, and make fixes if necessary.
This sounds like the right fix to me. A change in weak type for a function input is like changing the dtype: it can change the function's behavior, and must trigger a re-compilation. If you want to avoid recompilation, you need to make sure the inputs of the second function call match the inputs of the first function call.
Thanks @jakevdp! I think I tracked down the inconsistency to this behavior:
>>> a = jnp.asarray(0.0, dtype=jnp.float32)
>>> a.weak_type
False
>>> a.dtype
dtype('float32')
>>> b = jnp.asarray(0.0)
>>> b.weak_type
True
>>> b.dtype
dtype('float32')
That is, when we explicitly specify dtype
, weak_type
is False, while weak_type
is True
if we don't explicitly specify dtype
...
Yes, that's expected. Roughly, the mental model of "weak type" is that it's a value whose dtype has not been specified by the user. It's the mechanism that allows (x + 1).dtype == x.dtype
to hold true within JAX code.
fixed by #458
Each solver in JAXopt maintains a state, which contains a bunch of attributes, some of them are scalar-valued (e.g.,
stepsize
). In some solvers, the state returned byinit_state
and the state returned byupdate
have inconsistentweak_type
for some of these attributes, which triggers a JIT recompilation ofupdate
. This can be seen with the code example below (code by @fllinares):Output:
What's happening:
state0
is obtained from theinit_state
call and is given as input toupdate
. A first JIT compilation happens, withstate0.stepsize.weak_type = False
. Thenupdate
outputs a new statestate1
withstate1.stepsize.weak_type = True
. When we use that state as input toupdate
, a JIT recompilation happens sinceweak_type
has changed. For the following calls toupdate
, no recompilation occurs, sinceweak_type
remainsTrue
.Similarly to the
dtype
andaux
consistency checks incommon_test.py
, we need to checkweak_type
consistency for each solver in a systematic manner, and make fixes if necessary.@froystig Your opinion on the best way to fix would be welcome.