keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
62.05k stars 19.48k forks source link

no way to jit `model.build()` #18424

Open GallagherCommaJack opened 1 year ago

GallagherCommaJack commented 1 year ago

in distributed training, it's often desirable to wrap the model.build() call in a jax.jit, since the state may not fit entirely on a single host.

right now that's throwing this error:

TypeError: Value Tracedwith with type  is not a valid JAX type

I think because KerasVariable is not registered as a pytree?

fchollet commented 1 year ago

Thanks for the report -- do you have a way to reproduce this?

I think because KerasVariable is not registered as a pytree?

Does it error out when trying to process a KerasVariable in a JAX op? That is theoretically possible because KerasVariable implements __jax_array__, however JAX does not always respect that property. If that's the case we can look into a workaround.

GallagherCommaJack commented 1 year ago

can reproduce with this code:

model = keras.layers.Dense(10)

@jax.jit
def init_params():
    model.build((3, 6))
    return model.trainable_variables

params = init_params()
GallagherCommaJack commented 1 year ago

the flax Partitioned class might be a useful example to look to re how to set up variables with metadata in a jax-friendly way

fchollet commented 1 year ago

the state may not fit entirely on a single host

Can you share more info about the context here -- why are you trying to jit the build function? Could this be solved by incremental sharding as you create the variables (if we're talking about device memory) or buffering (if talking about host memory)?

fchollet commented 1 year ago

If that helps, I'm able to run your code snippet by unwrapping the variables:

@jax.jit
def init_params():
    model.build((3, 6))
    return [v.value for v in model.trainable_variables]

params = init_params()

In general, making sure that JAX is able to process objects that implement __jax_array__ (e.g. our KerasVariables) exactly as if they were JAX arrays has been an ongoing thread that Matt Johnson has been helping us with. We'll likely get there in a future JAX release. For now, when you hit this issue, you can always unwrap them via variable.value.

fchollet commented 1 year ago

CC @mattjj

mattjj commented 1 year ago

This is really a JAX bug IMO. Let me take a look.

GallagherCommaJack commented 1 year ago

@mattjj any progress on this?