Closed pommedeterresautee closed 1 year ago
I cannot reproduce this problem using triton/main. Can you double check?
I can reproduce it on a new machine with an A100 from Lambdalabs (no apt update) and triton main from this morning (28ea484dab08a0ec76164fec39323252f5db3eeb):
ubuntu@132-145-138-145:~$ nvidia-smi
Fri Mar 31 13:49:45 2023
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.65.01 Driver Version: 515.65.01 CUDA Version: 11.7 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA A100-SXM... On | 00000000:06:00.0 Off | 0 |
| N/A 35C P0 69W / 400W | 0MiB / 40960MiB | 0% Default |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
ubuntu@132-145-138-145:~$ python script.py
Traceback (most recent call last):
File "script.py", line 25, in <module>
assert (lock == 2).all().item(), lock
AssertionError: tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2,
1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 1, 2, 2, 1, 2,
2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 1, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 1, 1, 1,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 1, 2, 2, 2, 1, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 1, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 1, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2,
1, 2, 2, 2, 1, 2, 2, 1, 1, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2,
2, 1, 1, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2, 2, 1, 1, 1, 2, 1,
2, 2, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 2, 2, 2, 2, 1,
1, 1, 1, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 1, 2, 1, 1, 2, 1, 2, 1, 1, 1, 1,
1, 1, 1, 2, 2, 1, 2, 1, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 2,
2, 2, 2, 2, 2, 1, 2, 1], device='cuda:0', dtype=torch.int32)
ubuntu@132-145-138-145:~$ cat script.py
import torch
import triton
import triton.language as tl
@triton.jit()
def dead_lock(
locks,
size: tl.constexpr,
):
pid = tl.program_id(0)
pid = pid % size
if tl.atomic_xchg(locks + pid, 1) == 0: # original value is 0
tl.atomic_xchg(locks + pid, 2)
# else:
# while tl.atomic_xchg(locks + pid, 3) != 2: # supposed to return the old value (2 at this point)
# pass
lock = torch.zeros(800, dtype=torch.int32, device='cuda')
for _ in range(1000):
lock.zero_()
dead_lock[(lock.numel()*2,)](lock, lock.numel())
assert (lock == 2).all().item(), lock
I guess may be bigger matrix and a for loop to retry may helps to raise the issue?
Also the PTX is a bit strange, at the end there are 2 consecutive calls to barrier sync, the second one seem to be redundant.
//
// Generated by LLVM NVPTX Back-End
//
.version 8.0
.target sm_86
.address_size 64
// .globl dead_lock_0d
.extern .shared .align 1 .b8 global_smem[];
.visible .entry dead_lock_0d(
.param .u64 dead_lock_0d_param_0
)
.maxntid 128, 1, 1
{
.reg .pred %p<6>;
.reg .b32 %r<18>;
.reg .b64 %rd<6>;
ld.param.u64 %rd3, [dead_lock_0d_param_0];
mov.u32 %r6, %ctaid.x;
mul.hi.s32 %r7, %r6, 1374389535;
shr.u32 %r8, %r7, 31;
shr.s32 %r9, %r7, 8;
add.s32 %r10, %r9, %r8;
mul.lo.s32 %r11, %r10, 800;
sub.s32 %r12, %r6, %r11;
mul.wide.s32 %rd4, %r12, 4;
add.s64 %rd5, %rd3, %rd4;
mov.u32 %r1, %tid.x;
membar.gl ;
setp.eq.s32 %p4, %r1, 0;
mov.u32 %r3, 1;
mov.u32 %r2, 0x0;
@%p4 atom.global.gpu.exch.b32 %r2, [ %rd5 + 0 ], %r3;
mov.u32 %r4, global_smem;
@%p4 st.shared.b32 [ %r4 + 0 ], %r2;
bar.sync 0;
ld.shared.u32 %r13, [global_smem];
bar.sync 0;
setp.ne.s32 %p3, %r13, 0;
@%p3 bra $L__BB0_2;
bar.sync 0;
membar.gl ;
mov.u32 %r15, 2;
mov.u32 %r17, 0x0;
@%p4 atom.global.gpu.exch.b32 %r17, [ %rd5 + 0 ], %r15;
@%p4 st.shared.b32 [ %r4 + 0 ], %r17;
bar.sync 0;
bar.sync 0;
$L__BB0_2:
ret;
}
I think the result is actually expected.
Let's say you have block_id 0 and block_id 800.
block_id 0 first calls atomic_xchg([0], 1)->0
, then it calls atomic_xchg([0], 2)->1
. Next block_id 800 calls atomic_xchg([0], 1)->2
. Shouldn't location[0] == 1?
Just a suggestion, probably this is what you want?
@triton.jit()
def dead_lock(
locks,
size: tl.constexpr,
):
pid = tl.program_id(0)
pid = pid % size # every lock position got 2 CTAs accessing it
if tl.atomic_cas(locks + pid, 0, 1) == 0: # original value is 0
tl.atomic_xchg(locks + pid, 2)
Thank you for your point, I think I now understand, even if ops are atomic, the design of this snippet do not provide any guarantee on the order of execution of the ops (the first xchg of the second visit can be executed before or after the second xchg of the first visit), meaning it's expected to have a mix of 1 and 2s.
There is still the 2nd question about the double sync barrier in PTX, is it expected?
Thank you for your point, I think I now understand, even if ops are atomic, the design of this snippet do not provide any guarantee on the order of execution of the ops (the first xchg of the second visit can be executed before or after the second xchg of the first visit), meaning it's expected to have a mix of 1 and 2s.
The suggested snipped is the one I use in the stream k kernel, now with your help I understand better why that one works :-)
There is still the 2nd question about the double sync barrier in PTX, is it expected?
There is still the 2nd question about the double sync barrier in PTX, is it expected?
We know about it. It's not a big deal, since (1) ptxas may remove them; (2) even if not, two barriers back-to-back shouldn't have much overhead since the threads are already synchronized by the time they reach the second barrier.
Thank you, it makes sense
The kernel below dead locks (when the lock matrix is big enough), and my understanding is that it should not. If you comment the else branch part and print the lock matrix, you will notice that it's full of 1 and 2s, where it's expected to have it full of 2s.
Test ran on 3090 RTX and triton main branch.
related to https://github.com/openai/triton/issues/1393 and https://triton-lang.slack.com/archives/C050BRSAH4Z/p1680168403507209