triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.43k stars 1.65k forks source link

[BUG] triton.language.associative_scan returning incorrect results when `reverse=True` #4362

Open PheelaV opened 4 months ago

PheelaV commented 4 months ago

Hi,

I believe triton.language.associative_scan is returning incorrect results when reverse=True, or I could not figure out the desired behaviour. In the original PR described as "jax like" and rev(scan(rev(x))). So I compared different variants and also did a jax run.

Spotted a problem

Here I thought maybe not using out pointers is causing some sort of race condition as the result for the "exponent" f part is very clearly wrong. So I tested them pair-wise like so

import torch
import triton.language as tl
import triton

# Setup
@triton.jit
def op(fl, xl, fr, xr) -> tuple[torch.float, torch.float]:
  """First order linear recurrence operation.
     source: https://srush.github.io/annotated-mamba/hard.html
    """
  f = fr * fl
  x = fr * xl + xr
  return f, x

@triton.jit
def kernel1(exp_ref, vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None:
  input_range = tl.arange(0, BS)
  exp = tl.load(exp_ref + input_range)
  vals = tl.load(vals_ref + input_range)
  exp, vals = tl.associative_scan(
    (exp, vals), axis=0, combine_fn=op, reverse=reverse
  )

  tl.store(exp_ref + input_range, exp)
  tl.store(vals_ref + input_range, vals)

@triton.jit
def kernel2(exp_ref, vals_ref, out_exp_ref, out_vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None:
  input_range = tl.arange(0, BS)
  exp = tl.load(exp_ref + input_range)
  vals = tl.load(vals_ref + input_range)
  exp, vals = tl.associative_scan(
    (exp, vals), axis=0, combine_fn=op, reverse=reverse
  )

  tl.store(out_exp_ref + input_range, exp)
  tl.store(out_vals_ref + input_range, vals)

def init():
    BS=4
    device = torch.device("cuda")
    exp = torch.tensor([1.0, 1.5, 0.8, 2.0]).to(device)
    vals = torch.tensor([1.0, -1.0, 0.5, 2.0]).to(device)

    out_exp = torch.empty_like(exp)
    out_vals = torch.empty_like(vals)

    return BS, device, exp, vals, out_exp, out_vals

# Act
reverse = False
BS, device, exp, vals, out_exp, out_vals = init()
kernel1[(1,)](exp, vals, BS, reverse)
out_exp1, out_vals1 = exp, vals

BS, device, exp, vals, out_exp2, out_vals2 = init()
kernel2[(1,)](exp, vals, out_exp2, out_vals2, BS, reverse)
print(f"{out_exp1=}")
print(f"{out_vals1=}")
print()
print(f"{out_exp2=}")
print(f"{out_vals2=}")

print()
print()
reverse=True
BS, device, exp, vals, out_exp, out_vals = init()
kernel1[(1,)](exp, vals, BS, reverse)
out_exp_reverse1, out_vals_reverse1 = exp, vals

BS, device, exp, vals, out_exp_reverse2, out_vals_reverse2 = init()
kernel2[(1,)](exp, vals, out_exp_reverse2, out_vals_reverse2, BS, reverse)
print(f"{out_exp_reverse1=}")
print(f"{out_vals_reverse1=}")
print()
print(f"{out_exp_reverse2=}")
print(f"{out_vals_reverse2=}")

Output:

out_exp1=tensor([1.0000, 1.5000, 1.2000, 2.4000], device='cuda:0')
out_vals1=tensor([1.0000, 0.5000, 0.9000, 3.8000], device='cuda:0')

out_exp2=tensor([1.0000, 1.5000, 1.2000, 2.4000], device='cuda:0')
out_vals2=tensor([1.0000, 0.5000, 0.9000, 3.8000], device='cuda:0')

out_exp_reverse1=tensor([2.4000, 2.4000, 2.4000, 2.4000], device='cuda:0')
out_vals_reverse1=tensor([3.1500, 4.5500, 2.1000, 3.5000], device='cuda:0')

out_exp_reverse2=tensor([2.4000, 2.4000, 2.4000, 2.4000], device='cuda:0')
out_vals_reverse2=tensor([3.1500, 4.5500, 2.1000, 3.5000], device='cuda:0')

The reverse is incorrect.

What I believe is the correct result

Not trusting it I created a manual version as a reference and compared it to jax implementation.

Manual reference:

import torch

def op(fl, xl, fr, xr):
  f = fr * fl
  x = fr * xl + xr
  return f, x

def forward_scan(exp, vals):
  exp = torch.tensor(exp)
  vals = torch.tensor(vals)
  state_exp = [torch.tensor(1.0)]
  state_vals = [torch.tensor(0.0)]
  for i in range(len(exp)):
    new_exp, new_val = op(state_exp[-1], state_vals[-1], exp[i], vals[i])
    state_exp.append(new_exp)
    state_vals.append(new_val)
  return torch.stack(state_exp), torch.stack(state_vals)

def reverse_scan(exp, vals):
  exp = torch.tensor(exp)
  vals = torch.tensor(vals)

  state_f = torch.tensor(1.0)
  state_x = torch.tensor(0.0)

  f_results = [state_f]
  x_results = [state_x]

  # iterate in reverse, taking the accumulated state as the left side
  # and new iterated over states on the right side
  for i in range(len(exp) - 1, -1, -1):
    state_f, state_x = op(state_f, state_x, exp[i], vals[i])
    f_results.append(state_f)
    x_results.append(state_x)

  return torch.stack(f_results[1:][::-1]), torch.stack(x_results[1:][::-1])

exp = [1.0, 1.5, 0.8, 2.0]
vals = [1.0, -1.0, 0.5, 2.0]

forward_exp, forward_vals = forward_scan(exp, vals)
backward_exp, backward_vals = reverse_scan(exp, vals)

print("Forward scan results:")
print("gates", forward_exp)
print("tokens", forward_vals)

print("Backward scan results:")
print("gates", backward_exp)
print("tokens", backward_vals)

output

Forward scan results:
gates tensor([1.0000, 1.0000, 1.5000, 1.2000, 2.4000])
tokens tensor([0.0000, 1.0000, 0.5000, 0.9000, 3.8000])
Backward scan results:
gates tensor([2.4000, 2.4000, 1.6000, 2.0000])
tokens tensor([3.1500, 2.1500, 2.1000, 2.0000])

Jax reference

from jax import lax
import jax.numpy as jnp

result_add_1 = lax.associative_scan(jnp.add, jnp.arange(0, 4))
result_add_1_reverse = lax.associative_scan(jnp.add, jnp.arange(0, 4), reverse=True)
print(f"{result_add_1=}")
print(f"{result_add_1_reverse=}")

def op(left, right) -> tuple[float, float]:
    fl, xl = left
    fr, xr = right
    f = fl * fr
    x = fr * xl + xr
    return f, x

exp = jnp.array([1.0, 1.5, 0.8, 2.0])
vals = jnp.array([1.0, -1.0, 0.5, 2.0])

result_jax_normal = lax.associative_scan(op, (exp, vals))
print(result_jax_normal)
result_jax_reversed = lax.associative_scan(op, (exp, vals), reverse=True)
print(result_jax_reversed)

output

result_add_1=Array([0, 1, 3, 6], dtype=int32)
result_add_1_reverse=Array([6, 6, 5, 3], dtype=int32)
(Array([1. , 1.5, 1.2, 2.4], dtype=float32), Array([1. , 0.5, 0.9, 3.8], dtype=float32))
(Array([2.4, 2.4, 1.6, 2. ], dtype=float32), Array([3.1499999, 2.1499999, 2.1 , 2.], dtype=float32))

My intuition seems to be correct and both yield the same results.

Any ideas? #3177 #2930, kindly referencing @srush and the original usage in the mamba implementation.

Python: 3.10 and 3.12 Triton: 3.0 and triton-nightly

Two workarounds:

  1. use flip()
  2. read/write using a reversed range
Investigation into workaround and further testing. ## naive flip of axis, one possible workaround but I do not trust it, why does it suddenly change behavior as compared to out_ref? ```python @triton.jit def kernel3(exp_ref, vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None: input_range = tl.arange(0, BS) exp = tl.load(exp_ref + input_range) vals = tl.load(vals_ref + input_range) exp, vals = tl.associative_scan( (tl.flip(exp), tl.flip(vals)), axis=0, combine_fn=op, #reverse=reverse ) if reverse else tl.associative_scan( (exp, vals), axis=0, combine_fn=op ) tl.store(exp_ref + input_range, tl.flip(exp) if reverse else exp) tl.store(vals_ref + input_range, tl.flip(vals) if reverse else vals) @triton.jit def kernel4(exp_ref, vals_ref, out_exp_ref, out_vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None: input_range = tl.arange(0, BS) exp = tl.load(exp_ref + input_range) vals = tl.load(vals_ref + input_range) exp, vals = tl.associative_scan( (tl.flip(exp), tl.flip(vals)), axis=0, combine_fn=op, #reverse=reverse ) if reverse else tl.associative_scan( (exp, vals), axis=0, combine_fn=op ) tl.store(exp_ref + input_range, tl.flip(exp) if reverse else exp) tl.store(vals_ref + input_range, tl.flip(vals) if reverse else vals) def init(): BS=4 device = torch.device("cuda") exp = torch.tensor([1.0, 1.5, 0.8, 2.0]).to(device) vals = torch.tensor([1.0, -1.0, 0.5, 2.0]).to(device) out_exp = torch.empty_like(exp) out_vals = torch.empty_like(vals) return BS, device, exp, vals, out_exp, out_vals # Act reverse = False BS, device, exp, vals, out_exp, out_vals = init() kernel3[(1,)](exp, vals, BS, reverse) out_exp1, out_vals1 = exp, vals BS, device, exp, vals, out_exp2, out_vals2 = init() kernel4[(1,)](exp, vals, out_exp2, out_vals2, BS, reverse) print(f"{out_exp1=}") print(f"{out_vals1=}") print() print(f"{out_exp2=}") print(f"{out_vals2=}") print() print() reverse=True BS, device, exp, vals, out_exp, out_vals = init() kernel3[(1,)](exp, vals, BS, reverse) out_exp_reverse1, out_vals_reverse1 = exp, vals BS, device, exp, vals, out_exp_reverse2, out_vals_reverse2 = init() kernel4[(1,)](exp, vals, out_exp_reverse2, out_vals_reverse2, BS, reverse) print(f"{out_exp_reverse1=}") print(f"{out_vals_reverse1=}") print() print(f"{out_exp_reverse2=}") print(f"{out_vals_reverse2=}") ``` output ```console out_exp1=tensor([1.0000, 1.5000, 1.2000, 2.4000], device='cuda:0') out_vals1=tensor([1.0000, 0.5000, 0.9000, 3.8000], device='cuda:0') out_exp2=tensor([2.8924e+28, 2.0208e+00, 0.0000e+00, 0.0000e+00], device='cuda:0') out_vals2=tensor([0., 0., 0., 0.], device='cuda:0') out_exp_reverse1=tensor([2.4000, 2.4000, 1.6000, 2.0000], device='cuda:0') out_vals_reverse1=tensor([3.1500, 2.1500, 2.1000, 2.0000], device='cuda:0') out_exp_reverse2=tensor([1.0842e-19, 2.1437e+00, 2.0000e+00, 2.2844e+00], device='cuda:0') out_vals_reverse2=tensor([2.0000e+00, 2.2844e+00, 1.4013e-45, 0.0000e+00], device='cuda:0') ``` somehow only the in-place replacement seems to have gotten the correct ## usage in the mamba blog with shifted gates ```python @triton.jit def kernel5(exp_ref, vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None: input_range = tl.arange(0, BS) exp = tl.load(exp_ref + input_range) vals = tl.load(vals_ref + input_range) exp, vals = tl.associative_scan( (exp, vals), axis=0, combine_fn=op, reverse=reverse ) tl.store(exp_ref + input_range, exp) tl.store(vals_ref + input_range, vals) @triton.jit def kernel6(exp_ref, vals_ref, out_exp_ref, out_vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None: input_range = tl.arange(0, BS) exp = tl.load(exp_ref + input_range) vals = tl.load(vals_ref + input_range) exp, vals = tl.associative_scan( (exp, vals), axis=0, combine_fn=op, reverse=reverse ) tl.store(out_exp_ref + input_range, exp) tl.store(out_vals_ref + input_range, vals) # Act reverse = False BS, device, exp, vals, out_exp, out_vals = init() kernel5[(1,)](exp, vals, BS, reverse) out_exp1, out_vals1 = exp, vals BS, device, exp, vals, out_exp2, out_vals2 = init() kernel6[(1,)](exp, vals, out_exp2, out_vals2, BS, reverse) print(f"{out_exp1=}") print(f"{out_vals1=}") print() print(f"{out_exp2=}") print(f"{out_vals2=}") print() print() reverse=True BS, device, exp, vals, out_exp, out_vals = init() kernel5[(1,)](exp, vals, BS, reverse) out_exp_reverse1, out_vals_reverse1 = exp, vals BS, device, exp, vals, out_exp_reverse2, out_vals_reverse2 = init() exp = torch.cat([exp[1:], torch.tensor([1]).to(device)]) kernel6[(1,)](exp, vals, out_exp_reverse2, out_vals_reverse2, BS, reverse) print(f"{out_exp_reverse1=}") print(f"{out_vals_reverse1=}") print() print(f"{out_exp_reverse2=}") print(f"{out_vals_reverse2=}") ``` output ```console out_exp1=tensor([1.0000, 1.5000, 1.2000, 2.4000], device='cuda:0') out_vals1=tensor([1.0000, 0.5000, 0.9000, 3.8000], device='cuda:0') out_exp2=tensor([1.0000, 1.5000, 1.2000, 2.4000], device='cuda:0') out_vals2=tensor([1.0000, 0.5000, 0.9000, 3.8000], device='cuda:0') out_exp_reverse1=tensor([2.4000, 2.4000, 2.4000, 2.4000], device='cuda:0') out_vals_reverse1=tensor([3.1500, 4.5500, 2.1000, 3.5000], device='cuda:0') out_exp_reverse2=tensor([2.4000, 2.4000, 2.4000, 2.4000], device='cuda:0') out_vals_reverse2=tensor([4.9000, 4.2000, 3.5000, 2.1000], device='cuda:0') ``` both reverses are wrong ## workaround1: reversing indexes ```python @triton.jit def kernel7(exp_ref, vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None: input_range = tl.arange(0, BS) input_range = (BS - 1 - input_range) if reverse else input_range exp = tl.load(exp_ref + input_range) vals = tl.load(vals_ref + input_range) exp, vals = tl.associative_scan( (exp, vals), axis=0, combine_fn=op ) tl.store(exp_ref + input_range, exp) tl.store(vals_ref + input_range, vals) @triton.jit def kernel8(exp_ref, vals_ref, out_exp_ref, out_vals_ref, BS: tl.constexpr, reverse: tl.constexpr) -> None: input_range = tl.arange(0, BS) input_range = (BS - 1 - input_range) if reverse else input_range exp = tl.load(exp_ref + input_range) vals = tl.load(vals_ref + input_range) exp, vals = tl.associative_scan( (exp, vals), axis=0, combine_fn=op ) tl.store(out_exp_ref + input_range, exp) tl.store(out_vals_ref + input_range, vals) # Act reverse = False BS, device, exp, vals, out_exp, out_vals = init() kernel7[(1,)](exp, vals, BS, reverse) out_exp1, out_vals1 = exp, vals BS, device, exp, vals, out_exp2, out_vals2 = init() kernel8[(1,)](exp, vals, out_exp2, out_vals2, BS, reverse) print(f"{out_exp1=}") print(f"{out_vals1=}") print() print(f"{out_exp2=}") print(f"{out_vals2=}") print() print() reverse=True BS, device, exp, vals, out_exp, out_vals = init() kernel7[(1,)](exp, vals, BS, reverse) out_exp_reverse1, out_vals_reverse1 = exp, vals BS, device, exp, vals, out_exp_reverse2, out_vals_reverse2 = init() kernel8[(1,)](exp, vals, out_exp_reverse2, out_vals_reverse2, BS, reverse) print(f"{out_exp_reverse1=}") print(f"{out_vals_reverse1=}") print() print(f"{out_exp_reverse2=}") print(f"{out_vals_reverse2=}") ``` output ```console out_exp1=tensor([1.0000, 1.5000, 1.2000, 2.4000], device='cuda:0') out_vals1=tensor([1.0000, 0.5000, 0.9000, 3.8000], device='cuda:0') out_exp2=tensor([1.0000, 1.5000, 1.2000, 2.4000], device='cuda:0') out_vals2=tensor([1.0000, 0.5000, 0.9000, 3.8000], device='cuda:0') out_exp_reverse1=tensor([2.4000, 2.4000, 1.6000, 2.0000], device='cuda:0') out_vals_reverse1=tensor([3.1500, 2.1500, 2.1000, 2.0000], device='cuda:0') out_exp_reverse2=tensor([2.4000, 2.4000, 1.6000, 2.0000], device='cuda:0') out_vals_reverse2=tensor([3.1500, 2.1500, 2.1000, 2.0000], device='cuda:0') ``` both correct ## Performance of the variants that yielded correct results? ```python import torch import triton.language as tl import triton torch.random.manual_seed(1155) def init_rand(seqlen): BS=4 device = torch.device("cuda") exp = torch.rand(seqlen).to(device) vals = torch.rand(seqlen).to(device) out_exp = torch.empty_like(exp) out_vals = torch.empty_like(vals) return BS, device, exp, vals, out_exp, out_vals lines=["in_place_flip", "in_place_rev", "out_ref_rev"] @triton.testing.perf_report([ triton.testing.Benchmark( x_names=["seqlen"], x_vals=[2**i for i in range(7,20)], xlabel='sequence length', ylabel='ms', x_log=True, y_log=True, line_arg="benched", line_vals=lines, line_names=lines, plot_name="reversing scan", args={ } ), ]) def bench(benched, seqlen, device="cuda"): BS, device, exp, vals, out_exp, out_vals = init_rand(seqlen) match benched: case "in_place_flip": subject = lambda:kernel3[(1,)](exp, vals, BS, True) case "in_place_rev": subject = lambda:kernel7[(1,)](exp, vals, BS, True) case "out_ref_rev": subject = lambda:kernel8[(1,)](exp, vals, out_exp, out_vals, BS, True) ms = triton.testing.do_bench(subject, warmup=1000, rep=200) print(f"{seqlen=};\t{benched=};{ms=}") return ms bench.run(save_path="./bench_results", print_data=True) ``` output ```console reversing scan: seqlen in_place_flip in_place_rev out_ref_rev 0 128.0 0.007101 0.004218 0.003976 1 256.0 0.004264 0.003972 0.003923 2 512.0 0.004228 0.003937 0.004008 3 1024.0 0.004271 0.004030 0.003985 4 2048.0 0.004283 0.004002 0.003961 5 4096.0 0.004266 0.004031 0.003958 6 8192.0 0.004274 0.003972 0.004049 7 16384.0 0.004264 0.003959 0.004006 8 32768.0 0.004260 0.003984 0.004011 9 65536.0 0.004251 0.004020 0.004054 10 131072.0 0.004317 0.004037 0.004085 11 262144.0 0.004282 0.004010 0.004020 12 524288.0 0.004318 0.004024 0.004000 ``` ![image](https://github.com/user-attachments/assets/d2fd8da3-9d1a-40f2-8ad0-73425405fccd) and for reverse=False ```console reversing scan: seqlen in_place_flip in_place_rev out_ref_rev 0 128.0 0.006323 0.004214 0.004002 1 256.0 0.003979 0.004000 0.003952 2 512.0 0.003979 0.004011 0.004018 3 1024.0 0.004027 0.003972 0.004007 4 2048.0 0.003967 0.004033 0.003976 5 4096.0 0.004024 0.003985 0.004023 6 8192.0 0.004042 0.004050 0.004083 7 16384.0 0.003983 0.004045 0.003997 8 32768.0 0.004036 0.004020 0.004061 9 65536.0 0.004019 0.004021 0.004017 10 131072.0 0.004057 0.004016 0.004059 11 262144.0 0.004033 0.004002 0.004021 12 524288.0 0.003994 0.003968 0.004008 ``` ![image](https://github.com/user-attachments/assets/a85b6c4d-32f8-4db2-9eb4-5d33da97f0cd) On the `reverse=True`, the results on a T4 is tends to favour the in-place reversed index variant, but on a A100 it was inconclusive, mind you this is ran on google collab. in-place flip does worse. On the `reverse=False`, they are all mostly the same.
srush commented 3 months ago

We debugged it a bit and foudn that this issue only occurs with length <32. I am going to send a PR to block these sequences for now and also try to debug why that happens.

Hprairie commented 3 months ago

Are there any updates on this? I have also observed the following behavior with <32 length, on both AMD and NVIDIA gpus.

PheelaV commented 3 months ago

@Hprairie Not yet as far as I know, but look at the workarounds in the details of the report. The reversed=False works well, one can load the memory pointers in reversed order and save them back like that, or for <32 elements you might be better off with a linear scan using a simple for loop.

srush commented 3 months ago

yeah I think under 32 this is not going to be fast anyway. That being said I am looking into a fix.

Hprairie commented 3 months ago

I have found that tl.flip doesn't fix the problem and that it still persists when doing something like tl.flip(tl.scan(tl.flip())). It may also be happening for >32 lengths when running a scan with a 2d tensor. I will work on getting a reproducible script to help with debugging.

anasiri commented 3 months ago

I also have the same issue when using tl.cumsum (which uses the scan) with reverse=True.