Closed YifanShenSZ closed 2 months ago
ccing @JacobSzwejbka on Mutable Buffer behavior
Can you post the graph after export?
Oh I misread. EP functionalizes the graph so there are no in-place operations.
Here is the graph.
graph(): %b_state : [num_users=1] = placeholder[target=b_state] %c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0] %x : [num_users=1] = placeholder[target=x] %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%b_state, %x), kwargs = {}) %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {}) %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%lift_fresh_copy,), kwargs = {dtype: torch.float16}) %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%_to_copy,), kwargs = {}) %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%clone,), kwargs = {}) %_to_copy_1 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%detach,), kwargs = {dtype: torch.float16, device: cpu}) %detach_1 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%_to_copy_1,), kwargs = {}) %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %detach_1), kwargs = {}) %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 9), kwargs = {}) return (add, mul_1)
Here is the graph signature
ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_state'), target='state', persistent=True), InputSpec(kind=<InputKind.CONSTANT_TENSOR: 4>, arg=TensorArgument(name='c_lifted_tensor_0'), target='lifted_tensor_0', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add'), target='state'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='mul_1'), target=None)])
The first output is flagged as BUFFER_MUTATION. You can imagine then if you were trying to actually run this that there would be an implicit copy from this output back to the state input between inferences. We make that copy explicit in to_executorch so you end up with something like
functional_mul
functional_add
state.copy_(add_result)
That is also the graph you get if you call print(ep.module().graph)
graph(): %state : [num_users=2] = get_attr[target=state] %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0] %x : [num_users=1] = placeholder[target=x] %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%state, %x), kwargs = {}) %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%lifted_tensor_0,), kwargs = {}) %_to_copy : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%lift_fresh_copy,), kwargs = {dtype: torch.float16}) %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%_to_copy,), kwargs = {}) %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%clone,), kwargs = {}) %_to_copy_1 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%detach,), kwargs = {dtype: torch.float16, device: cpu}) %detach_1 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%_to_copy_1,), kwargs = {}) %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %detach_1), kwargs = {}) %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 9), kwargs = {}) %copy__default : [num_users=0] = callfunction[target=torch.ops.aten.copy.default](args = (%state, %add), kwargs = {}) return (mul_1,)
We have interest to support further "injection" of mutation back into the graph after user passes are finished being run in to_executorch, but nothing is currently in the works.
If the goal is to track what operations were originally in place you can just walk the graph backwards from the MUTABLE_OUTPUT and see what ops it was generated from. Each time it changes back up to the original input would be a mutation on it.
Thanks @JacobSzwejbka for the detailed explanation!
To confirm I understand correctly, in EXIR:
BUFFER_MUTATION
(i.e. in-place state.copy_
) on the final buffer resultYes thats correct. Though delegates wont see the state.copy_ node. It happens near the very end of the lowering process. (Though there is the option for a user to inject a pass after it happens so that they can replace the node with a custom copy op if their memory architecture has strange rules for instance.)
Also worth pointing out that any custom ops that do mutation are wrapped in a higher order OP called "auto_functionalize". This under the hood essentially does:
AutoFunctionalize(custom_op_args): state_copy = custom_op_args.state.clone() out = customop(state_copy, custom_op_args) return out, state_copy
Today in to_executorch we undo this functionalization blindly which is safe for the LLM custom op that exists, but isnt safe in general as we dont reason in that pass if there is any downstream consequence of the mutation. We expect to have to revisit this in the future.
Thanks Jacob! 🙏 I will start to try stateful delegate to Core ML backend. Will follow up on slack if any new question emerged
Consider such a toy model
I thought, in exported program,
mul_
andadd_
will be in-place, whileself.state * 9
will be out-of-place. However, all ops seem to be out-of-place?