keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.67k stars 19.42k forks source link

JAX backend - exporting to TF savedmodel via JAX2TF errors (Cannot convert '(b, t, 1)' to a shape. Found invalid entry 't' AND ValueError: Invalid dtype: _DimExpr) #19048

Closed lbortolotti closed 5 months ago

lbortolotti commented 8 months ago

I would like to develop a model using keras with the JAX backend, but then export it as a TensorFlow savedmodel.

This was already broken with keras-core==0.1.7, and is even-more-broken in keras==3.0.2.

Code to reproduce

import os
os.environ['KERAS_BACKEND'] = 'jax'
import keras_core as keras # import keras directly to use keras==3.0.2
import numpy as np

window_size = 100
inference_window_size = None

X = np.random.random((1024, window_size, 1))
Y = np.random.random((1024, window_size, 1))

model = keras.Sequential([keras.layers.Dense(128, activation='relu'),
                          keras.layers.Dense(64, activation='relu'),
                          keras.layers.Dense(1, activation='relu'),
                          ])

model.compile(optimizer=keras.optimizers.Adam(), loss=keras.losses.mean_squared_error)

model.fit(X, Y, batch_size=32)

model.summary()

import tensorflow as tf

module = tf.Module()
module.model = model

if 'KERAS_BACKEND' in os.environ and os.environ['KERAS_BACKEND'] == 'jax':
    from jax.experimental import jax2tf

    poly_shape_window_size = 't' if inference_window_size is None else inference_window_size
    poly_shape = jax2tf.PolyShape('b', poly_shape_window_size, 1)

    print(f'poly_shape: {poly_shape}')

    predict_fn = jax2tf.convert(module.model.call,
                                native_serialization_platforms=('cpu',),
                                polymorphic_shapes=str(poly_shape),
                                )
else:
    predict_fn = module.model.call

@tf.function(autograph=False, input_signature=[tf.TensorSpec(shape=(None, inference_window_size, 1), dtype=tf.float32)])
def fn(x):
    outputs = predict_fn(x)
    return outputs

module.fn = fn

tf.saved_model.save(module, 'model.tf', signatures={'serving_default': module.fn})

Package versions - keras-core

tensorflow==2.15 jax[cpu]==0.4.16 keras-core==0.1.7

Throws

