google-deepmind / tf2jax

Apache License 2.0
105 stars 11 forks source link

[Bug report] Issue in conversion for squeeze op #226

Open aboubezari opened 2 days ago

aboubezari commented 2 days ago

The squeeze operator is not converted to JAX properly when the axis argument is None. This is converted to an empty list in JAX, which does nothing. It should propagate the None value and squeeze all value 1 axes.

Offending code: https://github.com/google-deepmind/tf2jax/blob/main/tf2jax/_src/ops.py#L1882-L1888

shaobohou commented 1 day ago

Thanks for the report, could you add a code snippet that reproduces the bug?

aboubezari commented 1 day ago

Yep, this simple code will show how the squeeze operator does not handle this case properly:

import tensorflow as tf
import tf2jax
import numpy as np

@tf.function
def forward(a, b):
    a_squeezed = tf.squeeze(a)
    return a_squeezed + b

unsqueezed_shape = [1, 1, 1, 1, 1, 2]
squeezed_shape = [2]
a_in = np.zeros([1, 1, 1, 1, 1, 2])
b_in = np.zeros([2])

jax_func = tf2jax.convert_functional(forward, a_in, b_in)

jax_result = jax_func(a_in, b_in)
tf_result = forward(a_in, b_in)

assert jax_result.shape == tf_result.shape, f"Shapes do not match: {jax_result.shape} != {tf_result.shape}"

The assert trips in this code, the JAX version does nothing.

shaobohou commented 1 day ago

Thanks, what is the TF, JAX and tf2jax versions you are using?

aboubezari commented 1 day ago

Thanks, what is the TF, JAX and tf2jax versions you are using?

Our versions are:

tensorflow==2.12
jax==0.4.23
tf2jax==0.3.5
shaobohou commented 1 day ago

Ah, okay, I can't reproduce this on 0.3.6, can you upgrade?