I use tensorflow==2.9.1 and I came across a situation where tf.Tensor could not have ndim attribute at some point of execution. Moreover, what I found out is that the problem does not occur in tensorflow==2.16.
Here you have a simple code to reproduce the error.
import tensorflow as tf
import jaxtyping as jax
import beartype
@tf.function()
@jax.jaxtyped(typechecker=beartype.beartype)
def map_function(tensor: jax.Float[tf.Tensor, "h w c"]) -> jax.Float[tf.Tensor, "h w c"]:
return 1 - tensor
def main():
tf.config.run_functions_eagerly(True)
tf.data.experimental.enable_debug_mode()
dataset = tf.data.Dataset.from_tensor_slices(tensors=tf.random.uniform((100,30,30,3)))
dataset = dataset.map(map_function)
if __name__ == "__main__":
main()
This is what i get
AttributeError: 'Tensor' object has no attribute 'ndim'
Unfortunately I'm pinned to the 2.9.1 version of tensorflow. Let me know what do you think.
Hi!
I use
tensorflow==2.9.1
and I came across a situation wheretf.Tensor
could not havendim
attribute at some point of execution. Moreover, what I found out is that the problem does not occur intensorflow==2.16
. Here you have a simple code to reproduce the error.This is what i get
AttributeError: 'Tensor' object has no attribute 'ndim'
Unfortunately I'm pinned to the 2.9.1 version of
tensorflow
. Let me know what do you think.