Closed innat closed 6 months ago
Looks like jax is not happy wth this.
result = keras.ops.cond(
do_pad,
lambda: x[:, :depth, :height, :width, :],
lambda: x
)
However, this works
def pad_or_crop(x, do_pad, depth=3, height=64, width=96):
if do_pad:
return x[:, :depth, :height, :width, :]
else:
return x
x = jnp.ones((2, 4, 70, 70, 96))
result = pad_or_crop(x, do_pad=True)
But I'm skeptical to use if
in the call method for tensorflow backend, like
--> 176 if do_pad:
177 return x[:, :depth, :height, :width, :]
179 return x
OperatorNotAllowedInGraphError: Exception encountered when calling VideoSwinTransformerBlock.call().
Using a symbolic `tf.Tensor` as a Python `bool` is not allowed. You can attempt the following resolutions to the problem: If you are running in Graph mode, use Eager execution mode or decorate this function with @tf.function. If you are using AutoGraph, you can try decorating this function with @tf.function. If that does not work, then you may be using an unsupported feature or your source code may not be visible to AutoGraph. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/limitations.md#access-to-source-code for more information.
Arguments received by VideoSwinTransformerBlock.call():
• x=tf.Tensor(shape=(None, 4, 64, 64, 96), dtype=float32)
• mask_matrix=tf.Tensor(shape=(100, 196, 196), dtype=float32)
• training=True
This behavior is expected from keras.ops.cond
since dynamic shapes break XLA compilation. Since you are returning arrays with different shapes, Keras won't be able to compile the graph, hence the error.
For the if do_pad
case: this will only work if do_pad
is a Python bool/object at compile-time. If it's a tensor, then keras.ops.cond
is the right thing to do.
It looks like ops.cond is practically useless for jax backend for now. It also breaks the consistency with other backend.
It looks like ops.cond is practically useless for jax backend for now
Not really, it could be useful. Just that, JAX is a thin layer around XLA and, unlike TensorFlow, doesn't trigger recompilation when dynamics shapes are involved. So, we are bottlenecked by the backend to only support static shapes for all backends. keras.ops.cond
can still be useful if you want to compute something differently based on a runtime value of a tensor:
keras.ops.cond(cond, lambda x: x + 1, lambda x: x + 2)
Closing this, please feel free to reopen if you have any other questions!
System Info
Describe
Please check the following code where
cond
operation is used. It works in tensorflow and torch backend just fine but not in jax. Is it expected?But in jax backend, it gives