linkedin / Liger-Kernel

Efficient Triton Kernels for LLM Training
https://arxiv.org/pdf/2410.10989
BSD 2-Clause "Simplified" License
3.37k stars 189 forks source link

In-place operations in triton kernel might result in incorrect gradient calculations #272

Open Tcc0403 opened 1 month ago

Tcc0403 commented 1 month ago

🐛 Describe the bug

254 #262 (comments)

PyTorch’s autograd system records operations on tensors to construct a computational graph, which is used for computing gradients. When an in-place operation is performed on a tensor, the autograd system needs to ensure that the computational graph reflects the modified values.

https://pytorch.org/docs/stable/autograd.html#in-place-correctness-checks

Each tensor in PyTorch has an internal version counter that is incremented every time an in-place operation is performed.

https://github.com/pytorch/pytorch/blob/190e09d8b6a13f789b143f0fbd1325f924550967/c10/core/TensorImpl.h#L382

Since we don't explicitly call pytorch in-place operations, the version counter doesn't change when we do in-place operations in triton kernels, i.e., pytorch's "In-place correctness checks" mechanism won't work properly and show no error to user.

Reproduce

import torch
import torch.nn.functional as F

from liger_kernel.transformers.functional import liger_cross_entropy

def run_inplace_experiment(logits_p, logits_q, cross_entropy_fn):
    _p = logits_p.clone().detach().requires_grad_(True)
    _p.retain_grad()
    softmax = torch.nn.Softmax(dim=-1)
    p = softmax(_p)
    p.retain_grad()
    loss = cross_entropy_fn(p, logits_q)
    loss.backward(retain_graph=True)

    print(f"Cross Entropy Loss: {loss.item()}")
    print(f"Input _p: {_p}")
    print(f"Input logits_q: {logits_q}")
    print(f"Gradients of p (batch item 0): {p.grad[0]}")
    print(f"Gradients of _p (batch item 0): {_p.grad[0]}")

torch.manual_seed(0)
logits_p = torch.randn(8, 8, requires_grad=True, device="cuda")
logits_q = torch.randint(0, 8, (8,), device="cuda", dtype=torch.long)

run_inplace_experiment(logits_p, logits_q, cross_entropy_fn=F.cross_entropy)

print()
print("LIGER:")
run_inplace_experiment(logits_p, logits_q, cross_entropy_fn=liger_cross_entropy)
❯ python3 inplace_bug.py
Cross Entropy Loss: 2.08567214012146
Input _p: tensor([[-0.9247, -0.4253, -2.6438,  0.1452, -0.1209, -0.5797, -0.6229, -0.3284],
        [-1.0745, -0.3631, -1.6711,  2.2655,  0.3117, -0.1842,  1.2866,  1.1820],
        [-0.1271,  1.2169,  1.4353,  1.0605, -0.4941, -1.4244, -0.7244, -1.2973],
        [ 0.0697, -0.0074,  1.8969,  0.6878, -0.0779, -0.8373,  1.3506, -0.2879],
        [-0.5965, -0.3283, -0.9086, -0.8059, -0.7407, -0.0504,  0.5435,  1.5150],
        [ 0.0141,  0.4532,  1.6349,  0.7124, -0.1806,  1.0252, -1.4622, -0.7554],
        [-0.1836,  0.3824,  0.3918, -0.0830,  0.8971, -1.1123,  0.1116,  0.4863],
        [-0.5499, -0.3231, -0.5469,  0.9049,  0.2837,  0.1210,  0.4730, -1.0823]],
       device='cuda:0', requires_grad=True)
Input logits_q: tensor([4, 6, 7, 2, 2, 6, 5, 5], device='cuda:0')
Gradients of p (batch item 0): tensor([ 0.0149,  0.0157,  0.0140,  0.0174, -0.1086,  0.0154,  0.0153,  0.0159],
       device='cuda:0')
Gradients of _p (batch item 0): tensor([ 0.0017,  0.0029,  0.0003,  0.0055, -0.0182,  0.0024,  0.0023,  0.0032],
       device='cuda:0')

LIGER:
Cross Entropy Loss: 2.08567214012146
Input _p: tensor([[-0.9247, -0.4253, -2.6438,  0.1452, -0.1209, -0.5797, -0.6229, -0.3284],
        [-1.0745, -0.3631, -1.6711,  2.2655,  0.3117, -0.1842,  1.2866,  1.1820],
        [-0.1271,  1.2169,  1.4353,  1.0605, -0.4941, -1.4244, -0.7244, -1.2973],
        [ 0.0697, -0.0074,  1.8969,  0.6878, -0.0779, -0.8373,  1.3506, -0.2879],
        [-0.5965, -0.3283, -0.9086, -0.8059, -0.7407, -0.0504,  0.5435,  1.5150],
        [ 0.0141,  0.4532,  1.6349,  0.7124, -0.1806,  1.0252, -1.4622, -0.7554],
        [-0.1836,  0.3824,  0.3918, -0.0830,  0.8971, -1.1123,  0.1116,  0.4863],
        [-0.5499, -0.3231, -0.5469,  0.9049,  0.2837,  0.1210,  0.4730, -1.0823]],
       device='cuda:0', requires_grad=True)
