It works great with scale=1.0, but with scale=-1 (therefore q.shape[-1] ** -0.5) it gives significantly incorrect results, mismatching the results from chunk_linear_attn. The chunk_linear_attn function seems to work fine with other scales such as -1.
Unfortunately I'm not sure where in the triton code the application of scale is done incorrectly.
Thanks for the amazing library!
I discovered an error in https://github.com/sustcsonglin/flash-linear-attention/blob/main/fla/ops/linear_attn/chunk_fuse.py
It works great with scale=1.0, but with scale=-1 (therefore q.shape[-1] ** -0.5) it gives significantly incorrect results, mismatching the results from chunk_linear_attn. The chunk_linear_attn function seems to work fine with other scales such as -1.
Unfortunately I'm not sure where in the triton code the application of scale is done incorrectly.