pytorch / torchdynamo

A Python-level JIT compiler designed to make unmodified PyTorch programs faster.
BSD 3-Clause "New" or "Revised" License
1.01k stars 123 forks source link

Questions for normalize_ir() #185

Closed frank-wei closed 2 years ago

frank-wei commented 2 years ago

It looks like the normalizeir() has different behavior than I thought. For ex, I was hoping Functionalization will replace the in-place op to standard op like relu to relu. I run a simple test program with gm.graph as follows:

    %x : torch.Tensor [#users=1] = placeholder[target=x]
    %relu : [#users=1] = call_function[target=torch.nn.functional.relu](args = (%x,), kwargs = {inplace: True})
    return (relu,)

After adding the normalizer in my_compiler. The result gm.graph does not change.

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
      gm = normalize_ir(gm, example_inputs)
      print("result gm=", gm.graph)
      return gm

After more debugging, I found that n.meta["is_input_mutation"] is True for relu node. Here is the code

So my question is, 1) is there way we can do the expected functionalization pass by changing in-place to standard op in this case? 2) just curious what is_input_mutation stands for or what is the situation where is_input_mutation is False for relu node?

anijain2305 commented 2 years ago

Hi Wei, in this case, you are mutating the input itself. Suppose, there is a user of the input x outside the Fx graph, then this user must see the updated/mutated value of x.

If the input was not mutated (opposed to your example), the def-and-use of the mutated variable would be contained within the scope of the Fx graph. And, therefore, we could do graph rewrite to get rid of mutation. This is what normalize-ir does for majority (not all) of the cases. But, handling input mutation requires little more effort.

1) If the input is not mutated, you would see normalize-ir removing a large number of mutation, involving changing in-place op.

2) Specifically for your example, no we don't handle input mutation today. We could implement it by adding extra outputs to the Fx graph, and these extra outputs would be the mutated input values. This way we can contain the scope within the FX graph, and remove the mutation. Finally, we can then overwrite the original inputs to these extra outputs outside the scope of Fx graph.

Somewhat related to this topic is Functionalization pass at the dispatcher level. Today, AOT Autograd relies on Dynamo's normalize-ir to remove mutation. But we plan to move over to the dispatcher-level functionalization pass soon. I am not sure if this benefits your work, but happy to discuss.

frank-wei commented 2 years ago

@anijain2305 thanks for your explanation and it helps me a lot to my concerns. Mutating the input is an extreme case which we rarely see it. I tried with case not mutating the input and saw the expected behavior. For Functionalization pass, I'd like to see the in-place to be changed to standard ones. Actually, just curious if "dispatcher-level" as you mentioned will happen somewhere inside AOT? One more question maybe related to Functionalization pass and AOT Autograd, will it help to remove/transform the torch.ops.aten.copy_ in the case below? The reason is that TRT does not support any in-place operation.

   def forward(self, x, y):
                y = y+3
                y[:,0] = x[:,0]
                return y

gm.graph = 
    %x_1 : [#users=1] = placeholder[target=x_1]
    %y_1 : [#users=1] = placeholder[target=y_1]
    %_tensor_constant0 : [#users=1] = get_attr[target=_tensor_constant0]
    %add : [#users=2] = call_function[target=torch.ops.aten.add](args = (%y_1, %_tensor_constant0), kwargs = {})
    %slice_1 : [#users=1] = call_function[target=torch.ops.aten.slice](args = (%x_1, 0, 0, 9223372036854775807), kwargs = {})
    %select : [#users=1] = call_function[target=torch.ops.aten.select](args = (%slice_1, 1, 0), kwargs = {})
    %slice_2 : [#users=1] = call_function[target=torch.ops.aten.slice](args = (%add, 0, 0, 9223372036854775807), kwargs = {})
    %select_1 : [#users=1] = call_function[target=torch.ops.aten.select](args = (%slice_2, 1, 0), kwargs = {})
    %copy_ : [#users=0] = call_function[target=torch.ops.aten.copy_](args = (%select_1, %select), kwargs = {})
    return add
anijain2305 commented 2 years ago

1) The dispatcher-level functionalization traces the original model, and checks mutation at op-by-op level at the dispatcher. If it sees mutation, it removes it and keeps scorecard for handling the future uses. There is an excellent presentation from Brian Hirsh (author of Functionalization) on this matter.

So, this is not AOT Autograd per se. In fact, in the case of AOT Autograd, we plan to first get the forward and backward graph from the usual AOT tracing. And then call functionalization on top of these forward and backward graphs.

Therefore, I don't see any reason why we cannot use Functionalization for Dynamo-created Fx Graphs.

2) Yes, it is supposed to remove copy_ in the graph. Handling mutation is tough, and various backends have different support for mutation - like nvfuser leaves the mutated ops untouched and leaves performance opportunity (but still runs correctly), while compilers like TVM (or likely TRT) fail if they see mutation.

More details on Functionalization and AOT Autograd integration is here - https://github.com/facebookresearch/torchdynamo/issues/88

frank-wei commented 2 years ago

@anijain2305 it looks great for me of this PR https://github.com/pytorch/functorch/pull/703/files. I am expecting it will remove some mutations on dynamo created fx graph which will help us remove some blockers in some models.