pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.25k stars 22.12k forks source link

Statically checked tensor shapes #26889

Open vsiles opened 4 years ago

vsiles commented 4 years ago

🚀 Feature

(Long term request, mostly to gather feedback on our current experiment) We would like to extend the Tensor class with the description of its dimensions shape and values, to enable static checking of tensor operations w.r.t. to shapes (e.g. detecting an illegal call to Tensor.mm statically rather than with a runtime exception). The new syntax would look like Tensor[int32, dim0, dim1, dim2]

Motivation

At the moment, PyTorch is more or less untyped: everything is a Tensor and there is no information whatsoever on the dimensions of these tensors. By being more descriptive, we could statically check (aka at compile time, rather than at runtime) that tensor operations are executed on arguments with the right shape. For example, we could catch this kind of errors during type checking that:

T0 : Tensor[int32, D3, D4] = ...
T1 : Tensor[int32, D5,  D5] = ...
T2 = Tensor.mm(T0, T1) # mismatch: D4 != D5

This could save lots of computer time (less runtime errors) along with debugging time.

Pitch

Recently Pyre added an experimental support for variadic type variable which allows to describe the shape of tensors. It allowed me to write some initial stubs for PyTorch where the tensor type has a documented shape. This shape can be check statically by Pyre to prevent most mis-usage of PyTorch operators.

As an example, I took this script and translated it into this typed version. The main stubs are located here.

We already got some very positive feedback for the Python types community last Friday during Facebook MPK's Python meetup. So now I'm asking the PyTorch community :D

Known limitations: this is an early draft of the project, so we can't type everything at the moment. For example we only support simple broadcasting (like Tensor.__add__ when the rhs is scalar. Nothing for Tensor.matmul yet). Also there are some functions that just can't be statically check (like Tensor.cat) and which require manual annotation.

Alternatives

Current known alternative are all runtime check (like the Named Tensor proposal) which address the same problem, but still at runtime, which could be less efficient when programs run for several hours/days.

Additional context

I don't expect PyTorch to migrate to this solution right now, I'm gathering feedback on the experiment to see where to go next. Our next stop is to support broadcasting, and I would gladly have some direction on which killer feature we should try to support next.

soumith commented 4 years ago

cc: @apaszke @malmaud @t-vi @zdevito as possible folks of interest

gchanan commented 4 years ago

What's the limitation with torch.cat? Is it because you don't have type(s) for "Tuple of Tensors of size X, Y, Z"?

Another example of this problem is torch.nonzero, whose shape depends on the runtime values.

vsiles commented 4 years ago

If I understand correctly, the statement of cat is that all Tensors in the tuple must have the same shape but on one of the dimensions (the one provided as input). So we could type partially using tons of overloads for the input, but the output size seems problematic as it is the sum of the input dimensions.

vsiles commented 4 years ago

@gchanan Now that you mention it, maybe we could do something for cat with the intvar prototype too. I have to check it's current status and I'll get back to you.

jerry73204 commented 4 years ago

May you take a look at my little project tch-typed-tensor. It provides statically checked tensor type backed on tch, a Rust binding for PyTorch.

vsiles commented 4 years ago

@jerry73204 sure I'll have a look ( and maybe a couples questions. I'll aks them over there !)

t-vi commented 4 years ago

Out of curiosity, how would the mechanism be able to handle convolutions? Or, to pick a simpler example, max or sum?

My impression was that types (as we currently use them) are not the best way to express shape information. The latter are "more dynamic" than say Tensor vs. int, the dtype, etc. in that they depend on the non-tensor parameters. Ideally, we'd have some dynamic annotation mechanism where we can pass in shapes for tensors and specify the other parameters as good as we can (still doesn't help with nonzero, but it would cover a lot more cases than something as static as types). This would have to go somewhere near the native_functions.yaml and would be a rather large amount of work...

But I must admit that I'm mainly looking at this from a JIT perspective, where we have the luxury of more information than with type checking. Don't let the lofty ideas of a passer-by spoil your enthusiasm for practical improvements.

vsiles commented 4 years ago

I had a quick look last week at convolutions. The equation for their size/shape is clearly more complex than the basic example I did, but (with some time) it seems doable as long as most of the shape sizes are literal. We could use integer generics (another experimental typing extensions, still reaaaaally experimental) to deal with these cases. But for that to work, it requires that all inputs are known statically (which seems to be the case for what I read so far).

Using more dynamic annotations could be done using the Annotated new PEP (I have a draft of how to do simple variadic with that), which would expand our expressivity if the type checker knows how to parse these annotations. But I'm not sure that's the right way to go, it needs further investigation.

Ideally, we'd have some dynamic annotation mechanism where we can pass in shapes for tensors and specify the other parameters as good as we can

I think that's basically what I did for cat at the moment: we don't know (yet) how to support it, so I passed the expected type as an annotation, to help the typechecker. Is it what you have in mind or are you thinking about something else ?

t-vi commented 4 years ago

So there are two bits I don't know how to handle with annotations (but you're obviously much more au courant with them):

Maybe I just didn't wrap my head around the possibilities of type annotations enough yet. Also, I think we might want to find a way of presenting the information in a way that makes it accessible to the JIT when the time comes, too.

And that's saying nothing about things like "do we know if the number of elements is divisible by X"-type information which I understand would be beneficial for optimizations

