Open GallagherCommaJack opened 1 year ago
Thanks for the report -- do you have a way to reproduce this?
I think because KerasVariable is not registered as a pytree?
Does it error out when trying to process a KerasVariable in a JAX op? That is theoretically possible because KerasVariable implements __jax_array__
, however JAX does not always respect that property. If that's the case we can look into a workaround.
can reproduce with this code:
model = keras.layers.Dense(10)
@jax.jit
def init_params():
model.build((3, 6))
return model.trainable_variables
params = init_params()
the flax Partitioned
class might be a useful example to look to re how to set up variables with metadata in a jax-friendly way
the state may not fit entirely on a single host
Can you share more info about the context here -- why are you trying to jit the build function? Could this be solved by incremental sharding as you create the variables (if we're talking about device memory) or buffering (if talking about host memory)?
If that helps, I'm able to run your code snippet by unwrapping the variables:
@jax.jit
def init_params():
model.build((3, 6))
return [v.value for v in model.trainable_variables]
params = init_params()
In general, making sure that JAX is able to process objects that implement __jax_array__
(e.g. our KerasVariables) exactly as if they were JAX arrays has been an ongoing thread that Matt Johnson has been helping us with. We'll likely get there in a future JAX release. For now, when you hit this issue, you can always unwrap them via variable.value
.
CC @mattjj
This is really a JAX bug IMO. Let me take a look.
@mattjj any progress on this?
in distributed training, it's often desirable to wrap the
model.build()
call in ajax.jit
, since the state may not fit entirely on a single host.right now that's throwing this error:
I think because
KerasVariable
is not registered as a pytree?