triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.99k stars 1.59k forks source link

atomic_xchg doesn't behave as expected #1452

Closed pommedeterresautee closed 1 year ago

pommedeterresautee commented 1 year ago

The kernel below dead locks (when the lock matrix is big enough), and my understanding is that it should not. If you comment the else branch part and print the lock matrix, you will notice that it's full of 1 and 2s, where it's expected to have it full of 2s.

import torch
import triton
import triton.language as tl

@triton.jit()
def dead_lock(
        locks,
        size: tl.constexpr,
        ):
    pid = tl.program_id(0)
    pid = pid % size  # every lock position got 2 CTAs accessing it
    if tl.atomic_xchg(locks + pid, 1) == 0:  # original value is 0
        tl.atomic_xchg(locks + pid, 2)
    else:
        while tl.atomic_xchg(locks + pid, 3) != 2:  # supposed to return the old value (2 at this point)
            pass

lock = torch.zeros(800, dtype=torch.int32, device='cuda')

dead_lock[(lock.numel()*2,)](lock, lock.numel())
print(lock)

Test ran on 3090 RTX and triton main branch.

related to https://github.com/openai/triton/issues/1393 and https://triton-lang.slack.com/archives/C050BRSAH4Z/p1680168403507209

Jokeren commented 1 year ago

I cannot reproduce this problem using triton/main. Can you double check?

pommedeterresautee commented 1 year ago

I can reproduce it on a new machine with an A100 from Lambdalabs (no apt update) and triton main from this morning (28ea484dab08a0ec76164fec39323252f5db3eeb):

ubuntu@132-145-138-145:~$ nvidia-smi 
Fri Mar 31 13:49:45 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.65.01    Driver Version: 515.65.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-SXM...  On   | 00000000:06:00.0 Off |                    0 |
| N/A   35C    P0    69W / 400W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
ubuntu@132-145-138-145:~$ python script.py 
Traceback (most recent call last):
  File "script.py", line 25, in <module>
    assert (lock == 2).all().item(), lock
AssertionError: tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2,
        1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 1, 2, 2, 1, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        1, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
        1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 1, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 1, 1, 1,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 1, 2, 2, 2, 1, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 1, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 1, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2,
        1, 2, 2, 2, 1, 2, 2, 1, 1, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        1, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2,
        2, 1, 1, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2, 2, 1, 1, 1, 2, 1,
        2, 2, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 2, 2, 2, 2, 1,
        1, 1, 1, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 1, 2, 1, 1, 2, 1, 2, 1, 1, 1, 1,
        1, 1, 1, 2, 2, 1, 2, 1, 1, 1, 1, 2, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 2,
        2, 2, 2, 2, 2, 1, 2, 1], device='cuda:0', dtype=torch.int32)
ubuntu@132-145-138-145:~$ cat script.py 
import torch
import triton
import triton.language as tl

@triton.jit()
def dead_lock(
        locks,
        size: tl.constexpr,
        ):
    pid = tl.program_id(0)
    pid = pid % size
    if tl.atomic_xchg(locks + pid, 1) == 0:  # original value is 0
        tl.atomic_xchg(locks + pid, 2)
    # else:
    #     while tl.atomic_xchg(locks + pid, 3) != 2:  # supposed to return the old value (2 at this point)
    #         pass

lock = torch.zeros(800, dtype=torch.int32, device='cuda')

for _ in range(1000):
    lock.zero_()
    dead_lock[(lock.numel()*2,)](lock, lock.numel())
    assert (lock == 2).all().item(), lock

I guess may be bigger matrix and a for loop to retry may helps to raise the issue?

pommedeterresautee commented 1 year ago

Also the PTX is a bit strange, at the end there are 2 consecutive calls to barrier sync, the second one seem to be redundant.

//
// Generated by LLVM NVPTX Back-End
//

.version 8.0
.target sm_86
.address_size 64

    // .globl   dead_lock_0d