Input logits_q: tensor([4, 6, 7, 2, 2, 6, 5, 5], device='cuda:0')
Gradients of p (batch item 0): tensor([ 0.0149,  0.0157,  0.0140,  0.0174, -0.1086,  0.0154,  0.0153,  0.0159],
       device='cuda:0')
Gradients of _p (batch item 0): tensor([2.1320e-05, 3.4830e-05, 6.8024e-06, 6.7467e-05, 1.3247e-02, 2.9687e-05,
        2.8429e-05, 3.8656e-05], device='cuda:0')

Solution

One trivial solution is performing a no-op like inplace operation, such as .add_(0) and .mul_(1), to explicitly declare we have changed the tensor values in-place, then the errors will be raised.

With this approach, I suggest adding a inplace=True/False parameter to those functions involved in-place operations, so users can set it to False (using extra tensors) when getting errors.

Versions

Environment Report:

Operating System: Linux-5.15.133.1-microsoft-standard-WSL2-x86_64-with-glibc2.35 Python version: 3.10.12 PyTorch version: 2.4.1+cu121 CUDA version: 12.1 Triton version: 3.0.0 Transformers version: 4.45.0

ByronHsu commented 1 month ago

should we adopt the second solution since the first one introduces quite a lot of overhead? also, can you elaborate under which case will this behavior happen?

Tcc0403 commented 1 month ago

@ByronHsu

also, can you elaborate under which case will this behavior happen?

Consider the following forward graph:

graph TD
    A[input] -->|a| B[exp]
    B -->|b| C[liger_ce]
    C -->|loss| ouput

to calculate gradients of exp layer, which is exp(input), we can either:

  1. save input tensor a in forward pass, then recompute exp(a) in backward pass
  2. save output tensor b in forward pass, no need further operations in backward pass (assum torch marks it as version 0)

Normally, we take the least computations/memory option, 2. in this case.

graph TD
    A[input] -->|a| B["exp <br> saved tensors: b (v0)"]
    B -->|b| C[liger_ce]
    C -->|loss| ouput

After a complete forward pass from input a to loss, now we call loss.backward().

graph TD
    A[input] <-->|dx * grad_ce = b' * grad_ce| B["exp <br> saved tensors: b' (v0)<br>(changed by liger_ce)"]
    B <-->|grad_ce| C[liger_ce]
    C <-->|loss| ouput

Notice that in forward pass we stored the gradients of liger_ce at b, the input tensor of it, so the saved tensor b in exp layer has been changed as well. Since the saved tensor is corrupted, exp layer can't produce the correct gradients.

Replacing exp with any layer that stores output tensor and liger_ce with any layer that performs inplace operations on input, will result in the same behavior.

tl;dr The saved tensors are corrupted by inplace operations.

Why no error?

The reason why it doesn't raise the error is because triton kernel doesn't bump the version when doing inplace op, so it's still v0 when computing gradients in backward.

If we do inplace outside of kernel by calling torch function, version can be correctly updated.

graph TD
    A[input] <-->|"dx * grad_output <br>= b' * grad_output"| B["exp <br> saved tensors: b' (v1)<br>(changed by inplace op)"]
    B <-->|grad_output| C["torch's inplace op"]
    C <-->|something| something

Thus, the error can be detected.

Tcc0403 commented 3 weeks ago

We can keep pointers of gradients when designing a kernel, and add a boolean argument to autograd.function for users to decide whether storing gradients inplace or not.

If False, we can allocate new memory and pass it to kernel. E.g. X_ptr and dX_ptr as below: https://github.com/linkedin/Liger-Kernel/blob/ff6650bbcef5d31b7522694cbeb73a21169460e9/src/liger_kernel/ops/jsd.py#L64-L77 If True, we can just pass the existing tensor that we want to perform in-place storing. E.g. X_ptr and dX_ptr as below: https://github.com/linkedin/Liger-Kernel/blob/ff6650bbcef5d31b7522694cbeb73a21169460e9/src/liger_kernel/ops/fused_linear_jsd.py#L75-L88

Above examples show that we can design a kernel which looks "out-place" but still can achieve "in-place" storing.

One trivial solution is performing a no-op like inplace operation, such as .add(0) and .mul(1), to explicitly declare we have changed the tensor values in-place, then the errors will be raised.

Since the trivial solution introduces quite a lot of overhead, we can just do it only in the first pass as a in-place correctness checker.

A possible implementation could be like this:

@triton.jit
def _kernel(
    x_ptr, # input tensor
    y_ptr, # output tensor
    dx_ptr, # gradients of input  
    ... 
):
    ... # do something

def forward(_input, inplace: bool, ...):
    ... # do something
    if inplace:
        dx = _input
        if first_pass: # I haven't come up with a good way to detect first pass or not
            _input.add_(0) 
    else:
        dx = tensor.zeros_like(_input)
    _kernel[(...)](
        x_ptr=_input,
        y_ptr=output,
        dx_ptr=dx,
        ...
    )
    return output

cc @ByronHsu @lancerts