Open fxmarty opened 1 year ago
import torch
import triton import triton.language as tl
from matmul_perf_model import early_config_prune
def init_tozero(nargs): return nargs.zero()
def get_configs_io_bound(): for num_stages in range(2, 7): for block_m in [16, 32]: for block_k in [32, 64]: for block_n in [32, 64, 128, 256]: num_warps = 2 if block_n <= 64 else 4 yield triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, num_stages=num_stages, num_warps=num_warps)
for split_k in [2, 4, 8, 16]:
yield triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero)
@triton.autotune( configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK
Hi, I doubt this is a duplicate of https://github.com/openai/triton/issues/1058 because I am not overflowing on the program ids (largest PID should be
cdiv(8192, 16) * cdiv(8192, 32) = 131072
in my case).I extended https://github.com/openai/triton/blob/main/python/triton/ops/matmul.py to support 4D tensors, e.g. of shapes (2, 8, 1024, 1024) and (2, 8, 1024, 512).
For large enough values in dimension 1 (while keeping all other dims equal), a
CUDA error: an illegal memory access was encountered
is raised:Is it an issue in my implementation or expected at first glance? I have no issues on smaller shapes. This is on A100-SXM4-80GB on triton
main
.Here is my kernel. You can see that the only change is grabbing a
pid_batch
andpid_dim1
, and changing the offset to account the many GEMMs as we have 4D tensors.Thank you!