Open kriscao-cohere opened 1 week ago
Thanks for posting this @kriscao-cohere! We do have a global context that keeps track of graph references during jit
and all other transforms. That might be the first place I would look.
https://github.com/google/flax/blob/e4dad9ca2453b37da77cba790e49b35f4492fde9/flax/nnx/graph.py#L740
Hopefully the update_context
context manager is not messing anything:
https://github.com/google/flax/blob/e4dad9ca2453b37da77cba790e49b35f4492fde9/flax/nnx/graph.py#L1046
I'll look into it. But if this is blocking you consider using regular jax.jit
as detailed in Performance Considerations for NNX.
Just as a sanity check I ran this simple training code performing a print(nnx.graph.GRAPH_CONTEXT)
after each step but all the context stacks are empty. Maybe the python garbage collector is hitting a spike?
Hi @cgarciae, thanks for the updates! And thanks for pointing me to the jax-only pattern for using NNX, I tried it and it eliminated the wait time (and also sped up my experiment loop a lot, but I am doing small model experiments).
As for repro, I only saw it with certain model architectures (mainly Transformers past a certain layer depth), and I tried to repro the slow nnx.split
only by running nnx.split
by itself 100 times, but there were no unexpected slowdowns. It was only when I did 100 model forward passes (even on dummy data) that I would see the occasional slow step caused by nnx graph traversal.
I'm using NNX for a toy transformer on Wikitext-103, and I'm observing that one in every ~100 steps there's a step that takes much much longer (on the order of 2 seconds vs 0.02 seconds). I'm managed to trakc down the culprit with a profile, and it seems that there's sone NNX internal machinery in
nnx.split
that's taking the bulk of the time:Is there anything NNX-related that could be causing this to take a long time?