Open aboubezari opened 2 days ago
Thanks for the report, could you add a code snippet that reproduces the bug?
Yep, this simple code will show how the squeeze operator does not handle this case properly:
import tensorflow as tf
import tf2jax
import numpy as np
@tf.function
def forward(a, b):
a_squeezed = tf.squeeze(a)
return a_squeezed + b
unsqueezed_shape = [1, 1, 1, 1, 1, 2]
squeezed_shape = [2]
a_in = np.zeros([1, 1, 1, 1, 1, 2])
b_in = np.zeros([2])
jax_func = tf2jax.convert_functional(forward, a_in, b_in)
jax_result = jax_func(a_in, b_in)
tf_result = forward(a_in, b_in)
assert jax_result.shape == tf_result.shape, f"Shapes do not match: {jax_result.shape} != {tf_result.shape}"
The assert trips in this code, the JAX version does nothing.
Thanks, what is the TF, JAX and tf2jax versions you are using?
Thanks, what is the TF, JAX and tf2jax versions you are using?
Our versions are:
tensorflow==2.12
jax==0.4.23
tf2jax==0.3.5
Ah, okay, I can't reproduce this on 0.3.6, can you upgrade?
The squeeze operator is not converted to JAX properly when the
axis
argument isNone
. This is converted to an empty list in JAX, which does nothing. It should propagate theNone
value and squeeze all value 1 axes.Offending code: https://github.com/google-deepmind/tf2jax/blob/main/tf2jax/_src/ops.py#L1882-L1888