keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.11k stars 19.36k forks source link

model.export on JAX saves all model weights as constants embedded in the graph #19132

Open martin-gorner opened 4 months ago

martin-gorner commented 4 months ago

Repro colab:

https://colab.research.google.com/drive/1QHg0zpFsJS6qfTDfBwts8KLule7B84RO?usp=sharing

Requested fix: when exporting a model through jax2tf, weights must be wrapped in tf.Variable before jax2tf is called.

Relevant jax2tf documentation: https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#saved-model-with-parameters: "Some special care is needed to ensure that the model parameters are not embedded as constants in the graph and are instead saved separately as variables. This is useful for two reasons: the parameters could be very large and exceed the 2GB limits of the GraphDef part of the SavedModel, or you may want to fine-tune the model and change the value of the parameters."

fchollet commented 3 months ago

@nkovela1 this is fixed, right?

nkovela1 commented 3 months ago

@fchollet Yes, this is fixed. Closing the issue, thanks!

google-ml-butler[bot] commented 3 months ago

Are you satisfied with the resolution of your issue? Yes No

martin-gorner commented 3 months ago

I don't see any change in the repro Colab. It is still saving all variables as constants in the graph as far as I can tell. And I did test with keras-nightly. See repro Colab.