File "repro_github.py", line 47, in fn
    outputs = predict_fn(x)
  File "venv_repro\lib\site-packages\jax\experimental\jax2tf\jax2tf.py", line 402, in converted_fun_tf
    impl.before_conversion()
  File "venv_repro\lib\site-packages\jax\experimental\jax2tf\jax2tf.py", line 506, in before_conversion
    self.exported = export.export(
  File "venv_repro\lib\site-packages\jax\experimental\export\export.py", line 422, in do_export
    lowered = wrapped_fun_jax.lower(
  File "venv_repro\lib\site-packages\keras_core\src\models\sequential.py", line 187, in call
    return self._functional.call(inputs, training=training, mask=mask)
  File "venv_repro\lib\site-packages\keras_core\src\models\functional.py", line 188, in call
    outputs = self._run_through_graph(
  File "venv_repro\lib\site-packages\keras_core\src\ops\function.py", line 140, in _run_through_graph
    outputs = operation_fn(node.operation)(*args, **kwargs)
  File "venv_repro\lib\site-packages\keras_core\src\models\functional.py", line 574, in call
    return operation(*args, **kwargs)
  File "venv_repro\lib\site-packages\keras_core\src\utils\traceback_utils.py", line 123, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "venv_repro\lib\site-packages\keras_core\src\backend\common\variables.py", line 437, in standardize_shape
    raise ValueError(
ValueError: Cannot convert '(b, t, 1)' to a shape. Found invalid entry 't'.

Looking at the keras code, this part seems to assume that the only polymorphic shape "marker" is the letter b, but in fact it can be any letter, as far as I understood from the jax2tf documentation. For example, in my case, I have an unknown batch dimension (b) and time dimension (t), and these are not the same to each other (which I think JAX would imply if I were to use b to identify both of them).

I had originally opened this as a JAX bug (https://github.com/google/jax/issues/17803) but I suspect it's actually a Keras bug.

Looking at the keras3 code, it seems this bug is still present. There unfortunately seems to be an additional one as well, which breaks earlier, and which follows:

Package versions - keras3

tensorflow==2.15 jax[cpu]==0.4.16 keras-core==3.0.2

Throws

  File "repro_github.py", line 47, in fn
    outputs = predict_fn(x)
  File "venv\lib\site-packages\jax\experimental\jax2tf\jax2tf.py", line 402, in converted_fun_tf
    impl.before_conversion()
  File "venv\lib\site-packages\jax\experimental\jax2tf\jax2tf.py", line 500, in before_conversion
    self.exported = export.export(
  File "venv\lib\site-packages\jax\experimental\export\export.py", line 470, in do_export
    lowered = wrapped_fun_jax.lower(
  File "venv\lib\site-packages\keras\src\models\sequential.py", line 203, in call
    return self._functional.call(inputs, training=training, mask=mask)
  File "venv\lib\site-packages\keras\src\models\functional.py", line 188, in call
    outputs = self._run_through_graph(
  File "venv\lib\site-packages\keras\src\ops\function.py", line 153, in _run_through_graph
    outputs = operation_fn(node.operation)(*args, **kwargs)
  File "venv\lib\site-packages\keras\src\models\functional.py", line 572, in call
    return operation(*args, **kwargs)
  File "venv\lib\site-packages\keras\src\utils\traceback_utils.py", line 123, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "venv\lib\site-packages\keras\src\backend\common\variables.py", line 396, in standardize_dtype
    raise ValueError(f"Invalid dtype: {dtype}")
ValueError: Invalid dtype: _DimExpr

For this latter issue I don't have immediate ideas...

Let me know if you need any more information.

Thanks.

dugujiujian1999 commented 8 months ago
  1. Save in keras format
    
    import os
    os.environ['KERAS_BACKEND'] = 'jax'
    import keras
    import numpy as np

window_size = 100 inference_window_size = None

X = np.random.random((1024, window_size, 1)) Y = np.random.random((1024, window_size, 1))

model = keras.Sequential([keras.layers.Dense(128, activation='relu'), keras.layers.Dense(64, activation='relu'), keras.layers.Dense(1, activation='relu'), ])

model.compile(optimizer=keras.optimizers.Adam(), loss=keras.losses.mean_squared_error)

model.fit(X, Y, batch_size=32)

model.summary() model.save("keras_model.keras")

2. Then change the backend to tensorflow
And reboot it again
This time we save it using `tf.saved_model.save`

import os os.environ['KERAS_BACKEND'] = 'tensorflow' import keras import tensorflow as tf

model = keras.saving.load_model("keras_model.keras") model.summary()

tf.saved_model.save(model, "saved_model_dir")

SuryanarayanaY commented 8 months ago

Hi @lbortolotti ,

Thanks for reporting the issue. I do replicate the issue with keras3.0.2 with tf.function annotation. If I remove tf.function annotation for 'fn(x)' and save the model it works fine. Please refer attached gist.

The ValueError: Invalid dtype: _DimExpr generated from standardize_dtype() indicates that the dtype from serialized config is passed as _DimExpr which may be generated from Jax API itself.

SuryanarayanaY commented 8 months ago

I also tested with tf_keras(keras 2.x) and got a different error.Attached gist for same

AttributeError: 'DynamicJaxprTracer' object has no attribute '_keras_mask'

lbortolotti commented 8 months ago
  1. Save in keras format
import os
os.environ['KERAS_BACKEND'] = 'jax'
import keras
import numpy as np

window_size = 100
inference_window_size = None

X = np.random.random((1024, window_size, 1))
Y = np.random.random((1024, window_size, 1))

model = keras.Sequential([keras.layers.Dense(128, activation='relu'),
                          keras.layers.Dense(64, activation='relu'),
                          keras.layers.Dense(1, activation='relu'),
                          ])

model.compile(optimizer=keras.optimizers.Adam(), loss=keras.losses.mean_squared_error)

model.fit(X, Y, batch_size=32)

model.summary()
model.save("keras_model.keras")
2. Then change the backend to tensorflow
   And reboot it again
   This time we save it using `tf.saved_model.save`
import os
os.environ['KERAS_BACKEND'] = 'tensorflow'
import keras
import tensorflow as tf

model = keras.saving.load_model("keras_model.keras")
model.summary()

tf.saved_model.save(model, "saved_model_dir")

Sorry I should have been more explicit in my original issue . The reason I can't take this approach is that in my real use-case, I am using a fairly complex custom keras layer, for which I only have jax-backend implementation. So reloading the model with the tensorflow backend is not an option.

innat commented 8 months ago

Similar issue with torch backend. https://github.com/keras-team/keras/issues/19017

lbortolotti commented 8 months ago

Similar issue with torch backend. #19017

I suspect that https://github.com/keras-team/keras/issues/19017 is impossible by design, it would require quite a lot of magic on keras' part!

This issue, however, leverages JAX's jax2tf, so in principle what I am trying to do should be possible.

One of the two issues I found (Cannot convert '(b, t, 1)' to a shape. Found invalid entry 't') is also clearly triggered by a part of keras that is unique to the JAX backend integration.

nkovela1 commented 8 months ago

Hi @lbortolotti,

It seems that your workflow with jax2tf and TF SavedModel is exactly what the model.export() API is useful for. You can simply do model.export("my_model") or have more customization with ExportArchive, here some reference tutorial on it: https://keras.io/api/models/model_saving_apis/export/

lbortolotti commented 8 months ago

Hi @lbortolotti,

It seems that your workflow with jax2tf and TF SavedModel is exactly what the model.export() API is useful for. You can simply do model.export("my_model") or have more customization with ExportArchive, here some reference tutorial on it: https://keras.io/api/models/model_saving_apis/export/

Does model.export() automatically turn my JAX-backend-keras-model (with custom layers that use JAX ops) into a TF savedmodel? I'd be very surprised if it did? The documentation you link looks like it might only be valid for keras 2, or when using keras with the tensorflow backend?

nkovela1 commented 8 months ago

Yes, model.export() does turn a JAX-backend Keras model into a TF SavedModel. The API is not as mature for JAX-backend as it is for TF (hence why docs have not been updated yet), but it does work for the case specified in this issue, as I was able to run model.export("my_model"): https://colab.sandbox.google.com/gist/nkovela1/9d8d3cd435611fd2fca45f2635d0cd30/jax_export.ipynb

lbortolotti commented 8 months ago

Ah excellent! I see it here https://github.com/keras-team/keras/blob/c4dd4fab5bd9491b32b4ab1e360d290ad8e8a238/keras/export/export_lib.py#L267 It's roughly doing what I was trying to do "by hand", and manages to get the model exported. But the exported model is broken, because of this: https://github.com/keras-team/keras/blob/c4dd4fab5bd9491b32b4ab1e360d290ad8e8a238/keras/export/export_lib.py#L460 It's not correct to replace all "None" dimensions with "b", as this would imply that all the unknown dimensions have same length (ref. https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md)

Updated code to reproduce (much neater now, thank you very much for the pointers!)

import os
os.environ['KERAS_BACKEND'] = 'jax'
import keras
import numpy as np

window_size = 100
inference_batch_size = None
inference_window_size = None

X = np.random.random((1024, window_size, 1))
Y = np.random.random((1024, window_size, 1))

model = keras.Sequential([keras.layers.Input(batch_shape=(inference_batch_size , inference_window_size , 1)),
                          keras.layers.Dense(128, activation='relu'),
                          keras.layers.Dense(64, activation='relu'),
                          keras.layers.Dense(1, activation='relu'),
                          ])

model.compile(optimizer=keras.optimizers.Adam(), loss=keras.losses.mean_squared_error)

model.summary()

model.fit(X, Y, batch_size=32)

model.export(f'model_export.tf')

import tensorflow as tf

# load model
tf_model = tf.saved_model.load('model_export.tf')

# call tf_model
tf_model_fn = tf_model.signatures['serving_default']

# This works
out = tf_model_fn(tf.constant(X[:window_size, ...], dtype=tf.float32))

# And this does not
out = tf_model_fn(tf.constant(X, dtype=tf.float32))

Throws the error:

Detected at node XlaCallModule defined at (most recent call last):
<stack traces unavailable>
Module shape refinement failed: UNKNOWN: \venv\lib\site-packages\keras\src\layers\core\dense.py:112:0: error: 'stablehlo.add' op requires compatible types for all operands and results
\venv\lib\site-packages\keras\src\layers\core\dense.py:112:0: note: see current operation: %18 = "stablehlo.add"(%15, %17) : (tensor<1024x100x128xf32>, tensor<1024x1024x128xf32>) -> tensor<?x?x128xf32>
dugujiujian1999 commented 8 months ago

@lbortolotti Try this

spec_shape = str(spec_shape).replace("None", "b", 1).replace("None", "d")

https://github.com/keras-team/keras/blob/c4dd4fab5bd9491b32b4ab1e360d290ad8e8a238/keras/backend/common/variables.py#L422-L424

        if config.backend() == "jax":
            if str(e) == "b" or str(e) == "d":
            # JAX2TF tracing represents `None` dimensions as `b` or `d`
                continue
lbortolotti commented 8 months ago

These changes fix my specific example. I suppose the "proper" fix would be to:

  1. Somehow expose polymorphic_shape to the user (or is there already a way in keras to specify that two None dimensions should be e.g. equal to each other?)
  2. Accept arbitrary strings to represent None dimensions in standardize_shape when the backend is jax, as opposed to specifically only b or d or whatever else.

Thanks!

Luca

chococigar commented 7 months ago

I'm also facing AttributeError: 'DynamicJaxprTracer' object has no attribute '_keras_mask'. Any hack to mitigate this?

hertschuh commented 5 months ago

Hi @lbortolotti ,

import keras_core as keras.

keras_core is deprecated, please use keras instead.

The latest version of Keras (3.1.2) supports exactly what you want out of the box using ExportArchive. It will detect the None dimension and give it a name in the polymorphic_shapes given to jax2tf.

import os
os.environ['KERAS_BACKEND'] = 'jax'
import keras
import numpy as np
import tensorflow as tf

window_size = 100
inference_window_size = None

X = np.random.random((1024, window_size, 1))
Y = np.random.random((1024, window_size, 1))

model = keras.Sequential([keras.layers.Dense(128, activation='relu'),
                          keras.layers.Dense(64, activation='relu'),
                          keras.layers.Dense(1, activation='relu'),
                          ])

model.compile(optimizer=keras.optimizers.Adam(), loss=keras.losses.mean_squared_error)

model.fit(X, Y, batch_size=32)

model.summary()

archive = keras.export.ExportArchive()
archive.track(model)
signature = [tf.TensorSpec(shape=(None, inference_window_size, 1), dtype=tf.float32)]
archive.add_endpoint(
    "call",
    model.__call__,
    input_signature=signature
)
archive.write_out("my_export")

revived_model = tf.saved_model.load("my_export")
revived_model.call(np.random.random((1024, 50, 1)))

Does model.export() automatically turn my JAX-backend-keras-model (with custom layers that use JAX ops) into a TF savedmodel?

Yes, it uses jax2tf under the hood.

It's not correct to replace all "None" dimensions with "b"

This is fixed.

Somehow expose polymorphic_shape to the user (or is there already a way in keras to specify that two None dimensions should be e.g. equal to each other?)

Both options were added. If you have Nones in the signature, it will create a polymorphic_shapes with individual names for each None dimension, meaning they're uncorrelated. If you want more control, you can pass polymorphic_shapes in the jax2tf argument to override the default. Here is an example of how to do this: https://github.com/keras-team/keras/blob/master/keras/export/export_lib_test.py#L274

Sorry I should have been more explicit in my original issue . The reason I can't take this approach is that in my real use-case, I am using a fairly complex custom keras layer, for which I only have jax-backend implementation. So reloading the model with the tensorflow backend is not an option.

I'm a bit confused about this. This export approach produces a Tensorflow graph that you can only reload with Tensorflow. However, the what Keras "save" approach (model.save("keras_model.keras")) produces can be reloaded with Keras with the JAX backend. So if that is what you want, you definitely want to go with the "save" approach.

@chococigar @SuryanarayanaY ,

AttributeError: 'DynamicJaxprTracer' object has no attribute '_keras_mask'

Keras 2.0 does not work with JAX, it is Tensorflow only. You'll have to switch to Keras 3.

@innat ,

The torch backend is not supported, there is no way to make the Tensorflow export work with torch.

innat commented 5 months ago

@hertschuh https://github.com/keras-team/keras/issues/19017#issuecomment-1889271772

hertschuh commented 5 months ago

@innat ,

Thanks! Interesting. So when we get PyTorch XLA to work with Keras 3, then this will be possible.