patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.05k stars 49 forks source link

Replace ndim with len(shape) #193

Closed dziulek closed 4 months ago

dziulek commented 4 months ago

Hi!

I use tensorflow==2.9.1 and I came across a situation where tf.Tensor could not have ndim attribute at some point of execution. Moreover, what I found out is that the problem does not occur in tensorflow==2.16. Here you have a simple code to reproduce the error.

import tensorflow as tf
import jaxtyping as jax
import beartype

@tf.function()
@jax.jaxtyped(typechecker=beartype.beartype)
def map_function(tensor: jax.Float[tf.Tensor, "h w c"]) -> jax.Float[tf.Tensor, "h w c"]:
    return 1 - tensor

def main():

    tf.config.run_functions_eagerly(True)
    tf.data.experimental.enable_debug_mode()

    dataset = tf.data.Dataset.from_tensor_slices(tensors=tf.random.uniform((100,30,30,3)))
    dataset = dataset.map(map_function)

if __name__ == "__main__":

    main()

This is what i get AttributeError: 'Tensor' object has no attribute 'ndim'

Unfortunately I'm pinned to the 2.9.1 version of tensorflow. Let me know what do you think.

patrick-kidger commented 4 months ago

LGTM! No idea why they don't have an ndim attribute, but happy to support this use-case.