Open colehaus opened 4 months ago
Ah, I realized this is a workable solution:
def call(x: ndarray[Any, Float], dynamic, static) -> ndarray[Any, Float]:
test = eqx.combine(dynamic, static)
return test.__call__(x)
dynamic, static = eqx.partition(t, eqx.is_array)
jax.jit(ft.partial(call, static=static))(jnp.ones(d), dynamic)
I think this is happening because you're grabbing __call__
, which as a magic method isn't subject to the same bound-methods-are-PyTrees treatment as regular methods. This is the reason t
is being closed over, rather than provided as an input.
Can you try doing just eqx.filter_jit(t)(np.ones(d))
instead?
Ahh, that does make a big difference. I had gotten into the habit of doing explicit __call__
so jump-to-definition in my editor would be more useful and hadn't thought of it as anything more than a trivial syntactic transformation.
As you can see from the output, the jaxpr to MLIR and XLA compilation steps take longer and longer as the array dimension increases until it finally crashes during compilation. I believe this is because we're effectively closing over larger and larger values and JAX is doing work that scales with the size of the closed-over values (https://github.com/google/jax/issues/16278 may be related).
Flax avoids this issue because it directly passes the parameters/weights as arguments to the function. That perhaps seems like the best approach ATM. Is there a reasonable away to achieve behavior like that in Equinox?
(Unless I'm missing something, this is a pretty significant limitation for e.g. doing inference with language models where you'd want to JIT the sampling for a fixed model.)
Equinox version 0.11.4