patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.05k stars 49 forks source link

jaxtyping with JAX severely slowing down training speed #213

Closed kvablack closed 1 month ago

kvablack commented 1 month ago

So, I finally got around to profiling my train step, and I saw this:

image

My my absolute horror, I disabled typechecking and saw roughly a 2x speedup to training!!

My understanding was that typechecking should run only during tracing, and not during subsequent calls to a jitted train_step function. I can work harder to get a minimal repro going, but I wanted to see if you had any ideas off the top of your head. Here are some random additional facts:

patrick-kidger commented 1 month ago

Indeed, the usual behaviour is for jaxtyping is to be completely done by runtime. Probably you're constructing a typechecked dataclass or calling a typechecked function somewhere inside your training loop!

kvablack commented 1 month ago

This is all happening inside the train_step, which is jitted. You can even see in the trace that the isinstance calls are all inside the PjitFunction(train_step) call. Nothing else is happening in the train loop except for the train_step call.

Is it expected that typechecking dataclasses would run during a jitted function call?

patrick-kidger commented 1 month ago

So once a function has been JIT'd then the Python code is never evaluated again.

We don't use jax.debug.callback or jax.pure_callback or jax.experimental.io_callback, and those are the only possible escape hatches for running normal Python in JIT'd code. We don't use those so we couldn't run typechecking many times even if we wanted to!

One possibility is that Flax is deserialising via its __init__ method (which is where typechecking occurs on dataclasses), so that all of this extra overhead is occurring during the very end of a jit'd call, when reconstructing the Python objects being passed out of the JIT'd region.

If so then that's a Flax bug, but before we go pointing fingers: can you test what result you get when you don't return any Flax objects from your JIT'd call?

kvablack commented 1 month ago

One possibility is that Flax is deserialising via its init method (which is where typechecking occurs on dataclasses), so that all of this extra overhead is occurring during the very end of a jit'd call, when reconstructing the Python objects being passed out of the JIT'd region.

Pretty sure this is it. Here's a minimal repro that shows a performance diff only if a dataclass is being returned:

--- ```python import flax.serialization from jaxtyping import jaxtyped, Array, Int, ArrayLike, config from typeguard import typechecked import jax from functools import partial from flax import struct @partial(jaxtyped, typechecker=typechecked) @struct.dataclass class Node: x: list[Int[ArrayLike, ""]] y: list[str | int | float] = struct.field(pytree_node=False) # returning a dataclass @jax.jit @partial(jaxtyped, typechecker=typechecked) def f(x: Node) -> Array: return x x = Node(list(range(1000)), list(map(str, range(1000))) + list(range(1000)) + list(map(float, range(1000)))) f(x) config.update("jaxtyping_disable", False) f(x) %timeit f(x) # 21ms config.update("jaxtyping_disable", True) f(x) %timeit f(x) # 15ms # not returning a dataclass @jax.jit @partial(jaxtyped, typechecker=typechecked) def f(x: Node) -> Array: return flax.serialization.to_state_dict(x) f(x) config.update("jaxtyping_disable", False) f(x) %timeit f(x) # 15ms config.update("jaxtyping_disable", True) f(x) %timeit f(x) # 15ms ``` ---

And here's the line where the constructor gets called --- in the unflatten_func.

Is there another way to do it? I'm not too familiar with PyTree internals, but with a custom PyTree, you need to construct the Python object at some point, right? How does Equinox do it?

patrick-kidger commented 1 month ago

Ah, going via __init__ like this is known to be a dodgy thing to do. See the JAX docs here.

Not only does this run afoul of typechecking, it also misbehaves if you ever define a custom __init__ method.

Correct behaviour is to go via __new__ and then construct the desired fields directly. Here's the Equinox implementation, whihc does exactly that:

https://github.com/patrick-kidger/equinox/blob/955c0347de2690b07e59aad70e5666ded4ee28ef/equinox/_module.py#L913

I'm afraid this one isn't something we can really fix from jaxtyping's end. I'd suggest avoiding returning the Flax objects in question here. You say this is a flax.struct.dataclass. If that's the case then equinox.Module should be more-or-less a drop-in replacement for this. (It also fixes several other known issues around inheritance, bound methods, etc. etc.) Alternatively you could use normal Python types: tuples/dictionaries/etc.

kvablack commented 1 month ago

Ok thanks, that makes total sense. I'm sorry, I'm sure Equinox is great, I just have too much infra built up around flax.struct.dataclass 🥲 . FWIW, I'm a huge fan of jaxtyping, and the fact that it's worked this well for me and my sprawling Flax codebase is a testament to its excellent design.

At first, I hacked together a solution that disabled typechecking completely during Flax's PyTree unflattening. But then I looked at the trace again and realized almost all of the overhead was from PyTree typechecks. So I just removed the PyTree[Float[Array, "..."]] annotation from my params tree, and lo and behold training was fast again. Using Perfetto, I was even able to measure the remaining typechecking overhead from dataclass unflattening and it was only ~1ms, which is pretty acceptable. In theory I think it should be entirely negated by JAX's asynchronous dispatch, although I'm not sure if it's working correctly, considering that the unflattening doesn't seem to happen until after all the GPU operations have finished.