Open zzq96 opened 4 months ago
Atomic_add only works in global memory, not shared memory. Within one SM there would be almost no point in atomic adding shared memory even if it were possible
I was implementing backward calculation on position weights and found it rather slow to use atomic_add only on global memory. Here is a piece of benchmarking code.
cuda_source = '''
__global__ void scatter_add_kernel(int* dw, int block_m, int block_n,
int max_seq_len, long long stride_dws,
long long stride_dwzh) {
int w_len = 2 * max_seq_len - 1;
int thread_idx = threadIdx.x * block_n + threadIdx.y;
extern __shared__ int w[];
if (thread_idx < w_len) {
w[thread_idx] = 0;
}
__syncthreads();
int start_n = blockIdx.x * block_n;
int off_n = start_n + threadIdx.y;
int low = start_n;
int off_m;
int idx;
for (int start_m = low; start_m < max_seq_len; start_m += block_m) {
off_m = start_m + threadIdx.x;
idx = off_n - off_m + max_seq_len - 1;
atomicAdd(w + idx, 1);
}
__syncthreads();
if (thread_idx < w_len) {
dw[blockIdx.x * stride_dws + blockIdx.y * stride_dwzh + thread_idx] =
w[thread_idx];
}
}
torch::Tensor scatter_add(torch::Tensor input, int block_m, int block_n,
int max_seq_len, int num_seq_block, int num_z_h,
long long stride_dws, long long stride_dwzh) {
dim3 threads_per_block(block_m, block_n);
dim3 number_of_blocks(num_seq_block, num_z_h);
scatter_add_kernel<<<number_of_blocks, threads_per_block,
(2 * max_seq_len + 1) * sizeof(int)>>>(
input.data_ptr<int>(), block_m, block_n, max_seq_len, stride_dws,
stride_dwzh);
return input;
}
'''
cpp_source = '''
torch::Tensor scatter_add(torch::Tensor input, int block_m, int block_n,
int max_seq_len, int num_seq_block, int num_z_h,
long long stride_dws, long long stride_dwzh);
'''
def test_cuda(N):
BLOCK_M = 16
BLOCK_N = 64
Z = 1024
H = 2
num_sequence_block = math.ceil(N / BLOCK_N)
num_z_h = Z * H
dw = torch.zeros([num_sequence_block, num_z_h, 2 * N - 1],
dtype=torch.int32,
device='cuda')
stride_dws = dw.stride(0)
stride_dwzh = dw.stride(1)
test_extension.scatter_add(dw, BLOCK_M, BLOCK_N, N, num_sequence_block,
num_z_h, stride_dws, stride_dwzh)
dw = dw.sum(dim=0)
dw = dw.sum(dim=0)
return dw
@triton.jit
def _kernel_one_block(
start_n,
DW,
MAX_SEQ_LEN,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_n = tl.multiple_of(start_n, BLOCK_N)
offs_n = start_n + tl.arange(0, BLOCK_N)
offs_m_base = tl.arange(0, BLOCK_M)
low = start_n // BLOCK_M * BLOCK_M
high = MAX_SEQ_LEN
ds = tl.full([BLOCK_M, BLOCK_N], 1, dtype=tl.int32)
for start_m in range(low, high, BLOCK_M):
start_m = tl.multiple_of(start_m, BLOCK_M)
offs_m = start_m + offs_m_base
offs_w = offs_n[None, :] - offs_m[:, None] + MAX_SEQ_LEN - 1
tl.atomic_add(DW + offs_w, ds)
@triton.jit
def _kernel(
DW,
MAX_SEQ_LEN,
stride_dws,
stride_dwzh,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
off_s = tl.program_id(0).to(tl.int64)
start_n = off_s * BLOCK_N
DW += stride_dws * off_s + stride_dwzh * tl.program_id(1).to(tl.int64)
_kernel_one_block(
start_n=start_n,
DW=DW,
MAX_SEQ_LEN=MAX_SEQ_LEN,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
)
def test_triton(N):
num_warps = 8
num_stages = 1
Z = 1024
H = 2
BLOCK_M = 16
BLOCK_N = 64
num_sequence_block = math.ceil(N / BLOCK_N)
num_z_h = Z * H
dw = torch.zeros([num_sequence_block, num_z_h, 2 * N - 1],
dtype=torch.int32,
device='cuda')
grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]), num_z_h)
_kernel[grid](
DW=dw,
MAX_SEQ_LEN=N,
stride_dws=dw.stride(0),
stride_dwzh=dw.stride(1),
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_stages=num_stages,
num_warps=num_warps,
)
dw = dw.sum(dim=0)
dw = dw.sum(dim=0)
return dw
(The complete test code is in https://github.com/vickyandpiggy/torch_and_cuda/blob/main/compare_atomic_add.py)
N Triton Cuda
0 128.0 0.229376 0.130016
1 256.0 0.807936 0.394240
2 512.0 2.814976 1.384448
3 1024.0 10.505216 4.634624
4 2048.0 40.797184 16.866304
The results show that doing atomic_add/scatter_add on shared memory can be 2x faster than that doing so on global memory. I suggest it is because the severe conflicts in this case, and the faster read/write on shared memory greatly reduces the waiting time for threads.
So I wonder if you have a plan to support this in trition? Or is there any specific reason so that scatter on shared memory is not supported yet? @taoroalin @Jokeren
I has a tensor, shape is [128,], we want to store this on shared memory and than use atomic_add to update this tensor, like this:
any suggestions? thanks