Open yzhangcs opened 10 months ago
Mine is normal. NVIDIA A100 80GB PCIe, Triton nightly release
Testing BFloat16...
tensor([[ -3.0000, 4.3125, -5.2812, ..., -3.1094, 4.4062, 1.4141],
[ -4.1875, 8.4375, 7.3750, ..., -4.2188, 0.7227, 4.2188],
[ -8.3750, 9.2500, -3.0938, ..., -10.4375, 3.5312, -1.4688],
...,
[ 33.0000, -45.5000, -27.5000, ..., 47.5000, 11.7500, -42.5000],
[-12.8125, -13.3125, -29.5000, ..., -28.2500, -17.8750, -8.5625],
[ -8.0625, 19.0000, -25.1250, ..., 44.0000, 31.8750, 0.7148]],
device='cuda:0', dtype=torch.bfloat16)
tensor([[ -3.0000, 4.3125, -5.2812, ..., -3.1094, 4.4062, 1.4141],
[ -4.1875, 8.4375, 7.3750, ..., -4.2188, 0.7227, 4.2188],
[ -8.3750, 9.2500, -3.0938, ..., -10.4375, 3.5312, -1.4688],
...,
[ 33.0000, -45.5000, -27.5000, ..., 47.5000, 11.7500, -42.5000],
[-12.8125, -13.3125, -29.5000, ..., -28.2500, -17.8750, -8.5625],
[ -8.0625, 19.0000, -25.1250, ..., 44.0000, 31.8750, 0.7148]],
device='cuda:0', dtype=torch.bfloat16)
Diff: tensor(0., device='cuda:0', dtype=torch.bfloat16)
Testing Float...
tensor([[ -2.9934, 4.3165, -5.2756, ..., -3.1058, 4.4168, 1.4156],
[ -4.1945, 8.4496, 7.3791, ..., -4.2401, 0.7162, 4.2193],
[ -8.4092, 9.2836, -3.0875, ..., -10.4867, 3.5389, -1.4640],
...,
[ 32.8595, -45.5166, -27.4276, ..., 47.6876, 11.7916, -42.2014],
[-12.8712, -13.3366, -29.3908, ..., -28.2100, -17.8268, -8.5858],
[ -8.0902, 19.0255, -25.1759, ..., 44.1477, 31.9002, 0.8230]],
device='cuda:0')
tensor([[ -2.9934, 4.3165, -5.2756, ..., -3.1058, 4.4168, 1.4156],
[ -4.1945, 8.4496, 7.3791, ..., -4.2401, 0.7162, 4.2193],
[ -8.4092, 9.2836, -3.0875, ..., -10.4867, 3.5389, -1.4640],
...,
[ 32.8595, -45.5166, -27.4276, ..., 47.6876, 11.7916, -42.2014],
[-12.8712, -13.3366, -29.3908, ..., -28.2100, -17.8268, -8.5858],
[ -8.0902, 19.0255, -25.1759, ..., 44.1477, 31.9002, 0.8230]],
device='cuda:0')
Diff: tensor(0., device='cuda:0')
@Jokeren Hi, we @sustcsonglin found when using Triton 2.2 & H100 (A100 works), running this check will still give very strange results.
>>> python tests/test_fused_chunk.py
DTYPE STORE IFCOND DIFF
torch.float32 False False 0.0
torch.float32 False True 0.0
torch.float32 True False 0.0
torch.float32 True True 0.0
torch.bfloat16 False False 218.81393432617188
torch.bfloat16 False True 0.6739959716796875
torch.bfloat16 True False 218.81393432617188
torch.bfloat16 True True 0.6739959716796875
Can you figure out what happens?
The bug can be bypassed if adding some cond check at the for loop beginnings, like the above.
Can you try with triton/main?
Hi, I find a big results difference when using tl.store (under bfloat16).
I hve pasted the tailored code here for ease of reproduction. The only differnce between
attention_nostore_fwd_kernel
andattention_store_fwd_kernel
istl.store(p_h, b_h.to(p_h.dtype.element_ty))
, which saves the intermediate results to HBMs, and the output isThe results are consistant under float. With minor code changes, however, there is a big unacceptable difference in the final outputs under bfloat16 dtype. Also, the results of bfloat16 can be the same if the inputs are restricted in a very small range, e.g., divided by 1024.
I guess the evil stems from the precision of bfloat16. but I can't figure out why
tl.store
brings such a big difference, and how to solve this question. Could you give me some hints?The environment is Triton 2.1 & A100-SXM4-40GB.
Thanks.