triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.35k stars 1.64k forks source link

how to use atomic_add in shared memory? #4276

Open zzq96 opened 4 months ago

zzq96 commented 4 months ago

I has a tensor, shape is [128,], we want to store this on shared memory and than use atomic_add to update this tensor, like this:

def test(output_ptr, input_ptr, input_index_ptr):
    shm_output_ptr = tensor in shm
    for i in range(N):
        input = tl.load(input_ptr + i)
        input_idx = tl.load(input_index_ptr + i)
        tl.atomic_add(shm_output_ptr +input_idx , input)

    shm_output = tl.load_from_shm(shm_output_ptr)
    tl.store(output_ptr, shm_output)

any suggestions? thanks

taoroalin commented 4 months ago

Atomic_add only works in global memory, not shared memory. Within one SM there would be almost no point in atomic adding shared memory even if it were possible

vickyandpiggy commented 2 months ago

I was implementing backward calculation on position weights and found it rather slow to use atomic_add only on global memory. Here is a piece of benchmarking code.

cuda_source = '''
__global__ void scatter_add_kernel(int* dw, int block_m, int block_n,
                                   int max_seq_len, long long stride_dws,
                                   long long stride_dwzh) {
  int w_len = 2 * max_seq_len - 1;
  int thread_idx = threadIdx.x * block_n + threadIdx.y;
  extern __shared__ int w[];
  if (thread_idx < w_len) {
    w[thread_idx] = 0;
  }
  __syncthreads();
  int start_n = blockIdx.x * block_n;
  int off_n = start_n + threadIdx.y;
  int low = start_n;
  int off_m;
  int idx;
  for (int start_m = low; start_m < max_seq_len; start_m += block_m) {
    off_m = start_m + threadIdx.x;
    idx = off_n - off_m + max_seq_len - 1;
    atomicAdd(w + idx, 1);
  }
  __syncthreads();
  if (thread_idx < w_len) {
    dw[blockIdx.x * stride_dws + blockIdx.y * stride_dwzh + thread_idx] =
        w[thread_idx];
  }
}

torch::Tensor scatter_add(torch::Tensor input, int block_m, int block_n,
                          int max_seq_len, int num_seq_block, int num_z_h,
                          long long stride_dws, long long stride_dwzh) {
  dim3 threads_per_block(block_m, block_n);
  dim3 number_of_blocks(num_seq_block, num_z_h);

  scatter_add_kernel<<<number_of_blocks, threads_per_block,
                       (2 * max_seq_len + 1) * sizeof(int)>>>(
      input.data_ptr<int>(), block_m, block_n, max_seq_len, stride_dws,
      stride_dwzh);

  return input;
}
'''

cpp_source = '''
            torch::Tensor scatter_add(torch::Tensor input, int block_m, int block_n,
                          int max_seq_len, int num_seq_block, int num_z_h,
                          long long stride_dws, long long stride_dwzh);
            '''

def test_cuda(N):
  BLOCK_M = 16
  BLOCK_N = 64
  Z = 1024
  H = 2
  num_sequence_block = math.ceil(N / BLOCK_N)
  num_z_h = Z * H
  dw = torch.zeros([num_sequence_block, num_z_h, 2 * N - 1],
                   dtype=torch.int32,
                   device='cuda')
  stride_dws = dw.stride(0)
  stride_dwzh = dw.stride(1)
  test_extension.scatter_add(dw, BLOCK_M, BLOCK_N, N, num_sequence_block,
                             num_z_h, stride_dws, stride_dwzh)
  dw = dw.sum(dim=0)
  dw = dw.sum(dim=0)
  return dw

@triton.jit
def _kernel_one_block(
    start_n,
    DW,
    MAX_SEQ_LEN,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
  start_n = tl.multiple_of(start_n, BLOCK_N)
  offs_n = start_n + tl.arange(0, BLOCK_N)
  offs_m_base = tl.arange(0, BLOCK_M)
  low = start_n // BLOCK_M * BLOCK_M
  high = MAX_SEQ_LEN
  ds = tl.full([BLOCK_M, BLOCK_N], 1, dtype=tl.int32)
  for start_m in range(low, high, BLOCK_M):
    start_m = tl.multiple_of(start_m, BLOCK_M)
    offs_m = start_m + offs_m_base
    offs_w = offs_n[None, :] - offs_m[:, None] + MAX_SEQ_LEN - 1
    tl.atomic_add(DW + offs_w, ds)

@triton.jit
def _kernel(
    DW,
    MAX_SEQ_LEN,
    stride_dws,
    stride_dwzh,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
  off_s = tl.program_id(0).to(tl.int64)
  start_n = off_s * BLOCK_N
  DW += stride_dws * off_s + stride_dwzh * tl.program_id(1).to(tl.int64)
  _kernel_one_block(
      start_n=start_n,
      DW=DW,
      MAX_SEQ_LEN=MAX_SEQ_LEN,
      BLOCK_M=BLOCK_M,
      BLOCK_N=BLOCK_N,
  )

def test_triton(N):
  num_warps = 8
  num_stages = 1
  Z = 1024
  H = 2
  BLOCK_M = 16
  BLOCK_N = 64
  num_sequence_block = math.ceil(N / BLOCK_N)
  num_z_h = Z * H
  dw = torch.zeros([num_sequence_block, num_z_h, 2 * N - 1],
                   dtype=torch.int32,
                   device='cuda')
  grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), num_z_h)
  _kernel[grid](
      DW=dw,
      MAX_SEQ_LEN=N,
      stride_dws=dw.stride(0),
      stride_dwzh=dw.stride(1),
      BLOCK_M=BLOCK_M,
      BLOCK_N=BLOCK_N,
      num_stages=num_stages,
      num_warps=num_warps,
  )
  dw = dw.sum(dim=0)
  dw = dw.sum(dim=0)
  return dw

(The complete test code is in https://github.com/vickyandpiggy/torch_and_cuda/blob/main/compare_atomic_add.py)

        N     Triton       Cuda
0   128.0   0.229376   0.130016
1   256.0   0.807936   0.394240
2   512.0   2.814976   1.384448
3  1024.0  10.505216   4.634624
4  2048.0  40.797184  16.866304

The results show that doing atomic_add/scatter_add on shared memory can be 2x faster than that doing so on global memory. I suggest it is because the severe conflicts in this case, and the faster read/write on shared memory greatly reduces the waiting time for threads.

So I wonder if you have a plan to support this in trition? Or is there any specific reason so that scatter on shared memory is not supported yet? @taoroalin @Jokeren