keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.68k stars 19.42k forks source link

`keras.ops.cond` issue in jax #19379

Closed innat closed 6 months ago

innat commented 6 months ago

System Info

keras: 3.0.5

Describe

Please check the following code where cond operation is used. It works in tensorflow and torch backend just fine but not in jax. Is it expected?

import os, warnings
os.environ["KERAS_BACKEND"] = "tensorflow " # tensorflow jax torch

do_pad = True 
x = keras.ops.ones(shape=(2, 4, 70, 70, 96))
depth = 3
height = 64
width = 96
result = keras.ops.cond(
    do_pad,
    lambda: x[:, :depth, :height, :width, :],
    lambda: x
)
result.shape

# tensorflow backend
TensorShape([2, 3, 64, 70, 96])

# torch backend
torch.Size([2, 3, 64, 70, 96])

But in jax backend, it gives

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[3], line 6
      4 height = 64
      5 width = 96
----> 6 result = keras.ops.cond(
      7     do_pad,
      8     lambda: x[:, :depth, :height, :width, :],
      9     lambda: x
     10 )
     11 result.shape

File /opt/conda/lib/python3.10/site-packages/keras/src/ops/core.py:607, in cond(pred, true_fn, false_fn)
    595 @keras_export("keras.ops.cond")
    596 def cond(pred, true_fn, false_fn):
    597     """Conditionally applies `true_fn` or `false_fn`.
    598 
    599     Args:
   (...)
    605         The output of either `true_fn` or `false_fn` depending on pred.
    606     """
--> 607     return Cond()(pred, true_fn, false_fn)

File /opt/conda/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py:123, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    120     filtered_tb = _process_traceback_frames(e.__traceback__)
    121     # To get the full stack trace, call:
    122     # `keras.config.disable_traceback_filtering()`
--> 123     raise e.with_traceback(filtered_tb) from None
    124 finally:
    125     del filtered_tb

File /opt/conda/lib/python3.10/site-packages/keras/src/ops/core.py:546, in Cond.compute_output_spec(self, pred, true_fn, false_fn)
    544 false_fn_spec = backend.compute_output_spec(call_fn, false_fn)
    545 if not self._check_output_spec(true_fn_spec, false_fn_spec):
--> 546     raise ValueError(
    547         "`true_fn` and `false_fn` should return outputs "
    548         "of the same kind (struct, dtype and shape). "
    549         f"Got {true_fn_spec} and {false_fn_spec} instead."
    550     )
    551 return true_fn_spec

ValueError: Exception encountered when calling Cond.call().

`true_fn` and `false_fn` should return outputs of the same kind (struct, dtype and shape). Got <KerasTensor shape=(2, 3, 64, 70, 96), dtype=float32, sparse=False, name=keras_tensor> and <KerasTensor shape=(2, 4, 70, 70, 96), dtype=float32, sparse=False, name=keras_tensor_1> instead.

Arguments received by Cond.call():
  • args=('True', '<function <lambda> at 0x7c8c4ed1c9d0>', '<function <lambda> at 0x7c8c4ed1ce50>')
  • kwargs=<class 'inspect._empty'>
innat commented 6 months ago

Looks like jax is not happy wth this.

result = keras.ops.cond(
    do_pad,
    lambda: x[:, :depth, :height, :width, :],
    lambda: x
)

However, this works

def pad_or_crop(x, do_pad, depth=3, height=64, width=96):
    if do_pad:
        return x[:, :depth, :height, :width, :]
    else:
        return x

x = jnp.ones((2, 4, 70, 70, 96))
result = pad_or_crop(x, do_pad=True)

But I'm skeptical to use if in the call method for tensorflow backend, like

--> 176         if do_pad:
    177             return x[:, :depth, :height, :width, :]
    179         return x

OperatorNotAllowedInGraphError: Exception encountered when calling VideoSwinTransformerBlock.call().

Using a symbolic `tf.Tensor` as a Python `bool` is not allowed. You can attempt the following resolutions to the problem: If you are running in Graph mode, use Eager execution mode or decorate this function with @tf.function. If you are using AutoGraph, you can try decorating this function with @tf.function. If that does not work, then you may be using an unsupported feature or your source code may not be visible to AutoGraph. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/limitations.md#access-to-source-code for more information.

Arguments received by VideoSwinTransformerBlock.call():
  • x=tf.Tensor(shape=(None, 4, 64, 64, 96), dtype=float32)
  • mask_matrix=tf.Tensor(shape=(100, 196, 196), dtype=float32)
  • training=True
tirthasheshpatel commented 6 months ago

This behavior is expected from keras.ops.cond since dynamic shapes break XLA compilation. Since you are returning arrays with different shapes, Keras won't be able to compile the graph, hence the error.

For the if do_pad case: this will only work if do_pad is a Python bool/object at compile-time. If it's a tensor, then keras.ops.cond is the right thing to do.

innat commented 6 months ago

It looks like ops.cond is practically useless for jax backend for now. It also breaks the consistency with other backend.

tirthasheshpatel commented 6 months ago

It looks like ops.cond is practically useless for jax backend for now

Not really, it could be useful. Just that, JAX is a thin layer around XLA and, unlike TensorFlow, doesn't trigger recompilation when dynamics shapes are involved. So, we are bottlenecked by the backend to only support static shapes for all backends. keras.ops.cond can still be useful if you want to compute something differently based on a runtime value of a tensor:

keras.ops.cond(cond, lambda x: x + 1, lambda x: x + 2)
tirthasheshpatel commented 6 months ago

Closing this, please feel free to reopen if you have any other questions!