vsiles commented 4 years ago

Just answering one of your points, the other will require some thinking on my side :D

  • Things that depend not on the types passed in but on the values. Leaving out upcasting, for some sum(t: Tensor[dt, a,b], dim: int) the result is Tensor[dt, a] or Tensor[dt, b] depending on the value of dim.

This is usually done with overloading (e.g. look at unsqueeze here: you specify each reasonable case using Literal and overloading. In your example, I would do something like:

@overload
def sum(t :  Tensor[dt, a, b], dim: Literal[0]) -> Tensor[dt, a]:
     ...

@overload
def sum(t :  Tensor[dt, a, b], dim: Literal[1]) -> Tensor[dt, b]:
     ...

# default case
def sum(t, dim):
   ...

If instead of [ dt, a, b] you have a more generic Shape and need indexing in it, that's not currently supported, but it is on our roadmap

orionr commented 4 years ago

cc @XuanQi as an FYI.

jerry73204 commented 4 years ago

@vsiles thanks to have a peek on my little project. Please checkout to old-type-freak-api branch before cargo build. It depends on type-freak that has great changes recently, so the master branch is under code refactoring.

I don't have time to write fancy docs. Let me keep some notes for your interest.

jerry73204 commented 4 years ago

Based on my earlier experience in Rust, I would expect Python types provide these functionalities.

IMHO, purely typed dimensions does not work on Python (especially recursive type). I would expect Keras-style design, or play annotation tricks to do this job.

malmaud commented 4 years ago

How would this work with the upcoming named tensor feature, where dimensions are keyed by name instead of index? The proposed tensor type parameters are intrinsically positional, so that seems incompatible.

gchanan commented 4 years ago

CC @zou3519

I don't think the named tensor case adds that much complexity. Dimensions are still ordered positionally in that regime, it's just that you can associate and access dimensions by name in addition to just by position.

So, going back to @t-vi's reduction example, you would have to associate names as well as sizes in the types, and have logic for mapping known strings as well as known integers to the correct overload.

bewagner commented 4 years ago

@vsiles If there is any way I can help you with this, I would be glad to do so. I would love to see this feature in pytorch!

mrahtz commented 4 years ago

A group of us at DeepMind are interested on working on this too - some kind of system that would be interoperable with TensorFlow, PyTorch, and JAX. We've set up a mailing list at https://groups.google.com/g/python-shape-checkers to try and bring together all the conversations about this into one place. I've posted a summary there of what seems to be the current state of things, but stay tuned for updates!

jerry73204 commented 4 years ago

Cheers @mrahtz. Let me join the group too.

fylux commented 4 years ago

Hi! I have been working on a proposal for adding support for type arithmetic and some other useful features for tensor typing. I will present it next Monday in the next chapter of the Tensor Typing meetings that we have been organising, so if anyone is interested feel free to attend: link

ezyang commented 3 years ago

A more recent development in this space is @patrick-kidger's https://github.com/patrick-kidger/torchtyping/ One really important thing is that his library is using Python 3.9's Annotated extension point https://www.python.org/dev/peps/pep-0593/ to attach on the annotations. This is important because it means we can start playing around with ways to express this information, even in core (assuming that typing_extensions copy of Annotated works for all versions of Python we currently support) without having to wait for various features to make their way into mainstream in mypy. We also don't have to worry about prematurely committing to the wrong representation, since multiple schemes of annotations can coexist (modulo end users who are actually using them).

vadimkantorov commented 3 years ago

At https://github.com/pytorch/pytorch/issues/40373#issuecomment-815363986 I was also advocating for expressing flexible typing / shape constraints via assert statements

vadimkantorov commented 1 year ago

There is now also https://github.com/google/jaxtyping

patrick-kidger commented 1 year ago

Oh, as the author of both TorchTyping and jaxtyping: allow me to strongly recommend jaxtyping, for all use cases! jaxtyping's name is now essentially historical -- it supports annotating shapes+dtypes for JAX+PyTorch+NumPy+TensorFlow, and doesn't actually have a JAX dependency.

One of the main advantages over TorchTyping is that jaxtyping works gracefully with static type checking. The shapes and dtypes won't get checked (beyond what's really feasible with a static type checker unfortunately), but the array-ness of the type will. (In contrast TorchTyping is basically just completely incompatible with static type checking.)

As another bonus, jaxtyping also has much cleaner internals, and doesn't do any awful monkey-patching of typeguard. In particular this means its compatible with other runtime type checkers, e.g. beartype.

vadimkantorov commented 1 year ago

@patrick-kidger In my own typing experiment, I did runtime shape checks triggered by a function decorator: https://gist.github.com/vadimkantorov/71155662378b46cd07b6287f6be7c951

vadimkantorov commented 1 year ago

@ezyang these kinds of "typing decorators" or assert statements for expressing type constraints/hints can actually be used to populate torch.compile/torch.export shape constraints

patrick-kidger commented 1 year ago

Interesting! IIUC this implementation won't work with e.g. tuple[SomeTensorHint, SomeTensorHint] though?

vadimkantorov commented 1 year ago

Maybe one would need to expand the shapecheck function / or not the most expressive design, but should be possible in general! In the function we can get the typing info and then free to process them as we wish