Closed kriscao-cohere closed 4 weeks ago
Hi @kriscao-cohere, can you provide a minimal reproducible example? I created this test script and nnx.jit
does not seem to recompile based instance identity:
rngs = nnx.Rngs(0)
m1 = nnx.Linear(2, 3, rngs=rngs)
m2 = nnx.Linear(2, 3, rngs=rngs)
N = 0
@nnx.jit
def forward(m, x):
global N
N += 1
return m(x)
x1 = jax.random.uniform(rngs(), (5, 2))
x2 = jax.random.uniform(rngs(), (5, 2))
y1 = forward(m1, x1)
assert N == 1
y2 = forward(m2, x2)
assert N == 1 # did not recompile
It would be good if in this instance nnx.jit detects that the underlying computation graph is the same, and doesn't recompile.
This is precisely what it does, internally nnx.split
is used on each input Module / Object to extract a GraphDef
and State
pytrees, these are passed in place of the Module/Objects to jax.jit
which then uses its hash-based cache system to decide when to recompile or not.
Maybe you are using some non-nnx objects that hash by identity? This would certainly cause jit
to recompile even if most of the state is the same. To test you can manually use nnx.split
on your objects and check the hash of their GraphDef
.
Thanks for the suggestion, I didn't realize that the way I was passing in the kernel init function to my layers was causing a new local function variable to be created each time I initialized my model. Thanks!
This may be WAI, but I found that nnx.jit aggressively recompiles if nnx.Module/nnx.Optimizer input instances are different, even if the state shapes and dtypes are the same. This particularly causes trouble if I'm reinitializing modules with e.g. different seeds to test stability across different initializations. It would be good if in this instance nnx.jit detects that the underlying computation graph is the same, and doesn't recompile.