Open jansel opened 1 year ago
Here is a slightly smaller reproducer. The relevant part is having a load broadcasted over RBLOCK
so when RBLOCK
is distributed over multiple warps the read may happen after the store on some threads.
import torch
import triton
import triton.language as tl
@triton.jit
def triton_(in_out_ptr0, out_ptr0, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):
xindex = tl.program_id(0) * XBLOCK + tl.arange(0, XBLOCK)[:, None]
tmp0 = tl.load(in_out_ptr0 + xindex, None)
tl.store(in_out_ptr0 + xindex, tmp0 + 1, None)
tl.device_assert(tmp0 == 0)
rindex = tl.arange(0, RBLOCK)[None, :]
res = tl.sum(tmp0 + rindex, 0)[None, :]
tl.store(out_ptr0 + rindex, res)
xblock, rblock = 8, 512
grid_size = 1000
for _ in range(100):
buf7 = torch.zeros(grid_size * xblock, device="cuda")
rout = torch.empty(rblock, device="cuda")
triton_[(grid_size,)](buf7, rout, xblock, rblock)
torch.cuda.synchronize()
Yeah... triton would need to install a memory fence here. We have tl.debug_barrier
as a workaround for now... Fixing it properly would require some heavy alias analysis
How about exposing an argument in tl.store
so that users can specify whether the addresses could be alias of another that has been read?
For our specific use case (this is a result of inductor inplacing buffers and reusing an input's memory for an output), we will always read and write the exact same address. We also know when we are generating that pattern, so it would be easy to pass an extra arg to tl.store.
Hi @ptillet , what do you think regarding the proposed approach?
I still need to think about it a little more, but y how I'm leaving towards having a well defined but weaker memory model, and ensuring memory consistency via fences
This one seems pretty serious unless I am missing something.
Repro:
Output
However, with this change:
It runs fine.
It seems almost like the mutation of
in_out_ptr0
is corrupting the value oftmp7
.