google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.15k stars 649 forks source link

nnx.jit recompiles for new instance of Module even if parameter shapes and dtypes are unchanged #4329

Closed kriscao-cohere closed 4 weeks ago

kriscao-cohere commented 4 weeks ago

This may be WAI, but I found that nnx.jit aggressively recompiles if nnx.Module/nnx.Optimizer input instances are different, even if the state shapes and dtypes are the same. This particularly causes trouble if I'm reinitializing modules with e.g. different seeds to test stability across different initializations. It would be good if in this instance nnx.jit detects that the underlying computation graph is the same, and doesn't recompile.

cgarciae commented 4 weeks ago

Hi @kriscao-cohere, can you provide a minimal reproducible example? I created this test script and nnx.jit does not seem to recompile based instance identity:

rngs = nnx.Rngs(0)
m1 = nnx.Linear(2, 3, rngs=rngs)
m2 = nnx.Linear(2, 3, rngs=rngs)

N = 0
@nnx.jit
def forward(m, x):
  global N
  N += 1
  return m(x)

x1 = jax.random.uniform(rngs(), (5, 2))
x2 = jax.random.uniform(rngs(), (5, 2))

y1 = forward(m1, x1)
assert N == 1

y2 = forward(m2, x2)
assert N == 1 # did not recompile

It would be good if in this instance nnx.jit detects that the underlying computation graph is the same, and doesn't recompile.

This is precisely what it does, internally nnx.split is used on each input Module / Object to extract a GraphDef and State pytrees, these are passed in place of the Module/Objects to jax.jit which then uses its hash-based cache system to decide when to recompile or not.

cgarciae commented 4 weeks ago

Maybe you are using some non-nnx objects that hash by identity? This would certainly cause jit to recompile even if most of the state is the same. To test you can manually use nnx.split on your objects and check the hash of their GraphDef.

kriscao-cohere commented 4 weeks ago

Thanks for the suggestion, I didn't realize that the way I was passing in the kernel init function to my layers was causing a new local function variable to be created each time I initialized my model. Thanks!