Open jubueche opened 6 days ago
It seems that it is working (for now) when I don't invoke this kernel:
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE_OUT": 256, "BLOCK_SIZE_HIDDEN": 64}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE_OUT": 256, "BLOCK_SIZE_HIDDEN": 32}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_SIZE_OUT": 128, "BLOCK_SIZE_HIDDEN": 32}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_SIZE_OUT": 64, "BLOCK_SIZE_HIDDEN": 32}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_SIZE_OUT": 128, "BLOCK_SIZE_HIDDEN": 32}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_SIZE_OUT": 32, "BLOCK_SIZE_HIDDEN": 32}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_SIZE_OUT": 32, "BLOCK_SIZE_HIDDEN": 32}, num_stages=5, num_warps=2),
triton.Config({"BLOCK_SIZE_OUT": 64, "BLOCK_SIZE_HIDDEN": 32}, num_stages=5, num_warps=2),
],
key=["hidden_size", "out_size"],
)
@triton.jit
def modifier_kernel(
# pointers to tensors
weights_ptr, # 2D [hidden_size, out_size]
assumed_wmax_ptr, # 2D [num_slices, out_size]
reduced_assumed_wmax_ptr, # 2D [num_slices, 1]
upper_end_of_slices_ptr, # 1D [num_slices]
# sizes
hidden_size,
out_size,
num_slices,
# strides
stride_weights_hidden_size,
stride_weights_out_size,
stride_assumed_wmax_num_slices,
stride_assumed_wmax_out_size,
# miscellaneous
modifier_type: tl.constexpr, # str
modifier_weight_res: tl.constexpr, # float
modifier_seed: tl.constexpr, # int
modifier_std: tl.constexpr, # float
# block sizes
BLOCK_SIZE_HIDDEN: tl.constexpr,
BLOCK_SIZE_OUT: tl.constexpr,
):
pid = tl.program_id(axis=0)
offs_bn = (pid * BLOCK_SIZE_OUT + tl.arange(0, BLOCK_SIZE_OUT)) % out_size
offs_assumed_wmax = pid * BLOCK_SIZE_OUT + tl.arange(0, BLOCK_SIZE_OUT)
# for random number generation of output
increase_weight_offsets_by = BLOCK_SIZE_HIDDEN * BLOCK_SIZE_OUT
weight_random_offsets = tl.arange(0, BLOCK_SIZE_HIDDEN * BLOCK_SIZE_OUT).reshape(
(BLOCK_SIZE_HIDDEN, BLOCK_SIZE_OUT), can_reorder=True
)
ir_range_lower = 0
for slice_idx in range(0, num_slices):
# load the abs-max we need
abs_max_slice_ptrs = (
assumed_wmax_ptr
+ slice_idx * stride_assumed_wmax_num_slices
+ offs_bn * stride_assumed_wmax_out_size
)
if modifier_type == "AddNormal" or (
modifier_type == "Discretize" or modifier_type == "DiscretizeAddNormal"
):
assumed_wmax_per_slice = tl.load(reduced_assumed_wmax_ptr + slice_idx)
else:
assumed_wmax_per_slice = tl.load(
abs_max_slice_ptrs, mask=offs_assumed_wmax < out_size, other=float("-inf")
)
assumed_wmax_per_slice = assumed_wmax_per_slice[None, :]
ir_range_upper = tl.load(upper_end_of_slices_ptr + slice_idx)
current_lower = ir_range_lower
num_k = tl.cdiv(ir_range_upper - ir_range_lower, BLOCK_SIZE_HIDDEN)
for k in range(0, num_k):
current_upper = min(
ir_range_upper, ir_range_lower + (k + 1) * BLOCK_SIZE_HIDDEN, hidden_size
)
offs_k = current_lower + tl.arange(0, BLOCK_SIZE_HIDDEN)
b_ptrs = weights_ptr + (
offs_k[:, None] * stride_weights_hidden_size
+ offs_bn[None, :] * stride_weights_out_size
)
b = tl.load(b_ptrs, mask=offs_k[:, None] < current_upper, other=0.0)
if (modifier_type == "Discretize" or modifier_type == "DiscretizeAddNormal") or (
modifier_type == "DiscretizePerChannel"
or modifier_type == "DiscretizeAddNormalPerChannel"
):
if modifier_weight_res > 0:
n_states = max(modifier_weight_res, 1 / modifier_weight_res)
res = 2 * assumed_wmax_per_slice / n_states
b = b / res
b = tl.extra.cuda.libdevice.rint(b)
b = b * res
if (modifier_type == "AddNormal" or modifier_type == "AddNormalPerChannel") or (
modifier_type == "DiscretizeAddNormal"
or modifier_type == "DiscretizeAddNormalPerChannel"
):
randn_block = tl.randn(modifier_seed + pid, weight_random_offsets)
weight_random_offsets += increase_weight_offsets_by
randn_block = assumed_wmax_per_slice * modifier_std * randn_block
b += randn_block
# store b back to DRAM...
tl.store(
b_ptrs,
b,
mask=(offs_k[:, None] < current_upper) & (offs_assumed_wmax[None, :] < out_size),
)
current_lower = current_upper
ir_range_lower = ir_range_upper
I call this kernel like so:
if apply_weight_modifier:
modifier_std = rpu_config.modifier.std_dev
modifier_seed = randint(2**31, (1,)).item()
# bring the weight resolution into a state that we can interpret
modifier_weight_res = rpu_config.modifier.res
modifier_kernel[weight_modifier_grid](
# pointers to tensors
weights.T, # 2D [hidden_size, out_size]
assumed_wmax, # 2D [num_slices, out_size]
reduced_assumed_wmax, # 2D [num_slices, 1]
upper_end_of_slices, # 1D [num_slices]
# sizes
hidden_size,
out_size,
num_slices,
# strides
weights.stride(1), # flipped because of transpose
weights.stride(0),
assumed_wmax.stride(0),
assumed_wmax.stride(1),
# miscellaneous
rpu_config.modifier.type.value,
modifier_weight_res,
modifier_seed,
modifier_std,
# block sizes
# 32, # for debugging
# 32,
)
Note that I change the weights in-place. I essentially add noise to them.
I think I have found my bug.
In the function defintion, modifier_seed
is set as a tl.constexpr
, although it changes at compile time. The seed is always randomly generated.
Changing this to
modifier_seed, # int
solves the issue.
For the other parameters that I pass, is it correct that I use tl.constexpr
? What happens when the function is already compiled with, for example a value of modifier_std = 0.3
and then later I try to call it with modifier_std = 0.4
? Does the function get re-compiled? Or does it just use the "wrong" function with 0.3?
I developed a Triton kernel that is replacing the
Linear
layer in a framework that I am currently developing.This kernel is integrated in the standard way:
My framework also has a "fallback" PyTorch mode. When I use this mode, my model trains fine, but when I use the triton kernel, I eventually get
I searched online and only found this issue in the Google/Jax repo: https://github.com/google/jax/issues/16272 Here, they observed many LLVM worker threads being spawned, eventually leading to this error. I wanted to look at the processes using
top -Hp $(pgrep ...)
but couldn't manage to see anything useful. I then wanted to see the CPU memory development between the triton and PyTorch version:The top one is the Triton one and the bottom is the PyTorch one. Since the memory profiler also records the memory recursively for every spawned child subprocess, I could also see that the triton run had many more child processes. Triton run: 3700 Torch: 124 Also one can see that the memory grows linearly for Triton. Also, given that one node in my cluster has >700 GB of RAM, I am quite certain that I am not running out. When doing
top
, my virtual memory is ~90GB per GPU process (8 in parallel), but the physical memory used per process corresponds to what can be seen in the plot above.I also wanted to rule out that we don't have a silent GPU OOM, so I profiled the memory:
The top is triton and the bottom is PyTorch. I verified that there are no spurious tensors that are not released. The memory usage of the triton kernel is actually slightly lower than the one of PyTorch.
I also did a run where I disabled the auto-tuning, but it didn't change anything. For this run, the number of child processes even exceeded 6000. I don't know why that is. Also, the total memory used seems to be lower this time. The below picture is a zoomed in version before the crash.
On my cluster, I saw that each user can only have 4096 processes. I am not sure, but I guess sub-processes also count as that (?). I can't increase that since I am not an admin.
Setup
Torch: 2.3.0+cu121 I installed triton using
pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
OS: CentOS version 7, Kernel version: 3.10.0-1160.76.1.el7.x86_64 (yes it's old) Training: I am training on 1 node with 8 V100s using HF accelerate + DeepSpeed ZerO stage 2. This stage shards the optimizer state (offloaded to CPU as well) and the gradients.