pytorch / torchdynamo

A Python-level JIT compiler designed to make unmodified PyTorch programs faster.
BSD 3-Clause "New" or "Revised" License
997 stars 123 forks source link

Inductor decomps can decompose into mutations #1081

Closed bdhirsh closed 1 year ago

bdhirsh commented 2 years ago

@jansel pointed out some of the existing scattering ops decompose into mutable variants, like in this example, where index_put decomposes to index_put_.

@Chillee I don't have a concrete example of a failure, but does that make running the decompositions underneath functionalization problematic (see thread for context)? We might now:

(1) Functionalize a graph (which includes index_put) (2) Run further decomps that live in inductor (decomposing index_put into index_put_) (3) Send a graph containing a mutation (index_put_) to the min-cut partitioner, which could silently do the wrong thing.

Let me know if that sounds right to you. Any easy fix (a) might be just to wait to run the decomps until after partitioning, by re-tracing and running make_fx again with the decompositions separately on the partitioned fwd + bwd graphs. Or, (b) we could fully fix https://github.com/pytorch/pytorch/issues/83923#issuecomment-1226314106.

We could also try doing (a) and running the inductor + min_cut_partitioning benchmarks, and see if we fix any correctness failures.

anijain2305 commented 1 year ago

Fixed by https://github.com/pytorch/torchdynamo/pull/1390