ELS-RD / kernl

Kernl lets you run PyTorch transformer models several times faster on GPU with a single line of code, and is designed to be easily hackable.
http://www.kernl.ai
Apache License 2.0
1.53k stars 95 forks source link

bug: start_position support for the fused attention kernel #329

Open ipoletaev opened 1 year ago

ipoletaev commented 1 year ago

Description

Using of a start position index in a fused attention kernel does not work.

Steps to reproduce

START_IDX = 128

def attention_reference(q: torch.Tensor, k: torch.Tensor,
                        v: torch.Tensor) -> torch.Tensor:

    mask_y = torch.full((1, 1, q.size(2), q.size(2)), float("-inf"))
    mask_y = torch.triu(mask_y, diagonal=START_IDX + 1).float()
    att_y = (q @ k.transpose(-2, -1)) * scale
    att_y = att_y + mask_y.to(att_y)
    att_y = torch.nn.functional.softmax(att_y, dim=-1)
    return att_y @ v

q = torch.randn([4, 32, 4096, 128], dtype=torch.float16, device="cuda")
k = torch.randn([4, 32, 4096, 128], dtype=torch.float16, device="cuda")
v = torch.randn([4, 32, 4096, 128], dtype=torch.float16, device="cuda")
scale = 1 / math.sqrt(128)

x = triton_fa(q, k, v, scale, True, START_IDX)
y = attention_reference(q, k, v)
print(torch.max(torch.abs(x - y)))
print(torch.sum(x - y))

Expected Behavior

Almost identical prediction as with the vanilla implementation for any start position index.

Actual Behavior

Returns nan for any START_IDX != 0.

Your environment

torch==2.0.0 triton==2.0.0

Self-service

Code of Conduct