.extern .shared .align 1 .b8 global_smem[];

.visible .entry dead_lock_0d(
    .param .u64 dead_lock_0d_param_0
)
.maxntid 128, 1, 1
{
    .reg .pred  %p<6>;
    .reg .b32   %r<18>;
    .reg .b64   %rd<6>;

    ld.param.u64    %rd3, [dead_lock_0d_param_0];
    mov.u32     %r6, %ctaid.x;
    mul.hi.s32  %r7, %r6, 1374389535;
    shr.u32     %r8, %r7, 31;
    shr.s32     %r9, %r7, 8;
    add.s32     %r10, %r9, %r8;
    mul.lo.s32  %r11, %r10, 800;
    sub.s32     %r12, %r6, %r11;
    mul.wide.s32    %rd4, %r12, 4;
    add.s64     %rd5, %rd3, %rd4;
    mov.u32     %r1, %tid.x;
    membar.gl ;
    setp.eq.s32     %p4, %r1, 0;
    mov.u32     %r3, 1;
    mov.u32 %r2, 0x0;
    @%p4 atom.global.gpu.exch.b32 %r2, [ %rd5 + 0 ], %r3;
    mov.u32     %r4, global_smem;
    @%p4 st.shared.b32 [ %r4 + 0 ], %r2;
    bar.sync    0;
    ld.shared.u32   %r13, [global_smem];
    bar.sync    0;
    setp.ne.s32     %p3, %r13, 0;
    @%p3 bra    $L__BB0_2;
    bar.sync    0;
    membar.gl ;
    mov.u32     %r15, 2;
    mov.u32 %r17, 0x0;
    @%p4 atom.global.gpu.exch.b32 %r17, [ %rd5 + 0 ], %r15;
    @%p4 st.shared.b32 [ %r4 + 0 ], %r17;
    bar.sync    0;
    bar.sync    0;
$L__BB0_2:
    ret;

}
Jokeren commented 1 year ago

I think the result is actually expected.

Let's say you have block_id 0 and block_id 800.

block_id 0 first calls atomic_xchg([0], 1)->0, then it calls atomic_xchg([0], 2)->1. Next block_id 800 calls atomic_xchg([0], 1)->2. Shouldn't location[0] == 1?

Jokeren commented 1 year ago

Just a suggestion, probably this is what you want?

@triton.jit()
def dead_lock(
        locks,
        size: tl.constexpr,
        ):
    pid = tl.program_id(0)
    pid = pid % size  # every lock position got 2 CTAs accessing it
    if tl.atomic_cas(locks + pid, 0, 1) == 0:  # original value is 0
        tl.atomic_xchg(locks + pid, 2)
pommedeterresautee commented 1 year ago

Thank you for your point, I think I now understand, even if ops are atomic, the design of this snippet do not provide any guarantee on the order of execution of the ops (the first xchg of the second visit can be executed before or after the second xchg of the first visit), meaning it's expected to have a mix of 1 and 2s.

There is still the 2nd question about the double sync barrier in PTX, is it expected?

pommedeterresautee commented 1 year ago

Thank you for your point, I think I now understand, even if ops are atomic, the design of this snippet do not provide any guarantee on the order of execution of the ops (the first xchg of the second visit can be executed before or after the second xchg of the first visit), meaning it's expected to have a mix of 1 and 2s.

The suggested snipped is the one I use in the stream k kernel, now with your help I understand better why that one works :-)

There is still the 2nd question about the double sync barrier in PTX, is it expected?

ptillet commented 1 year ago

There is still the 2nd question about the double sync barrier in PTX, is it expected?

We know about it. It's not a big deal, since (1) ptxas may remove them; (2) even if not, two barriers back-to-back shouldn't have much overhead since the threads are already synchronized by the time they reach the second barrier.

pommedeterresautee commented 1 year ago

Thank you, it makes sense