Closed lbortolotti closed 5 months ago
import os
os.environ['KERAS_BACKEND'] = 'jax'
import keras
import numpy as np
window_size = 100 inference_window_size = None
X = np.random.random((1024, window_size, 1)) Y = np.random.random((1024, window_size, 1))
model = keras.Sequential([keras.layers.Dense(128, activation='relu'), keras.layers.Dense(64, activation='relu'), keras.layers.Dense(1, activation='relu'), ])
model.compile(optimizer=keras.optimizers.Adam(), loss=keras.losses.mean_squared_error)
model.fit(X, Y, batch_size=32)
model.summary() model.save("keras_model.keras")
2. Then change the backend to tensorflow
And reboot it again
This time we save it using `tf.saved_model.save`
import os os.environ['KERAS_BACKEND'] = 'tensorflow' import keras import tensorflow as tf
model = keras.saving.load_model("keras_model.keras") model.summary()
tf.saved_model.save(model, "saved_model_dir")
Hi @lbortolotti ,
Thanks for reporting the issue. I do replicate the issue with keras3.0.2 with tf.function
annotation.
If I remove tf.function
annotation for 'fn(x)'
and save the model it works fine. Please refer attached gist.
The ValueError: Invalid dtype: _DimExpr
generated from standardize_dtype()
indicates that the dtype from serialized config is passed as _DimExpr
which may be generated from Jax API itself.
I also tested with tf_keras(keras 2.x) and got a different error.Attached gist for same
AttributeError: 'DynamicJaxprTracer' object has no attribute '_keras_mask'
- Save in keras format
import os os.environ['KERAS_BACKEND'] = 'jax' import keras import numpy as np window_size = 100 inference_window_size = None X = np.random.random((1024, window_size, 1)) Y = np.random.random((1024, window_size, 1)) model = keras.Sequential([keras.layers.Dense(128, activation='relu'), keras.layers.Dense(64, activation='relu'), keras.layers.Dense(1, activation='relu'), ]) model.compile(optimizer=keras.optimizers.Adam(), loss=keras.losses.mean_squared_error) model.fit(X, Y, batch_size=32) model.summary() model.save("keras_model.keras")
2. Then change the backend to tensorflow And reboot it again This time we save it using `tf.saved_model.save`
import os os.environ['KERAS_BACKEND'] = 'tensorflow' import keras import tensorflow as tf model = keras.saving.load_model("keras_model.keras") model.summary() tf.saved_model.save(model, "saved_model_dir")
Sorry I should have been more explicit in my original issue . The reason I can't take this approach is that in my real use-case, I am using a fairly complex custom keras layer, for which I only have jax-backend implementation. So reloading the model with the tensorflow backend is not an option.
Similar issue with torch
backend. https://github.com/keras-team/keras/issues/19017
Similar issue with
torch
backend. #19017
I suspect that https://github.com/keras-team/keras/issues/19017 is impossible by design, it would require quite a lot of magic on keras' part!
This issue, however, leverages JAX's jax2tf, so in principle what I am trying to do should be possible.
One of the two issues I found (Cannot convert '(b, t, 1)' to a shape. Found invalid entry 't') is also clearly triggered by a part of keras that is unique to the JAX backend integration.
Hi @lbortolotti,
It seems that your workflow with jax2tf and TF SavedModel is exactly what the model.export()
API is useful for.
You can simply do model.export("my_model")
or have more customization with ExportArchive, here some reference tutorial on it: https://keras.io/api/models/model_saving_apis/export/
Hi @lbortolotti,
It seems that your workflow with jax2tf and TF SavedModel is exactly what the
model.export()
API is useful for. You can simply domodel.export("my_model")
or have more customization with ExportArchive, here some reference tutorial on it: https://keras.io/api/models/model_saving_apis/export/
Does model.export() automatically turn my JAX-backend-keras-model (with custom layers that use JAX ops) into a TF savedmodel? I'd be very surprised if it did? The documentation you link looks like it might only be valid for keras 2, or when using keras with the tensorflow backend?
Yes, model.export()
does turn a JAX-backend Keras model into a TF SavedModel. The API is not as mature for JAX-backend as it is for TF (hence why docs have not been updated yet), but it does work for the case specified in this issue, as I was able to run model.export("my_model")
: https://colab.sandbox.google.com/gist/nkovela1/9d8d3cd435611fd2fca45f2635d0cd30/jax_export.ipynb
Ah excellent! I see it here https://github.com/keras-team/keras/blob/c4dd4fab5bd9491b32b4ab1e360d290ad8e8a238/keras/export/export_lib.py#L267 It's roughly doing what I was trying to do "by hand", and manages to get the model exported. But the exported model is broken, because of this: https://github.com/keras-team/keras/blob/c4dd4fab5bd9491b32b4ab1e360d290ad8e8a238/keras/export/export_lib.py#L460 It's not correct to replace all "None" dimensions with "b", as this would imply that all the unknown dimensions have same length (ref. https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md)
Updated code to reproduce (much neater now, thank you very much for the pointers!)
import os
os.environ['KERAS_BACKEND'] = 'jax'
import keras
import numpy as np
window_size = 100
inference_batch_size = None
inference_window_size = None
X = np.random.random((1024, window_size, 1))
Y = np.random.random((1024, window_size, 1))
model = keras.Sequential([keras.layers.Input(batch_shape=(inference_batch_size , inference_window_size , 1)),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(64, activation='relu'),
keras.layers.Dense(1, activation='relu'),
])
model.compile(optimizer=keras.optimizers.Adam(), loss=keras.losses.mean_squared_error)
model.summary()
model.fit(X, Y, batch_size=32)
model.export(f'model_export.tf')
import tensorflow as tf
# load model
tf_model = tf.saved_model.load('model_export.tf')
# call tf_model
tf_model_fn = tf_model.signatures['serving_default']
# This works
out = tf_model_fn(tf.constant(X[:window_size, ...], dtype=tf.float32))
# And this does not
out = tf_model_fn(tf.constant(X, dtype=tf.float32))
Throws the error:
Detected at node XlaCallModule defined at (most recent call last):
<stack traces unavailable>
Module shape refinement failed: UNKNOWN: \venv\lib\site-packages\keras\src\layers\core\dense.py:112:0: error: 'stablehlo.add' op requires compatible types for all operands and results
\venv\lib\site-packages\keras\src\layers\core\dense.py:112:0: note: see current operation: %18 = "stablehlo.add"(%15, %17) : (tensor<1024x100x128xf32>, tensor<1024x1024x128xf32>) -> tensor<?x?x128xf32>
@lbortolotti Try this
spec_shape = str(spec_shape).replace("None", "b", 1).replace("None", "d")
if config.backend() == "jax":
if str(e) == "b" or str(e) == "d":
# JAX2TF tracing represents `None` dimensions as `b` or `d`
continue
These changes fix my specific example. I suppose the "proper" fix would be to:
polymorphic_shape
to the user (or is there already a way in keras to specify that two None dimensions should be e.g. equal to each other?)standardize_shape
when the backend is jax, as opposed to specifically only b or d or whatever else.Thanks!
Luca
I'm also facing AttributeError: 'DynamicJaxprTracer' object has no attribute '_keras_mask'
. Any hack to mitigate this?
Hi @lbortolotti ,
import keras_core as keras
.
keras_core
is deprecated, please use keras
instead.
The latest version of Keras (3.1.2) supports exactly what you want out of the box using ExportArchive
. It will detect the None
dimension and give it a name in the polymorphic_shapes
given to jax2tf
.
import os
os.environ['KERAS_BACKEND'] = 'jax'
import keras
import numpy as np
import tensorflow as tf
window_size = 100
inference_window_size = None
X = np.random.random((1024, window_size, 1))
Y = np.random.random((1024, window_size, 1))
model = keras.Sequential([keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(64, activation='relu'),
keras.layers.Dense(1, activation='relu'),
])
model.compile(optimizer=keras.optimizers.Adam(), loss=keras.losses.mean_squared_error)
model.fit(X, Y, batch_size=32)
model.summary()
archive = keras.export.ExportArchive()
archive.track(model)
signature = [tf.TensorSpec(shape=(None, inference_window_size, 1), dtype=tf.float32)]
archive.add_endpoint(
"call",
model.__call__,
input_signature=signature
)
archive.write_out("my_export")
revived_model = tf.saved_model.load("my_export")
revived_model.call(np.random.random((1024, 50, 1)))
Does model.export() automatically turn my JAX-backend-keras-model (with custom layers that use JAX ops) into a TF savedmodel?
Yes, it uses jax2tf
under the hood.
It's not correct to replace all "None" dimensions with "b"
This is fixed.
Somehow expose polymorphic_shape to the user (or is there already a way in keras to specify that two None dimensions should be e.g. equal to each other?)
Both options were added. If you have None
s in the signature, it will create a polymorphic_shapes
with individual names for each None
dimension, meaning they're uncorrelated. If you want more control, you can pass polymorphic_shapes
in the jax2tf
argument to override the default. Here is an example of how to do this: https://github.com/keras-team/keras/blob/master/keras/export/export_lib_test.py#L274
Sorry I should have been more explicit in my original issue . The reason I can't take this approach is that in my real use-case, I am using a fairly complex custom keras layer, for which I only have jax-backend implementation. So reloading the model with the tensorflow backend is not an option.
I'm a bit confused about this. This export approach produces a Tensorflow graph that you can only reload with Tensorflow. However, the what Keras "save" approach (model.save("keras_model.keras")
) produces can be reloaded with Keras with the JAX backend. So if that is what you want, you definitely want to go with the "save" approach.
@chococigar @SuryanarayanaY ,
AttributeError: 'DynamicJaxprTracer' object has no attribute '_keras_mask'
Keras 2.0 does not work with JAX, it is Tensorflow only. You'll have to switch to Keras 3.
@innat ,
The torch backend is not supported, there is no way to make the Tensorflow export work with torch.
@innat ,
Thanks! Interesting. So when we get PyTorch XLA to work with Keras 3, then this will be possible.
I would like to develop a model using keras with the JAX backend, but then export it as a TensorFlow savedmodel.
This was already broken with keras-core==0.1.7, and is even-more-broken in keras==3.0.2.
Code to reproduce
Package versions - keras-core
tensorflow==2.15 jax[cpu]==0.4.16 keras-core==0.1.7
Throws
Looking at the keras code, this part seems to assume that the only polymorphic shape "marker" is the letter b, but in fact it can be any letter, as far as I understood from the jax2tf documentation. For example, in my case, I have an unknown batch dimension (b) and time dimension (t), and these are not the same to each other (which I think JAX would imply if I were to use b to identify both of them).
I had originally opened this as a JAX bug (https://github.com/google/jax/issues/17803) but I suspect it's actually a Keras bug.
Looking at the keras3 code, it seems this bug is still present. There unfortunately seems to be an additional one as well, which breaks earlier, and which follows:
Package versions - keras3
tensorflow==2.15 jax[cpu]==0.4.16 keras-core==3.0.2
Throws
For this latter issue I don't have immediate ideas...
Let me know if you need any more information.
Thanks.