Open Tcc0403 opened 1 month ago
should we adopt the second solution since the first one introduces quite a lot of overhead? also, can you elaborate under which case will this behavior happen?
@ByronHsu
also, can you elaborate under which case will this behavior happen?
Consider the following forward graph:
graph TD
A[input] -->|a| B[exp]
B -->|b| C[liger_ce]
C -->|loss| ouput
to calculate gradients of exp layer, which is exp(input)
, we can either:
a
in forward pass, then recompute exp(a)
in backward passb
in forward pass, no need further operations in backward pass (assum torch marks it as version 0)Normally, we take the least computations/memory option, 2. in this case.
graph TD
A[input] -->|a| B["exp <br> saved tensors: b (v0)"]
B -->|b| C[liger_ce]
C -->|loss| ouput
After a complete forward pass from input a
to loss, now we call loss.backward()
.
graph TD
A[input] <-->|dx * grad_ce = b' * grad_ce| B["exp <br> saved tensors: b' (v0)<br>(changed by liger_ce)"]
B <-->|grad_ce| C[liger_ce]
C <-->|loss| ouput
Notice that in forward pass we stored the gradients of liger_ce
at b
, the input tensor of it, so the saved tensor b
in exp layer has been changed as well. Since the saved tensor is corrupted, exp layer can't produce the correct gradients.
Replacing exp
with any layer that stores output tensor and liger_ce
with any layer that performs inplace operations on input, will result in the same behavior.
tl;dr The saved tensors are corrupted by inplace operations.
The reason why it doesn't raise the error is because triton kernel doesn't bump the version when doing inplace op, so it's still v0 when computing gradients in backward.
If we do inplace outside of kernel by calling torch function, version can be correctly updated.
graph TD
A[input] <-->|"dx * grad_output <br>= b' * grad_output"| B["exp <br> saved tensors: b' (v1)<br>(changed by inplace op)"]
B <-->|grad_output| C["torch's inplace op"]
C <-->|something| something
Thus, the error can be detected.
We can keep pointers of gradients when designing a kernel, and add a boolean argument to autograd.function for users to decide whether storing gradients inplace or not.
If False, we can allocate new memory and pass it to kernel. E.g. X_ptr
and dX_ptr
as below:
https://github.com/linkedin/Liger-Kernel/blob/ff6650bbcef5d31b7522694cbeb73a21169460e9/src/liger_kernel/ops/jsd.py#L64-L77
If True, we can just pass the existing tensor that we want to perform in-place storing. E.g. X_ptr
and dX_ptr
as below:
https://github.com/linkedin/Liger-Kernel/blob/ff6650bbcef5d31b7522694cbeb73a21169460e9/src/liger_kernel/ops/fused_linear_jsd.py#L75-L88
Above examples show that we can design a kernel which looks "out-place" but still can achieve "in-place" storing.
One trivial solution is performing a no-op like inplace operation, such as .add(0) and .mul(1), to explicitly declare we have changed the tensor values in-place, then the errors will be raised.
Since the trivial solution introduces quite a lot of overhead, we can just do it only in the first pass as a in-place correctness checker.
A possible implementation could be like this:
@triton.jit
def _kernel(
x_ptr, # input tensor
y_ptr, # output tensor
dx_ptr, # gradients of input
...
):
... # do something
def forward(_input, inplace: bool, ...):
... # do something
if inplace:
dx = _input
if first_pass: # I haven't come up with a good way to detect first pass or not
_input.add_(0)
else:
dx = tensor.zeros_like(_input)
_kernel[(...)](
x_ptr=_input,
y_ptr=output,
dx_ptr=dx,
...
)
return output
cc @ByronHsu @lancerts
🐛 Describe the bug
254 #262 (comments)
PyTorch’s autograd system records operations on tensors to construct a computational graph, which is used for computing gradients. When an in-place operation is performed on a tensor, the autograd system needs to ensure that the computational graph reflects the modified values.
https://pytorch.org/docs/stable/autograd.html#in-place-correctness-checks
Each tensor in PyTorch has an internal version counter that is incremented every time an in-place operation is performed.
https://github.com/pytorch/pytorch/blob/190e09d8b6a13f789b143f0fbd1325f924550967/c10/core/TensorImpl.h#L382
Since we don't explicitly call pytorch in-place operations, the version counter doesn't change when we do in-place operations in triton kernels, i.e., pytorch's "In-place correctness checks" mechanism won't work properly and show no error to user.
Reproduce
Solution
One trivial solution is performing a no-op like inplace operation, such as
.add_(0)
and.mul_(1)
, to explicitly declare we have changed the tensor values in-place, then the errors will be raised.With this approach, I suggest adding a
inplace=True/False
parameter to those functions involved in-place operations, so users can set it to False (using extra tensors) when getting errors.Versions
Environment Report:
Operating System: Linux-5.15.133.1-microsoft-standard-WSL2-x86_64-with-glibc2.35 Python version: 3.10.12 PyTorch version: 2.4.1+cu121 CUDA version: 12.1 Triton version: 3.0.0 Transformers version: 4.45.0