Open PheelaV opened 4 months ago
We debugged it a bit and foudn that this issue only occurs with length <32. I am going to send a PR to block these sequences for now and also try to debug why that happens.
Are there any updates on this? I have also observed the following behavior with <32 length, on both AMD and NVIDIA gpus.
@Hprairie Not yet as far as I know, but look at the workarounds in the details of the report. The reversed=False
works well, one can load the memory pointers in reversed order and save them back like that, or for <32 elements you might be better off with a linear scan using a simple for
loop.
yeah I think under 32 this is not going to be fast anyway. That being said I am looking into a fix.
I have found that tl.flip
doesn't fix the problem and that it still persists when doing something like tl.flip(tl.scan(tl.flip()))
. It may also be happening for >32 lengths when running a scan with a 2d tensor. I will work on getting a reproducible script to help with debugging.
I also have the same issue when using tl.cumsum (which uses the scan) with reverse=True.
Hi,
I believe triton.language.associative_scan is returning incorrect results when
reverse=True
, or I could not figure out the desired behaviour. In the original PR described as "jax like" andrev(scan(rev(x)))
. So I compared different variants and also did a jax run.Spotted a problem
Here I thought maybe not using out pointers is causing some sort of race condition as the result for the "exponent" f part is very clearly wrong. So I tested them pair-wise like so
Output:
The reverse is incorrect.
What I believe is the correct result
Not trusting it I created a manual version as a reference and compared it to jax implementation.
Manual reference:
output
Jax reference
output
My intuition seems to be correct and both yield the same results.
Any ideas? #3177 #2930, kindly referencing @srush and the original usage in the mamba implementation.
Python: 3.10 and 3.12 Triton: 3.0 and triton-nightly
Two workarounds:
Investigation into workaround and further testing.
## naive flip of axis, one possible workaround but I do not trust it, why does it suddenly change behavior as compared to out_ref? ```python @triton.jit def kernel3(exp_ref, vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None: input_range = tl.arange(0, BS) exp = tl.load(exp_ref + input_range) vals = tl.load(vals_ref + input_range) exp, vals = tl.associative_scan( (tl.flip(exp), tl.flip(vals)), axis=0, combine_fn=op, #reverse=reverse ) if reverse else tl.associative_scan( (exp, vals), axis=0, combine_fn=op ) tl.store(exp_ref + input_range, tl.flip(exp) if reverse else exp) tl.store(vals_ref + input_range, tl.flip(vals) if reverse else vals) @triton.jit def kernel4(exp_ref, vals_ref, out_exp_ref, out_vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None: input_range = tl.arange(0, BS) exp = tl.load(exp_ref + input_range) vals = tl.load(vals_ref + input_range) exp, vals = tl.associative_scan( (tl.flip(exp), tl.flip(vals)), axis=0, combine_fn=op, #reverse=reverse ) if reverse else tl.associative_scan( (exp, vals), axis=0, combine_fn=op ) tl.store(exp_ref + input_range, tl.flip(exp) if reverse else exp) tl.store(vals_ref + input_range, tl.flip(vals) if reverse else vals) def init(): BS=4 device = torch.device("cuda") exp = torch.tensor([1.0, 1.5, 0.8, 2.0]).to(device) vals = torch.tensor([1.0, -1.0, 0.5, 2.0]).to(device) out_exp = torch.empty_like(exp) out_vals = torch.empty_like(vals) return BS, device, exp, vals, out_exp, out_vals # Act reverse = False BS, device, exp, vals, out_exp, out_vals = init() kernel3[(1,)](exp, vals, BS, reverse) out_exp1, out_vals1 = exp, vals BS, device, exp, vals, out_exp2, out_vals2 = init() kernel4[(1,)](exp, vals, out_exp2, out_vals2, BS, reverse) print(f"{out_exp1=}") print(f"{out_vals1=}") print() print(f"{out_exp2=}") print(f"{out_vals2=}") print() print() reverse=True BS, device, exp, vals, out_exp, out_vals = init() kernel3[(1,)](exp, vals, BS, reverse) out_exp_reverse1, out_vals_reverse1 = exp, vals BS, device, exp, vals, out_exp_reverse2, out_vals_reverse2 = init() kernel4[(1,)](exp, vals, out_exp_reverse2, out_vals_reverse2, BS, reverse) print(f"{out_exp_reverse1=}") print(f"{out_vals_reverse1=}") print() print(f"{out_exp_reverse2=}") print(f"{out_vals_reverse2=}") ``` output ```console out_exp1=tensor([1.0000, 1.5000, 1.2000, 2.4000], device='cuda:0') out_vals1=tensor([1.0000, 0.5000, 0.9000, 3.8000], device='cuda:0') out_exp2=tensor([2.8924e+28, 2.0208e+00, 0.0000e+00, 0.0000e+00], device='cuda:0') out_vals2=tensor([0., 0., 0., 0.], device='cuda:0') out_exp_reverse1=tensor([2.4000, 2.4000, 1.6000, 2.0000], device='cuda:0') out_vals_reverse1=tensor([3.1500, 2.1500, 2.1000, 2.0000], device='cuda:0') out_exp_reverse2=tensor([1.0842e-19, 2.1437e+00, 2.0000e+00, 2.2844e+00], device='cuda:0') out_vals_reverse2=tensor([2.0000e+00, 2.2844e+00, 1.4013e-45, 0.0000e+00], device='cuda:0') ``` somehow only the in-place replacement seems to have gotten the correct ## usage in the mamba blog with shifted gates ```python @triton.jit def kernel5(exp_ref, vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None: input_range = tl.arange(0, BS) exp = tl.load(exp_ref + input_range) vals = tl.load(vals_ref + input_range) exp, vals = tl.associative_scan( (exp, vals), axis=0, combine_fn=op, reverse=reverse ) tl.store(exp_ref + input_range, exp) tl.store(vals_ref + input_range, vals) @triton.jit def kernel6(exp_ref, vals_ref, out_exp_ref, out_vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None: input_range = tl.arange(0, BS) exp = tl.load(exp_ref + input_range) vals = tl.load(vals_ref + input_range) exp, vals = tl.associative_scan( (exp, vals), axis=0, combine_fn=op, reverse=reverse ) tl.store(out_exp_ref + input_range, exp) tl.store(out_vals_ref + input_range, vals) # Act reverse = False BS, device, exp, vals, out_exp, out_vals = init() kernel5[(1,)](exp, vals, BS, reverse) out_exp1, out_vals1 = exp, vals BS, device, exp, vals, out_exp2, out_vals2 = init() kernel6[(1,)](exp, vals, out_exp2, out_vals2, BS, reverse) print(f"{out_exp1=}") print(f"{out_vals1=}") print() print(f"{out_exp2=}") print(f"{out_vals2=}") print() print() reverse=True BS, device, exp, vals, out_exp, out_vals = init() kernel5[(1,)](exp, vals, BS, reverse) out_exp_reverse1, out_vals_reverse1 = exp, vals BS, device, exp, vals, out_exp_reverse2, out_vals_reverse2 = init() exp = torch.cat([exp[1:], torch.tensor([1]).to(device)]) kernel6[(1,)](exp, vals, out_exp_reverse2, out_vals_reverse2, BS, reverse) print(f"{out_exp_reverse1=}") print(f"{out_vals_reverse1=}") print() print(f"{out_exp_reverse2=}") print(f"{out_vals_reverse2=}") ``` output ```console out_exp1=tensor([1.0000, 1.5000, 1.2000, 2.4000], device='cuda:0') out_vals1=tensor([1.0000, 0.5000, 0.9000, 3.8000], device='cuda:0') out_exp2=tensor([1.0000, 1.5000, 1.2000, 2.4000], device='cuda:0') out_vals2=tensor([1.0000, 0.5000, 0.9000, 3.8000], device='cuda:0') out_exp_reverse1=tensor([2.4000, 2.4000, 2.4000, 2.4000], device='cuda:0') out_vals_reverse1=tensor([3.1500, 4.5500, 2.1000, 3.5000], device='cuda:0') out_exp_reverse2=tensor([2.4000, 2.4000, 2.4000, 2.4000], device='cuda:0') out_vals_reverse2=tensor([4.9000, 4.2000, 3.5000, 2.1000], device='cuda:0') ``` both reverses are wrong ## workaround1: reversing indexes ```python @triton.jit def kernel7(exp_ref, vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None: input_range = tl.arange(0, BS) input_range = (BS - 1 - input_range) if reverse else input_range exp = tl.load(exp_ref + input_range) vals = tl.load(vals_ref + input_range) exp, vals = tl.associative_scan( (exp, vals), axis=0, combine_fn=op ) tl.store(exp_ref + input_range, exp) tl.store(vals_ref + input_range, vals) @triton.jit def kernel8(exp_ref, vals_ref, out_exp_ref, out_vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None: input_range = tl.arange(0, BS) input_range = (BS - 1 - input_range) if reverse else input_range exp = tl.load(exp_ref + input_range) vals = tl.load(vals_ref + input_range) exp, vals = tl.associative_scan( (exp, vals), axis=0, combine_fn=op ) tl.store(out_exp_ref + input_range, exp) tl.store(out_vals_ref + input_range, vals) # Act reverse = False BS, device, exp, vals, out_exp, out_vals = init() kernel7[(1,)](exp, vals, BS, reverse) out_exp1, out_vals1 = exp, vals BS, device, exp, vals, out_exp2, out_vals2 = init() kernel8[(1,)](exp, vals, out_exp2, out_vals2, BS, reverse) print(f"{out_exp1=}") print(f"{out_vals1=}") print() print(f"{out_exp2=}") print(f"{out_vals2=}") print() print() reverse=True BS, device, exp, vals, out_exp, out_vals = init() kernel7[(1,)](exp, vals, BS, reverse) out_exp_reverse1, out_vals_reverse1 = exp, vals BS, device, exp, vals, out_exp_reverse2, out_vals_reverse2 = init() kernel8[(1,)](exp, vals, out_exp_reverse2, out_vals_reverse2, BS, reverse) print(f"{out_exp_reverse1=}") print(f"{out_vals_reverse1=}") print() print(f"{out_exp_reverse2=}") print(f"{out_vals_reverse2=}") ``` output ```console out_exp1=tensor([1.0000, 1.5000, 1.2000, 2.4000], device='cuda:0') out_vals1=tensor([1.0000, 0.5000, 0.9000, 3.8000], device='cuda:0') out_exp2=tensor([1.0000, 1.5000, 1.2000, 2.4000], device='cuda:0') out_vals2=tensor([1.0000, 0.5000, 0.9000, 3.8000], device='cuda:0') out_exp_reverse1=tensor([2.4000, 2.4000, 1.6000, 2.0000], device='cuda:0') out_vals_reverse1=tensor([3.1500, 2.1500, 2.1000, 2.0000], device='cuda:0') out_exp_reverse2=tensor([2.4000, 2.4000, 1.6000, 2.0000], device='cuda:0') out_vals_reverse2=tensor([3.1500, 2.1500, 2.1000, 2.0000], device='cuda:0') ``` both correct ## Performance of the variants that yielded correct results? ```python import torch import triton.language as tl import triton torch.random.manual_seed(1155) def init_rand(seqlen): BS=4 device = torch.device("cuda") exp = torch.rand(seqlen).to(device) vals = torch.rand(seqlen).to(device) out_exp = torch.empty_like(exp) out_vals = torch.empty_like(vals) return BS, device, exp, vals, out_exp, out_vals lines=["in_place_flip", "in_place_rev", "out_ref_rev"] @triton.testing.perf_report([ triton.testing.Benchmark( x_names=["seqlen"], x_vals=[2**i for i in range(7,20)], xlabel='sequence length', ylabel='ms', x_log=True, y_log=True, line_arg="benched", line_vals=lines, line_names=lines, plot_name="reversing scan", args={ } ), ]) def bench(benched, seqlen, device="cuda"): BS, device, exp, vals, out_exp, out_vals = init_rand(seqlen) match benched: case "in_place_flip": subject = lambda:kernel3[(1,)](exp, vals, BS, True) case "in_place_rev": subject = lambda:kernel7[(1,)](exp, vals, BS, True) case "out_ref_rev": subject = lambda:kernel8[(1,)](exp, vals, out_exp, out_vals, BS, True) ms = triton.testing.do_bench(subject, warmup=1000, rep=200) print(f"{seqlen=};\t{benched=};{ms=}") return ms bench.run(save_path="./bench_results", print_data=True) ``` output ```console reversing scan: seqlen in_place_flip in_place_rev out_ref_rev 0 128.0 0.007101 0.004218 0.003976 1 256.0 0.004264 0.003972 0.003923 2 512.0 0.004228 0.003937 0.004008 3 1024.0 0.004271 0.004030 0.003985 4 2048.0 0.004283 0.004002 0.003961 5 4096.0 0.004266 0.004031 0.003958 6 8192.0 0.004274 0.003972 0.004049 7 16384.0 0.004264 0.003959 0.004006 8 32768.0 0.004260 0.003984 0.004011 9 65536.0 0.004251 0.004020 0.004054 10 131072.0 0.004317 0.004037 0.004085 11 262144.0 0.004282 0.004010 0.004020 12 524288.0 0.004318 0.004024 0.004000 ``` ![image](https://github.com/user-attachments/assets/d2fd8da3-9d1a-40f2-8ad0-73425405fccd) and for reverse=False ```console reversing scan: seqlen in_place_flip in_place_rev out_ref_rev 0 128.0 0.006323 0.004214 0.004002 1 256.0 0.003979 0.004000 0.003952 2 512.0 0.003979 0.004011 0.004018 3 1024.0 0.004027 0.003972 0.004007 4 2048.0 0.003967 0.004033 0.003976 5 4096.0 0.004024 0.003985 0.004023 6 8192.0 0.004042 0.004050 0.004083 7 16384.0 0.003983 0.004045 0.003997 8 32768.0 0.004036 0.004020 0.004061 9 65536.0 0.004019 0.004021 0.004017 10 131072.0 0.004057 0.004016 0.004059 11 262144.0 0.004033 0.004002 0.004021 12 524288.0 0.003994 0.003968 0.004008 ``` ![image](https://github.com/user-attachments/assets/a85b6c4d-32f8-4db2-9eb4-5d33da97f0cd) On the `reverse=True`, the results on a T4 is tends to favour the in-place reversed index variant, but on a A100 it was inconclusive, mind you this is ran on google collab. in-place flip does worse. On the `reverse=False`, they are all mostly the same.