sustcsonglin / flash-linear-attention

Efficient implementations of state-of-the-art linear attention models in Pytorch and Triton
MIT License
1.24k stars 66 forks source link

bug in treatment of scale for fused_chunk_linear_attn #47

Closed SmerkyG closed 1 month ago

SmerkyG commented 1 month ago

Thanks for the amazing library!

I discovered an error in https://github.com/sustcsonglin/flash-linear-attention/blob/main/fla/ops/linear_attn/chunk_fuse.py

It works great with scale=1.0, but with scale=-1 (therefore q.shape[-1] ** -0.5) it gives significantly incorrect results, mismatching the results from chunk_linear_attn. The chunk_linear_attn function seems to work fine with other scales such as -1.

Unfortunately I'm not sure where in the triton code the application of scale is done incorrectly.

yzhangcs commented 1 month ago

@SmerkyG Hi, just fix the grad. Thanks for reporting this bug