Open XintianHan opened 9 months ago
Hello, I have update the code right now, and the code support head_dim=192. Can you try the example again?
Hello, I have update the code right now, and the code support head_dim=192. Can you try the example again?
Hi. Thanks for the quick reply. I think I still have problem with dimensions not equaling to the power of 2.
Here is what I ran
from lightning_attn.ops import lightning_attn_func
dtype = torch.bfloat16
device = torch.device("cuda")
b, h, n, d, e = 1, 16, 2, 192, 96
q = torch.randn((b, h, n, d), dtype=dtype, device=device).bfloat16().requires_grad_()
k = torch.randn((b, h, n, d), dtype=dtype, device=device).bfloat16().requires_grad_()
v = torch.randn((b, h, n, e), dtype=dtype, device=device).bfloat16().requires_grad_()
s = torch.randn(h, 1, 1).to(q)
o = torch.sum(lightning_attn_func(q, k, v, s))
o.backward()
Then the error happened at backward.
loc("/opt/tiger/mariana/lightning-attention-main/lightning_attn/ops/triton/lightning_attn2.py":123:64): error: Number of elements must be power-of-two, but %49 = "tt.make_range"() <{end = 96 : i32, start = 0 : i32}> : () -> tensor<96xi32> doesn't follow the rule (96) elements
Any thought here? Thank you so much!
Working on this.
This problem has been temporarily solved. The current solution is to use F.pad
. I will provide a more efficient solution in the future.
Thank you for the nice implementation! It seems that dim=192 is not in supports_dim. Why is it the case here? Could you add dim=192?
I tried this script
and got this error