pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.56k stars 22.54k forks source link

Significant Accuracy Difference between Compiled and Eager Flex Attention #135161

Open cora-codes opened 1 month ago

cora-codes commented 1 month ago

🐛 Describe the bug

I struggled a bit to get a repro, but I think this is in the realm of reasonable and identifies the behavior that causes my runs to diverge.

import torch
import torch.nn as nn
import torch.nn.attention.flex_attention

class Repro(nn.Module):
    def __init__(self):
        super().__init__()
        self.qkv_proj = nn.Linear(256, 256 * 3)
        self.n_head = 256 // 64
        self.d_attn = 256
        self.qkv_proj.weight.data.fill_(0.1)
        self.qkv_proj.bias.data.fill_(0.1)

    def forward(self, x):
        n_batch, n_ctx, _ = x.shape
        q, k, v = self.qkv_proj(x).split([self.d_attn, self.d_attn, self.d_attn], dim=2)
        q = q.reshape(n_batch, n_ctx, self.n_head, -1).transpose(1, 2)
        k = k.reshape(n_batch, n_ctx, self.n_head, -1).transpose(1, 2)
        v = v.reshape(n_batch, n_ctx, self.n_head, -1).transpose(1, 2)
        return torch.nn.attention.flex_attention.flex_attention(q, k, v)

torch.set_default_device("cuda")
torch.manual_seed(0)

model = Repro()

compiled_model = Repro()
compiled_model = torch.compile(compiled_model)

x = torch.randn((1, 512, 256), requires_grad=True)
x_compiled = x.clone().detach().requires_grad_(True)

out = model(x)
out_compiled = compiled_model(x_compiled)

out.sum().backward()
out_compiled.sum().backward()

weight_diff = torch.max(torch.abs(model.qkv_proj.weight.grad - compiled_model.qkv_proj.weight.grad)).item()
bias_diff = torch.max(torch.abs(model.qkv_proj.bias.grad - compiled_model.qkv_proj.bias.grad)).item()

print(f"Weight grad max abs diff: {weight_diff:.2e}")
print(f"Bias grad max abs diff: {bias_diff:.2e}")

The difference between the compiled and eager versions is:

Weight grad max abs diff: 1.37e+01
Bias grad max abs diff: 1.82e+01

This might not seem like a big deal (maybe the matrix initialization is just too big after all) but consider the contrast if you use torch.nn.functional.scaled_dot_product_attention (with the same initialization and precision)

Weight grad max abs diff: 0.00e+00
Bias grad max abs diff: 1.22e-04

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng

Versions

'2.5.0.dev20240904+cu124'

cora-codes commented 1 month ago

Also, if you split up the qkv_proj parameter into query, key and value I've found the query and key gradient are incorrect while the value gradient error is still small

bdhirsh commented 1 month ago

1.37e+01 seems like a pretty big delta

bdhirsh commented 1 month ago

(cc @drisspg )

drisspg commented 1 month ago

Yeah me and @Chillee were looking at this and it seems that the large delta in fp32 is due to the use of tensor cores "by default" in FlexAttentions kernels. I have this PR to align usage: https://github.com/pytorch/pytorch/pull/135168

cora-codes commented 1 month ago

Yes, definitely seems like a FP32 issue.

I've found it exclusively happens in the backward pass - similar to issue #133431 (except I see no segfault)

cora-codes commented 1 month ago

Here is a notebook that precisely showcases an another (?) massive discrepancy between compiled and uncompiled flex attention when calculating gradients: https://drive.google.com/file/d/1d2xo02cOdqwmqiTT44LnWri_PqGcvjQx/view?usp=sharing (I've exported it as a PDF so you can see the odd looking error distribution). This time, I attempted to pinpoint why I believe the LSE gradient is incorrect as implemented in Triton side.

drisspg commented 1 month ago

@cora-codes Is local_attention and global_attention individually all close?

cora-codes commented 1 month ago

Yes

cora-codes commented 1 month ago

I'm going to attempt to further narrow down what's going on here, starting by having a smaller repro script - here is what I've got so far - let me know if I should open a separate issue or not. I'm using the latest nightly: torch-2.6.0.dev20240925+cu124

import functools
import torch
import torch.nn.attention.flex_attention

torch.set_default_device("cuda")

def merge_attn(x_local, lse_local, x_global, lse_global):
    max_lse = torch.maximum(lse_local, lse_global).detach()
    exp_local = torch.exp(lse_local - max_lse)
    exp_global = torch.exp(lse_global - max_lse)
    numerator = (x_local * exp_local[..., None]) + (x_global * exp_global[..., None])
    denominator = exp_local[..., None] + exp_global[..., None]
    merged_x = numerator / denominator
    merged_lse = max_lse + torch.log(exp_local + exp_global)
    return merged_x, merged_lse

def create_masks(n_local_band):
    def sliding_window_causal(b, h, q_idx, kv_idx):
        return (q_idx >= kv_idx) & (q_idx - kv_idx <= n_local_band)

    def global_causal_v1(b, h, q_idx, kv_idx):
        return q_idx > kv_idx

    def global_causal_v2(b, h, q_idx, kv_idx):
        return (q_idx >= kv_idx) & (q_idx - kv_idx > n_local_band)

    sliding_window_causal_mask = torch.nn.attention.flex_attention.create_block_mask(
        sliding_window_causal, B=None, H=None, Q_LEN=512, KV_LEN=512
    )
    global_causal_mask_v1 = torch.nn.attention.flex_attention.create_block_mask(
        global_causal_v1, B=None, H=None, Q_LEN=512 - n_local_band, KV_LEN=512
    )
    global_causal_mask_v2 = torch.nn.attention.flex_attention.create_block_mask(
        global_causal_v2, B=None, H=None, Q_LEN=512, KV_LEN=512
    )

    return sliding_window_causal_mask, global_causal_mask_v1, global_causal_mask_v2

