Open jubueche opened 1 week ago
This code
import triton import triton.language as tl from torch import Tensor @triton.jit def get_three_d_weights( currents, # [B_M, B_N] weight_block, # [B_K, B_N] BLOCK_SIZE_K: tl.constexpr, ): poly_arange = tl.arange(0, BLOCK_SIZE_K)[:, None] # [B_K,1] x [B_M,1,B_N] -> [B_M, B_ ] approx_v = poly_arange * currents[:, None, :] # this will broadcast to [BLOCK_SIZE_M*next_pow_2(n_inp_bits),BLOCK_SIZE_K,BLOCK_SIZE_N] ir_dropped_weights = weight_block[None, :, :] * approx_v return ir_dropped_weights @triton.jit def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, currents_ptr, # M, N, K, # stride_am, stride_ak, # stride_bk, stride_bn, # stride_cm, stride_cn, # stride_currentsm, stride_currentsn, # BLOCK_SIZE_M: tl.constexpr, # BLOCK_SIZE_N: tl.constexpr, # BLOCK_SIZE_K: tl.constexpr, # GROUP_SIZE_M: tl.constexpr, # NUM_SMS: tl.constexpr, # ): start_pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) k_tiles = tl.cdiv(K, BLOCK_SIZE_K) num_tiles = num_pid_m * num_pid_n tiles_per_SM = num_tiles // NUM_SMS if start_pid < num_tiles % NUM_SMS: tiles_per_SM += 1 tile_id = start_pid - NUM_SMS ki = -1 offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) num_pid_in_group = GROUP_SIZE_M * num_pid_n pid_m = 0 pid_n = 0 offs_am = tl.arange(0, BLOCK_SIZE_M) offs_bn = tl.arange(0, BLOCK_SIZE_N) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) currents = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float16) for _ in range(0, k_tiles * tiles_per_SM): ki = tl.where(ki == k_tiles - 1, 0, ki + 1) if ki == 0: tile_id += NUM_SMS group_id = tile_id // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + (tile_id % group_size_m) pid_n = (tile_id % num_pid_in_group) // group_size_m start_m = pid_m * BLOCK_SIZE_M start_n = pid_n * BLOCK_SIZE_N offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) offs_am = tl.where(offs_am < M, offs_am, 0) offs_bn = tl.where(offs_bn < N, offs_bn, 0) offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) currents_ptrs = currents_ptr + (offs_am[:, None] * stride_currentsm + offs_bn[None, :] * stride_currentsn) currents = tl.load(currents_ptrs, mask=(offs_am[:, None] < M) & (offs_bn[None, :] < N), other=0.0) offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) a = tl.load(a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0) modified_weights = get_three_d_weights(currents, b, BLOCK_SIZE_K) a_broadcasted = tl.broadcast_to(a[:, None], (BLOCK_SIZE_M, 16, BLOCK_SIZE_K)) three_d_dot = tl.dot(a_broadcasted, modified_weights) three_d_dot = tl.sum(three_d_dot, axis=1) / 16. accumulator += three_d_dot if ki == k_tiles - 1: offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) if (c_ptr.dtype.element_ty == tl.float8e4nv): c = accumulator.to(tl.float8e4nv) else: c = accumulator.to(tl.float16) tl.store(c_ptrs, c, mask=c_mask) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # fmt: on def matmul_persistent(a: Tensor, b: Tensor, currents: Tensor): configs = { torch.float8_e4m3fn: { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8, "num_stages": 4, "num_warps": 8 }, torch.float16: { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 16, "GROUP_SIZE_M": 8, "num_stages": 3, "num_warps": 8 } } # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" assert a.shape[0] == currents.shape[0], "Wrong current shape" assert currents.shape[1] == b.shape[1], "Wrong current shape" assert a.dtype == b.dtype, "Incompatible dtypes" assert a.dtype == currents.dtype, "Incompatible dtypes" NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count M, K = a.shape K, N = b.shape dtype = a.dtype # Allocates output. c = torch.empty((M, N), device=a.device, dtype=dtype) # 1D launch kernel where each block gets its own program. grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), ) matmul_kernel_persistent[grid]( a, b, c, currents, # M, N, K, # a.stride(0), a.stride(1), # b.stride(0), b.stride(1), # c.stride(0), c.stride(1), # currents.stride(0), currents.stride(1), # BLOCK_SIZE_M=configs[dtype]["BLOCK_SIZE_M"], # BLOCK_SIZE_N=configs[dtype]["BLOCK_SIZE_N"], # BLOCK_SIZE_K=configs[dtype]["BLOCK_SIZE_K"], # GROUP_SIZE_M=configs[dtype]["GROUP_SIZE_M"], # NUM_SMS=NUM_SMS, # num_stages=configs[dtype]["num_stages"], # num_warps=configs[dtype]["num_warps"], # ) return c if __name__ == "__main__": import torch matmul_persistent( torch.randn((10, 4), device="cuda", dtype=torch.float16), torch.randn((4, 5), device="cuda", dtype=torch.float16), torch.randn((10, 5), device="cuda", dtype=torch.float16), )
throws the following compilation error:
python: /root/.triton/llvm/llvm-ed4e505c-centos-x64/include/llvm/ADT/SmallVector.h:304: T& llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::operator[](llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::size_type) [with T = unsigned int; <template-parameter-1-2> = void; llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::reference = unsigned int&; llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::size_type = long unsigned int]: Assertion `idx < size()' failed.
I don't know how to debug this (a tutorial on this or pointers would be nice). I am using triton v3.0.0. installed via pip from the webpage.
This takes me to the page detailing debugging the frontend. This code runs when I run it with TRITON_INTERPRET=1. So I don't think your "suggestion" works here. I think it's a bug so I am opening an issue here.
TRITON_INTERPRET=1
This code
throws the following compilation error:
I don't know how to debug this (a tutorial on this or pointers would be nice). I am using triton v3.0.0. installed via pip from the webpage.