pytorch / executorch

On-device AI across mobile, embedded and edge for PyTorch
https://pytorch.org/executorch/
Other
1.67k stars 280 forks source link

Mutable Buffer Should Only Be Updated by In-Place Op #4042

Closed YifanShenSZ closed 2 months ago

YifanShenSZ commented 2 months ago

Consider such a toy model

    class Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.register_buffer("state", torch.tensor(np.array([7, 5, 6], dtype=np.float16)))

        def forward(self, x):
            x = x.type(torch.float16)
            self.state.mul_(x)
            self.state.add_(torch.tensor(np.array([1, 2, 3], dtype=np.float16)))
            return self.state * 9

I thought, in exported program, mul_ and add_ will be in-place, while self.state * 9 will be out-of-place. However, all ops seem to be out-of-place?

Jack-Khuu commented 2 months ago

ccing @JacobSzwejbka on Mutable Buffer behavior

JacobSzwejbka commented 2 months ago

Can you post the graph after export?

JacobSzwejbka commented 2 months ago

Oh I misread. EP functionalizes the graph so there are no in-place operations.

Here is the graph.

graph(): %b_state : [num_users=1] = placeholder[target=b_state] %c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0] %x : [num_users=1] = placeholder[target=x] %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%b_state, %x), kwargs = {}) %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {}) %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%lift_fresh_copy,), kwargs = {dtype: torch.float16}) %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%_to_copy,), kwargs = {}) %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%clone,), kwargs = {}) %_to_copy_1 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%detach,), kwargs = {dtype: torch.float16, device: cpu}) %detach_1 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%_to_copy_1,), kwargs = {}) %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %detach_1), kwargs = {}) %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 9), kwargs = {}) return (add, mul_1)

Here is the graph signature

ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_state'), target='state', persistent=True), InputSpec(kind=<InputKind.CONSTANT_TENSOR: 4>, arg=TensorArgument(name='c_lifted_tensor_0'), target='lifted_tensor_0', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add'), target='state'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='mul_1'), target=None)])

The first output is flagged as BUFFER_MUTATION. You can imagine then if you were trying to actually run this that there would be an implicit copy from this output back to the state input between inferences. We make that copy explicit in to_executorch so you end up with something like

functional_mul
functional_add
state.copy_(add_result)

That is also the graph you get if you call print(ep.module().graph)

graph(): %state : [num_users=2] = get_attr[target=state] %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0] %x : [num_users=1] = placeholder[target=x] %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%state, %x), kwargs = {}) %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {}) %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%lift_fresh_copy,), kwargs = {dtype: torch.float16}) %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%_to_copy,), kwargs = {}) %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%clone,), kwargs = {}) %_to_copy_1 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%detach,), kwargs = {dtype: torch.float16, device: cpu}) %detach_1 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%_to_copy_1,), kwargs = {}) %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %detach_1), kwargs = {}) %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 9), kwargs = {}) %copy__default : [num_users=0] = callfunction[target=torch.ops.aten.copy.default](args = (%state, %add), kwargs = {}) return (mul_1,)

We have interest to support further "injection" of mutation back into the graph after user passes are finished being run in to_executorch, but nothing is currently in the works.

JacobSzwejbka commented 2 months ago

If the goal is to track what operations were originally in place you can just walk the graph backwards from the MUTABLE_OUTPUT and see what ops it was generated from. Each time it changes back up to the original input would be a mutation on it.

YifanShenSZ commented 2 months ago

Thanks @JacobSzwejbka for the detailed explanation!

To confirm I understand correctly, in EXIR:

  1. We do only have out-of-place ops, since the graph is functionalized
  2. Stateful execution is represented by BUFFER_MUTATION (i.e. in-place state.copy_) on the final buffer result
JacobSzwejbka commented 2 months ago

Yes thats correct. Though delegates wont see the state.copy_ node. It happens near the very end of the lowering process. (Though there is the option for a user to inject a pass after it happens so that they can replace the node with a custom copy op if their memory architecture has strange rules for instance.)

Also worth pointing out that any custom ops that do mutation are wrapped in a higher order OP called "auto_functionalize". This under the hood essentially does:

AutoFunctionalize(custom_op_args): state_copy = custom_op_args.state.clone() out = customop(state_copy, custom_op_args) return out, state_copy

Today in to_executorch we undo this functionalization blindly which is safe for the LLM custom op that exists, but isnt safe in general as we dont reason in that pass if there is any downstream consequence of the mutation. We expect to have to revisit this in the future.

YifanShenSZ commented 2 months ago

Thanks Jacob! 🙏 I will start to try stateful delegate to Core ML backend. Will follow up on slack if any new question emerged