def attn_v1(query, key, value, sliding_window_causal_mask, global_causal_mask):
    n_batch, n_ctx, d_model = query.shape
    n_head, n_local_band = 16, 128

    query = query.reshape(n_batch, n_ctx, n_head, -1).transpose(1, 2)
    key = key.reshape(n_batch, n_ctx, n_head, -1).transpose(1, 2)
    value = value.reshape(n_batch, n_ctx, n_head, -1).transpose(1, 2)

    local_attn = functools.partial(torch.nn.attention.flex_attention.flex_attention, block_mask=sliding_window_causal_mask, return_lse=True)
    global_attn = functools.partial(torch.nn.attention.flex_attention.flex_attention, block_mask=global_causal_mask, return_lse=True)

    x_local, lse_local = local_attn(query, key, value)
    x_global, lse_global = global_attn(query[:, :, n_local_band:, :], key, value)

    x, lse = merge_attn(
        x_local[:, :, n_local_band:, :], lse_local[:, :, n_local_band:], x_global, lse_global
    )
    x = torch.concat([x_local[:, :, :n_local_band, :], x], dim=2)
    x = x.transpose(1, 2).contiguous().reshape(n_batch, n_ctx, d_model)
    return x

def attn_v2(query, key, value, sliding_window_causal_mask, global_causal_mask):
    n_batch, n_ctx, d_model = query.shape
    n_head = 16

    query = query.reshape(n_batch, n_ctx, n_head, -1).transpose(1, 2)
    key = key.reshape(n_batch, n_ctx, n_head, -1).transpose(1, 2)
    value = value.reshape(n_batch, n_ctx, n_head, -1).transpose(1, 2)

    local_attn = functools.partial(torch.nn.attention.flex_attention.flex_attention, block_mask=sliding_window_causal_mask, return_lse=True)
    global_attn = functools.partial(torch.nn.attention.flex_attention.flex_attention, block_mask=global_causal_mask, return_lse=True)

    x_local, lse_local = local_attn(query, key, value)
    x_global, lse_global = global_attn(query, key, value)

    x, lse = merge_attn(x_local, lse_local, x_global, lse_global)
    x = x.transpose(1, 2).contiguous().reshape(n_batch, n_ctx, d_model)
    return x

def run_comparison(compile=False):
    n_local_band = 128
    sliding_window_causal_mask, global_causal_mask_v1, global_causal_mask_v2 = create_masks(n_local_band)

    if compile:
        attn_v1_func = torch.compile(attn_v1)
        attn_v2_func = torch.compile(attn_v2)
    else:
        attn_v1_func = attn_v1
        attn_v2_func = attn_v2

    query_v1 = torch.randn(2, 512, 512, requires_grad=True)
    key_v1 = torch.randn(2, 512, 512, requires_grad=True)
    value_v1 = torch.randn(2, 512, 512, requires_grad=True)

    query_v2 = query_v1.clone().detach().requires_grad_(True)
    key_v2 = key_v1.clone().detach().requires_grad_(True)
    value_v2 = value_v1.clone().detach().requires_grad_(True)

    out_v1 = attn_v1_func(query_v1, key_v1, value_v1, sliding_window_causal_mask, global_causal_mask_v1)
    out_v1.sum().backward()

    out_v2 = attn_v2_func(query_v2, key_v2, value_v2, sliding_window_causal_mask, global_causal_mask_v2)
    out_v2.sum().backward()

    print(f"Output difference - Min: {(out_v1 - out_v2).min():.2e}, Max: {(out_v1 - out_v2).max():.2e}")

    for name, grad_1, grad_2 in zip(["query", "key", "value"], [query_v1, key_v1, value_v1], [query_v2, key_v2, value_v2]):
        print(f"{name} gradient difference - Min: {(grad_1.grad - grad_2.grad).min():.2e}, Max: {(grad_1.grad - grad_2.grad).max():.2e}")
        print(f"{name} gradients close: {torch.allclose(grad_1.grad, grad_2.grad)}")

print("Without compile:")
run_comparison(compile=False)

print("\nWith compile:")
run_comparison(compile=True)

It should output:

Without compile:
Output difference - Min: 0.00e+00, Max: 0.00e+00
query gradient difference - Min: 0.00e+00, Max: 0.00e+00
query gradients close: True
key gradient difference - Min: 0.00e+00, Max: 0.00e+00
key gradients close: True
value gradient difference - Min: 0.00e+00, Max: 0.00e+00
value gradients close: True

With compile:
Output difference - Min: 0.00e+00, Max: 0.00e+00
query gradient difference - Min: -1.04e+00, Max: 1.12e+00
query gradients close: False
key gradient difference - Min: -1.44e+00, Max: 1.68e+00
key gradients close: False
value gradient difference - Min: -1.23e+00, Max: 1.17e+00
value gradients close: False
drisspg commented 1 month ago

On nighlty I can repro this. Just getting back from PTO so will take another look, I appreciate you diving deeper on this

cora-codes commented 3 weeks ago

Hey, just wanted to follow up here.

I'm happy to take a stab at debugging this again, but wanted to see if you have any findings you can share so I'm not duplicating any progress you've made

drisspg commented 3 weeks ago

Still debugging TBH, I do agree that this is a high Pri issue and want to get to the bottom of it

cora-codes commented 3 weeks ago

Glad to hear it - I wish I had something to report after a decent chunk of yesterday spent trying to debug this, but I've got nothing new 😔