Open cora-codes opened 1 month ago
Also, if you split up the qkv_proj
parameter into query
, key
and value
I've found the query
and key
gradient are incorrect while the value
gradient error is still small
1.37e+01
seems like a pretty big delta
(cc @drisspg )
Yeah me and @Chillee were looking at this and it seems that the large delta in fp32 is due to the use of tensor cores "by default" in FlexAttentions kernels. I have this PR to align usage: https://github.com/pytorch/pytorch/pull/135168
Yes, definitely seems like a FP32 issue.
I've found it exclusively happens in the backward pass - similar to issue #133431 (except I see no segfault)
Here is a notebook that precisely showcases an another (?) massive discrepancy between compiled and uncompiled flex attention when calculating gradients: https://drive.google.com/file/d/1d2xo02cOdqwmqiTT44LnWri_PqGcvjQx/view?usp=sharing (I've exported it as a PDF so you can see the odd looking error distribution). This time, I attempted to pinpoint why I believe the LSE gradient is incorrect as implemented in Triton side.
@cora-codes Is local_attention and global_attention individually all close?
Yes
I'm going to attempt to further narrow down what's going on here, starting by having a smaller repro script - here is what I've got so far - let me know if I should open a separate issue or not. I'm using the latest nightly: torch-2.6.0.dev20240925+cu124
import functools
import torch
import torch.nn.attention.flex_attention
torch.set_default_device("cuda")
def merge_attn(x_local, lse_local, x_global, lse_global):
max_lse = torch.maximum(lse_local, lse_global).detach()
exp_local = torch.exp(lse_local - max_lse)
exp_global = torch.exp(lse_global - max_lse)
numerator = (x_local * exp_local[..., None]) + (x_global * exp_global[..., None])
denominator = exp_local[..., None] + exp_global[..., None]
merged_x = numerator / denominator
merged_lse = max_lse + torch.log(exp_local + exp_global)
return merged_x, merged_lse
def create_masks(n_local_band):
def sliding_window_causal(b, h, q_idx, kv_idx):
return (q_idx >= kv_idx) & (q_idx - kv_idx <= n_local_band)
def global_causal_v1(b, h, q_idx, kv_idx):
return q_idx > kv_idx
def global_causal_v2(b, h, q_idx, kv_idx):
return (q_idx >= kv_idx) & (q_idx - kv_idx > n_local_band)
sliding_window_causal_mask = torch.nn.attention.flex_attention.create_block_mask(
sliding_window_causal, B=None, H=None, Q_LEN=512, KV_LEN=512
)
global_causal_mask_v1 = torch.nn.attention.flex_attention.create_block_mask(
global_causal_v1, B=None, H=None, Q_LEN=512 - n_local_band, KV_LEN=512
)
global_causal_mask_v2 = torch.nn.attention.flex_attention.create_block_mask(
global_causal_v2, B=None, H=None, Q_LEN=512, KV_LEN=512
)
return sliding_window_causal_mask, global_causal_mask_v1, global_causal_mask_v2
def attn_v1(query, key, value, sliding_window_causal_mask, global_causal_mask):
n_batch, n_ctx, d_model = query.shape
n_head, n_local_band = 16, 128
query = query.reshape(n_batch, n_ctx, n_head, -1).transpose(1, 2)
key = key.reshape(n_batch, n_ctx, n_head, -1).transpose(1, 2)
value = value.reshape(n_batch, n_ctx, n_head, -1).transpose(1, 2)
local_attn = functools.partial(torch.nn.attention.flex_attention.flex_attention, block_mask=sliding_window_causal_mask, return_lse=True)
global_attn = functools.partial(torch.nn.attention.flex_attention.flex_attention, block_mask=global_causal_mask, return_lse=True)
x_local, lse_local = local_attn(query, key, value)
x_global, lse_global = global_attn(query[:, :, n_local_band:, :], key, value)
x, lse = merge_attn(
x_local[:, :, n_local_band:, :], lse_local[:, :, n_local_band:], x_global, lse_global
)
x = torch.concat([x_local[:, :, :n_local_band, :], x], dim=2)
x = x.transpose(1, 2).contiguous().reshape(n_batch, n_ctx, d_model)
return x
def attn_v2(query, key, value, sliding_window_causal_mask, global_causal_mask):
n_batch, n_ctx, d_model = query.shape
n_head = 16
query = query.reshape(n_batch, n_ctx, n_head, -1).transpose(1, 2)
key = key.reshape(n_batch, n_ctx, n_head, -1).transpose(1, 2)
value = value.reshape(n_batch, n_ctx, n_head, -1).transpose(1, 2)
local_attn = functools.partial(torch.nn.attention.flex_attention.flex_attention, block_mask=sliding_window_causal_mask, return_lse=True)
global_attn = functools.partial(torch.nn.attention.flex_attention.flex_attention, block_mask=global_causal_mask, return_lse=True)
x_local, lse_local = local_attn(query, key, value)
x_global, lse_global = global_attn(query, key, value)
x, lse = merge_attn(x_local, lse_local, x_global, lse_global)
x = x.transpose(1, 2).contiguous().reshape(n_batch, n_ctx, d_model)
return x
def run_comparison(compile=False):
n_local_band = 128
sliding_window_causal_mask, global_causal_mask_v1, global_causal_mask_v2 = create_masks(n_local_band)
if compile:
attn_v1_func = torch.compile(attn_v1)
attn_v2_func = torch.compile(attn_v2)
else:
attn_v1_func = attn_v1
attn_v2_func = attn_v2
query_v1 = torch.randn(2, 512, 512, requires_grad=True)
key_v1 = torch.randn(2, 512, 512, requires_grad=True)
value_v1 = torch.randn(2, 512, 512, requires_grad=True)
query_v2 = query_v1.clone().detach().requires_grad_(True)
key_v2 = key_v1.clone().detach().requires_grad_(True)
value_v2 = value_v1.clone().detach().requires_grad_(True)
out_v1 = attn_v1_func(query_v1, key_v1, value_v1, sliding_window_causal_mask, global_causal_mask_v1)
out_v1.sum().backward()
out_v2 = attn_v2_func(query_v2, key_v2, value_v2, sliding_window_causal_mask, global_causal_mask_v2)
out_v2.sum().backward()
print(f"Output difference - Min: {(out_v1 - out_v2).min():.2e}, Max: {(out_v1 - out_v2).max():.2e}")
for name, grad_1, grad_2 in zip(["query", "key", "value"], [query_v1, key_v1, value_v1], [query_v2, key_v2, value_v2]):
print(f"{name} gradient difference - Min: {(grad_1.grad - grad_2.grad).min():.2e}, Max: {(grad_1.grad - grad_2.grad).max():.2e}")
print(f"{name} gradients close: {torch.allclose(grad_1.grad, grad_2.grad)}")
print("Without compile:")
run_comparison(compile=False)
print("\nWith compile:")
run_comparison(compile=True)
It should output:
Without compile:
Output difference - Min: 0.00e+00, Max: 0.00e+00
query gradient difference - Min: 0.00e+00, Max: 0.00e+00
query gradients close: True
key gradient difference - Min: 0.00e+00, Max: 0.00e+00
key gradients close: True
value gradient difference - Min: 0.00e+00, Max: 0.00e+00
value gradients close: True
With compile:
Output difference - Min: 0.00e+00, Max: 0.00e+00
query gradient difference - Min: -1.04e+00, Max: 1.12e+00
query gradients close: False
key gradient difference - Min: -1.44e+00, Max: 1.68e+00
key gradients close: False
value gradient difference - Min: -1.23e+00, Max: 1.17e+00
value gradients close: False
On nighlty I can repro this. Just getting back from PTO so will take another look, I appreciate you diving deeper on this
Hey, just wanted to follow up here.
I'm happy to take a stab at debugging this again, but wanted to see if you have any findings you can share so I'm not duplicating any progress you've made
Still debugging TBH, I do agree that this is a high Pri issue and want to get to the bottom of it
Glad to hear it - I wish I had something to report after a decent chunk of yesterday spent trying to debug this, but I've got nothing new 😔
🐛 Describe the bug
I struggled a bit to get a repro, but I think this is in the realm of reasonable and identifies the behavior that causes my runs to diverge.
The difference between the compiled and eager versions is:
This might not seem like a big deal (maybe the matrix initialization is just too big after all) but consider the contrast if you use
torch.nn.functional.scaled_dot_product_attention
(with the same initialization and precision)cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng
Versions
'2.5.0.dev20240904+cu124'