patrick-kidger / jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
Other
1.14k stars 59 forks source link

Provide a way to get error diagnostics out of isinstance checks #167

Open reinerp opened 8 months ago

reinerp commented 8 months ago

The assert isinstance(...) pattern prints a mostly useless message, just "AssertionError" without explanation. Would it be possible to expose an assertIsInstance(x, ty) API that prints expected versus actual, like we get for errors in the function arguments?

patrick-kidger commented 8 months ago

So this actually dovetails well with another feature I would like to add.

Beartype now supports checking for an __instancecheck_str__ method. (Beartype release notes, relevant jaxtyping discussion thread.)

Once this is added, then your use-case could be easily supported via assert isinstance(x, ty), ty.__instancecheck_str__(x).

This shouldn't be too much work to add. Discussing some jaxtyping internals briefly, the plan is basically to rewrite things from

class _MetaAbstractArray(type):
    def __instancecheck__(cls, obj):
        if something_bad:
            return False
        ...
        return True

to

class _MetaAbstractArray(type):
    def __instancecheck__(cls, obj):
        return cls.__instancecheck_str__(obj) != ""

    def __instancecheck_str__(cls, obj):
        if something_bad:
            return "something bad!" + _exc_shape_info(get_shape_memo())
        ...
        return ""

which would give both specifically how we failed the check (more than we get at the moment under any circumstances!) and all the extra information about the current values of bindings (via _exc_shape_info).

I'd be happy to guide a pull request on this; else I'm hoping to get around to this myself in the near future.