patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.14k stars 59 forks source link

numpy structured dtype support #204

Closed alexfanqi closed 3 months ago

alexfanqi commented 4 months ago

Hey,

This is a really useful library that saves me a lot of debugging time. Thanks for maintaining this all along!

I am wondering if it is possible to support numpy's structured array? https://numpy.org/doc/stable/user/basics.rec.html#structured-arrays I mainly use it to store multiple labels for a sample.

I did a small hacking to get it to work, but am unsure if this is safe.

@@ -166,6 +166,9 @@ class _MetaAbstractArray(type):
         if hasattr(obj.dtype, "type") and hasattr(obj.dtype.type, "__name__"):
             # JAX, numpy
             dtype = obj.dtype.type.__name__
+            # struct in numpy
+            if dtype == 'void' and obj.dtype is not np.dtype('V'):
+                dtype = str(obj.dtype)
         elif hasattr(obj.dtype, "as_numpy_dtype"):
             # TensorFlow
             dtype = obj.dtype.as_numpy_dtype.__name__

declare new AbstractDtype

annotation_t = np.dtype([('finger_count', np.uint8), ('lightness', np.int16), ('finger split', bool)])
class AnnotationT(AbstractDtype):
    dtypes = str(annotation_t)

assert(isinstance(np.array([(1, 1, False)], dtype=annotation_t), AnnotationT)) # pass
patrick-kidger commented 4 months ago

Something like this looks reasonable to me! I'd be happy to take a PR adding support for this.