@jansel pointed out some of the existing scattering ops decompose into mutable variants, like in this example, where index_put decomposes to index_put_.
@Chillee I don't have a concrete example of a failure, but does that make running the decompositions underneath functionalization problematic (see thread for context)? We might now:
(1) Functionalize a graph (which includes index_put)
(2) Run further decomps that live in inductor (decomposing index_put into index_put_)
(3) Send a graph containing a mutation (index_put_) to the min-cut partitioner, which could silently do the wrong thing.
Let me know if that sounds right to you. Any easy fix (a) might be just to wait to run the decomps until after partitioning, by re-tracing and running make_fx again with the decompositions separately on the partitioned fwd + bwd graphs. Or, (b) we could fully fix https://github.com/pytorch/pytorch/issues/83923#issuecomment-1226314106.
We could also try doing (a) and running the inductor + min_cut_partitioning benchmarks, and see if we fix any correctness failures.
@jansel pointed out some of the existing scattering ops decompose into mutable variants, like in this example, where
index_put
decomposes toindex_put_
.@Chillee I don't have a concrete example of a failure, but does that make running the decompositions underneath functionalization problematic (see thread for context)? We might now:
(1) Functionalize a graph (which includes
index_put
) (2) Run further decomps that live in inductor (decomposingindex_put
intoindex_put_
) (3) Send a graph containing a mutation (index_put_
) to the min-cut partitioner, which could silently do the wrong thing.Let me know if that sounds right to you. Any easy fix (a) might be just to wait to run the decomps until after partitioning, by re-tracing and running
make_fx
again with the decompositions separately on the partitioned fwd + bwd graphs. Or, (b) we could fully fix https://github.com/pytorch/pytorch/issues/83923#issuecomment-1226314106.We could also try doing (a) and running the inductor + min_cut_partitioning benchmarks, and see if we fix any correctness failures.