patrick-kidger / torchtyping

Type annotations and dynamic checking for a tensor's shape, dtype, names, etc.
Apache License 2.0
1.39k stars 33 forks source link

Support for an or condition, or other way to accomplish this pattern? #30

Open SeanEaster opened 2 years ago

SeanEaster commented 2 years ago

n00b to this very cool project, looking to enforce a broadcast-ability pattern where a dimension in one tensor either matches or can be broadcast to (i.e. equals 1) a dimension in another tensor.

@typeguard.typechecked
def mwe(
    x: torchtyping.TensorType[
        ...,
        "foo",
        "bar", # How do we make this "match bar from arg_b or equal 1"?
    ],
    y: torchtyping.TensorType[
        "bar",
    ]) -> torch typing.TensorType[...,"foo","bar"]:
    return x * y

Am I missing an existing way to do this in torchtyping out of the box? Would this need an extension?

patrick-kidger commented 2 years ago

Yep, this is possible: it can be done with Union[TensorType[..., "foo", 1], TensorType[..., "foo", "bar"]].

One caveat -- switching the order of the elements of the Union will cause a spurious failure. (The 1 case has to go before the "bar" case.) That's a bug, really, but probably a thorny one to fix.


Incidentally broadcasting is a common enough operation that I'd be willing to accept a PR making this neater than the Union solution. Essentially all that's needed is some syntax like TensorType["foo": OrOne] which TensorType.__class_getitem__ expands out into a Union of the form given above.

This should be pretty simple so it'd be a good first issue for anyone looking to contribute.