Open CHDev93 opened 1 year ago
you may not be bounds checking correctly in the matmul. Really hard to say without the kernel
Sorry @ptillet got pulled away for much longer than expected!
Finally got around to making a minimal repro for this one. The kernel below is basically doing a matmul but masking out the upper diagonal (output is stored in a weird format which I need for later). Real kernel has a bunch more bits and bobs but this is the thrust of it.
"""CUDA_LAUNCH_BLOCKING=1 python illegal_address_repro.py"""
import torch
import triton
import triton.language as tl
_BLOCK_HEIGHT = 16
def cdiv(x: int, y: int) -> int:
"""Compute ceiling of x / y."""
return (x + y - 1) // y
@triton.jit
def kernel_op(
K,
out,
stride_bz: int,
stride_bn: int,
stride_bk: int,
stride_oz: int,
stride_om: int,
stride_on: int,
n_dim: int,
BLOCK_HEIGHT: tl.constexpr,
BLOCK_WIDTH: tl.constexpr,
K_DIM: tl.constexpr,
):
lower_bandwidth = BLOCK_WIDTH - BLOCK_HEIGHT
block_id = tl.program_id(0)
# initialise offsets
row_offsets = block_id * BLOCK_HEIGHT + tl.arange(0, BLOCK_HEIGHT)
col_offsets = block_id * BLOCK_HEIGHT - lower_bandwidth + tl.arange(0, BLOCK_WIDTH)
head_offsets = tl.arange(0, K_DIM)
block_width_offsets = tl.arange(0, BLOCK_WIDTH)
b_offsets = col_offsets[:, None] * stride_bn + head_offsets[None, :] * stride_bk
# initialise pointers to B
b_ptrs = K + b_offsets
b_mask = (col_offsets[:, None] >= 0) & (col_offsets[:, None] < n_dim)
# compute a @ b.T for a tile and mask out the appropriate parts
a_block = tl.full((BLOCK_HEIGHT, K_DIM), value=1, dtype=tl.float32)
b_block = tl.load(b_ptrs, mask=b_mask, other=0.0) # (BLOCK_WIDTH, K_DIM)
ab_prod_block = tl.dot(a_block.to(tl.float16), tl.trans(b_block.to(tl.float16))) # (BLOCK_HEIGHT, BLOCK_WIDTH)
# ab_prod_block = tl.zeros((BLOCK_HEIGHT, BLOCK_WIDTH), dtype=tl.float32) # this will not cause a failure
out_offsets = row_offsets[:, None] * stride_om + block_width_offsets[None, :] * stride_on
out_ptrs = out + out_offsets
out_mask = row_offsets[:, None] < n_dim
tl.store(out_ptrs, ab_prod_block, mask=out_mask)
# params
batch_size = 1
n_dim = 64
k_dim = 64
lower_bandwidth = 16
block_width = lower_bandwidth + _BLOCK_HEIGHT
B = torch.ones((batch_size, n_dim, k_dim), dtype=torch.float32, device="cuda")
AB_product = torch.empty((batch_size, n_dim, block_width), dtype=torch.float32, device="cuda")
block_width = lower_bandwidth + _BLOCK_HEIGHT
num_warps = 4 if k_dim <= 64 else 8
grid_x = cdiv(n_dim, _BLOCK_HEIGHT)
grid = (grid_x,)
kernel_op[grid](
B,
AB_product,
B.stride(0),
B.stride(1),
B.stride(2),
AB_product.stride(0),
AB_product.stride(1),
AB_product.stride(2),
n_dim,
BLOCK_HEIGHT=_BLOCK_HEIGHT,
BLOCK_WIDTH=block_width,
K_DIM=k_dim,
num_warps=num_warps,
num_stages=1,
)
print(AB_product)
Couple of observations
k_dim
to 32 or 16, I see no errorb_block
to tl.full((BLOCK_WIDTH, K_DIM), value=1, dtype=tl.float32)
rather than load it, I see no errortriton==2.0.0
and triton==2.0.0.post1
Was able to get around this error by breaking up the matmul along the K_DIM and accumulating into a temporary buffer (like in the matrix multiplication tutorial). Still unclear why the regression from before the MLIR release and why it works with matrices that I instantiate in triton (but fails with ones that I load from an input). Now everything compiles but I'm left with some larger numerical diffs than I had in the previous Triton-IR based library.
It does look similar to this issue though instead of K being too small, it's K being too large.
I am seeing a similar error, which goes away if the batch size I use is 1. Anything greater than that causes illegal memory access. Any updates on the resolution of this issue?
@shivam-msft which version of triton are you on? Did you try breaking up the reduction dimension? Also, what precision are you using? I find that f16 works reliably but f32 and tf32 can give really large numerical diffs
Hello! In case it's helpful for anyone, I'm getting this one if I use nanogpt fromhttps://github.com/Lightning-Universe/lightning-GPT
with deepspeed and torch.compile
and only for batch_sizes > 40 (on A100 40GB). For smaller batch sizes, it works fine, so maybe it's related to just OOMs?
I found it can be triggered by OOM, the test code is:
"""
Vector Addition
===============
In this tutorial, you will write a simple vector addition using Triton.
In doing so, you will learn about:
* The basic programming model of Triton.
* The `triton.jit` decorator, which is used to define Triton kernels.
* The best practices for validating and benchmarking your custom ops against native reference implementations.
"""
# %%
# Compute Kernel
# --------------
import torch
import triton
import triton.language as tl
@triton.jit
def add_kernel(x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
output_ptr, # *Pointer* to output vector.
n_elements, # Size of the vector.
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
# NOTE: `constexpr` so it can be used as a shape value.
):
# There are multiple 'programs' processing different data. We identify which program
# we are here:
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
# This program will process inputs that are offset from the initial data.
# For instance, if you had a vector of length 256 and block_size of 64, the programs
# would each access the elements [0:64, 64:128, 128:192, 192:256].
# Note that offsets is a list of pointers:
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses.
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extra elements in case the input is not a
# multiple of the block size.
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
# Write x + y back to DRAM.
tl.store(output_ptr + offsets, output, mask=mask)
# %%
# Let's also declare a helper function to (1) allocate the `z` tensor
# and (2) enqueue the above kernel with appropriate grid/block sizes:
def add(x: torch.Tensor, y: torch.Tensor):
# We need to preallocate the output.
output = torch.empty_like(x)
assert x.is_cuda and y.is_cuda and output.is_cuda
n_elements = output.numel()
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
# In this case, we use a 1D grid where the size is the number of blocks:
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
# NOTE:
# - Each torch.tensor object is implicitly converted into a pointer to its first element.
# - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel.
# - Don't forget to pass meta-parameters as keywords arguments.
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
# running asynchronously at this point.
return output
# %%
# We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness:
torch.manual_seed(0)
size = 4096*1024*520
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output_torch = x + y
output_triton = add(x, y)
print(output_torch)
print(output_triton)
print(f'The maximum difference between torch and triton is '
f'{torch.max(torch.abs(output_torch - output_triton))}')
this is the example of vector-add in openai triton official tutorial, we can see that the size is very big, and in my A100 80GB is has the following error:
Traceback (most recent call last):
File "/data/nfs/hanhaowen/cat_triton_sin_cos.py", line 87, in <module>
output_triton = add(x, y)
^^^^^^^^^
File "/data/nfs/hanhaowen/cat_triton_sin_cos.py", line 73, in add
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
File "/opt/conda/lib/python3.11/site-packages/triton/runtime/jit.py", line 345, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/triton/runtime/jit.py", line 691, in run
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
File "/opt/conda/lib/python3.11/site-packages/triton/backends/nvidia/driver.py", line 365, in __call__
self.launch(*args, **kwargs)
RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered
and when I make the size smaller, it will be run without any problem.
I get the error in the title on some code that's essentially just doing a matmul. I've found that when the inner dimension k is 32, everything is fine. When k is 64 I get the error below.
Interestingly this code used to work on previous pre release versions of triton. I even built the mlir branch (before it was merged) and had this code running without issue.
I have
num_warps = 4 if k<= 64 else 8
and tried settingnum_stages
to 1,2,3 and 4 with no effect.Any hints as to what could've changed recently or how to fix?