Closed kvablack closed 1 month ago
Indeed, the usual behaviour is for jaxtyping is to be completely done by runtime. Probably you're constructing a typechecked dataclass or calling a typechecked function somewhere inside your training loop!
This is all happening inside the train_step
, which is jitted. You can even see in the trace that the isinstance
calls are all inside the PjitFunction(train_step)
call. Nothing else is happening in the train loop except for the train_step
call.
Is it expected that typechecking dataclasses would run during a jitted function call?
So once a function has been JIT'd then the Python code is never evaluated again.
We don't use jax.debug.callback
or jax.pure_callback
or jax.experimental.io_callback
, and those are the only possible escape hatches for running normal Python in JIT'd code. We don't use those so we couldn't run typechecking many times even if we wanted to!
One possibility is that Flax is deserialising via its __init__
method (which is where typechecking occurs on dataclasses), so that all of this extra overhead is occurring during the very end of a jit'd call, when reconstructing the Python objects being passed out of the JIT'd region.
If so then that's a Flax bug, but before we go pointing fingers: can you test what result you get when you don't return any Flax objects from your JIT'd call?
One possibility is that Flax is deserialising via its init method (which is where typechecking occurs on dataclasses), so that all of this extra overhead is occurring during the very end of a jit'd call, when reconstructing the Python objects being passed out of the JIT'd region.
Pretty sure this is it. Here's a minimal repro that shows a performance diff only if a dataclass is being returned:
And here's the line where the constructor gets called --- in the unflatten_func
.
Is there another way to do it? I'm not too familiar with PyTree internals, but with a custom PyTree, you need to construct the Python object at some point, right? How does Equinox do it?
Ah, going via __init__
like this is known to be a dodgy thing to do. See the JAX docs here.
Not only does this run afoul of typechecking, it also misbehaves if you ever define a custom __init__
method.
Correct behaviour is to go via __new__
and then construct the desired fields directly. Here's the Equinox implementation, whihc does exactly that:
I'm afraid this one isn't something we can really fix from jaxtyping's end. I'd suggest avoiding returning the Flax objects in question here. You say this is a flax.struct.dataclass
. If that's the case then equinox.Module
should be more-or-less a drop-in replacement for this. (It also fixes several other known issues around inheritance, bound methods, etc. etc.) Alternatively you could use normal Python types: tuples/dictionaries/etc.
Ok thanks, that makes total sense. I'm sorry, I'm sure Equinox is great, I just have too much infra built up around flax.struct.dataclass
🥲 . FWIW, I'm a huge fan of jaxtyping, and the fact that it's worked this well for me and my sprawling Flax codebase is a testament to its excellent design.
At first, I hacked together a solution that disabled typechecking completely during Flax's PyTree unflattening. But then I looked at the trace again and realized almost all of the overhead was from PyTree typechecks. So I just removed the PyTree[Float[Array, "..."]]
annotation from my params
tree, and lo and behold training was fast again. Using Perfetto, I was even able to measure the remaining typechecking overhead from dataclass unflattening and it was only ~1ms, which is pretty acceptable. In theory I think it should be entirely negated by JAX's asynchronous dispatch, although I'm not sure if it's working correctly, considering that the unflattening doesn't seem to happen until after all the GPU operations have finished.
So, I finally got around to profiling my train step, and I saw this:
My my absolute horror, I disabled typechecking and saw roughly a 2x speedup to training!!
My understanding was that typechecking should run only during tracing, and not during subsequent calls to a jitted
train_step
function. I can work harder to get a minimal repro going, but I wanted to see if you had any ideas off the top of your head. Here are some random additional facts:jax.device_put(inputs, data_parallel_sharding)
and letting JAX do the rest)flax.struct.dataclass
that is being typechecked, which is probably where most of the typechecking overhead is coming from. Only some of its nodes are PyTree nodes, but quite a few are metadata nodes (i.e., not registered with JAX). Come to think of it, this is most likely the culprit.