Open ByronHsu opened 5 months ago
Answer from claude 3 Sonnet (much better than chatgpt and gemini in this case):
This behavior is related to how PyTorch handles in-place operations and how it tracks the version of tensors for autograd. In PyTorch, when a tensor is modified in-place, its version is incremented to keep track of the changes and ensure correct gradient computation.
In your case, when you use the
TritonSquareFunction
without the redundant views, the input tensor's version stays at 0 because the triton kernel is directly modifying the input tensor in-place without creating any new views or copies. PyTorch allows this as long as the tensor is not used in any other computation that requires its original value.However, when you introduce the redundant views (
X.view(B, T, C).view(B*T, C)
), PyTorch creates new views of the input tensor, and the version of the input tensor is incremented. During the backward pass, when you try to modify the input tensor in-place using the triton kernel, PyTorch detects that the tensor version has changed, and it throws an error because it expects the tensor version to be the same as when it was saved for the backward pass.The PyTorch native backward (
SquareFunction
) increments the input tensor's version after each backward pass because it creates a new tensor for the gradient and performs an in-place addition operation (input *= grad_output
).To safely modify the input tensor in-place during the backward pass, you should avoid creating any new views or copies of the input tensor before passing it to the forward function. If you need to reshape the input tensor, you should do it inside the forward function and save the reshaped tensor for the backward pass.
Additionally, you can consider using PyTorch's in-place operations carefully or use PyTorch's functional APIs that perform out-of-place operations to avoid modifying the input tensor directly.
In general, modifying input tensors in-place during the backward pass is not recommended because it can lead to unexpected behaviors and potential errors, especially when dealing with complex computations and control flows. It's safer to create new tensors for the gradients and perform out-of-place operations to avoid modifying the input tensors.
full(TritonSquareFunction.apply)
1st backward input version 0
2nd backward input version 0
3rd backward input version 0
full_with_view(TritonSquareFunction.apply)
1st backward input version 0
2nd backward input version 1
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [8, 8]], which is output 0 of AsStridedBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True)
full(SquareFunction.apply)
1st backward input version 1
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [8, 8]] is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
full_with_view(SquareFunction.apply)
1st backward input version 1
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [8, 8]], which is output 0 of AsStridedBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
Question
Hi triton team, I am implementing a triton kernel (
2.3.0
on A100) with the trick of modifying the input tensor in-place to store gradient to save memory. However, I found some unexpected behaviors:In general, i would like to know how to safely modify input tensor in-place to save memory in backward, or it is not encouraged?
Reproduce
Output: