Closed jeromeku closed 4 weeks ago
Does the num_stages kernel config need to be adjusted in tandem with the tl.range specific num_stages (I believe the kernel num_stages is for tl.dot)?
No
Are there any debugging flags that can be set for diagnosing segfaults?
It would be better if you could provide with error messages. I'm not sure if core dump happens at kernel runtime or compilation.
@Jokeren
The error happens at compile time, as running the kernel in warmup
mode results in a segfault.
Attached is the output of TRITON_ENABLE_LLVM_DEBUG
before the segfault.
segfault.txt
I think this is more of an error in the pipeline pass.
We can take a look if you have a reproducer.
cc @pawelszczerbuk @htyu @ThomasRaoux
@Jokeren @pawelszczerbuk @htyu @ThomasRaoux
Much appreciated!
Here's a stripped down version of the kernel.
Setting NUM_PIPELINE_STAGES
> 1 results in a segfault, but works otherwise.
@Jokeren
Anything I can do to help debug?
There might be more than one problem. Since I'm working on something else, I could be slow and I was trying to find you others who could be more available to handle this issue.
@Jokeren
FWIW, the kernel no longer compiles at all on triton nightly (3.0.0.post20240716052845
):
loc("/home/jeromeku/Cpp/CUDA/cgcg-dev/experiments/segfault.py":68:46): error: 'scf.yield' op must be the last operation in the parent block
Traceback (most recent call last):
File "/home/jeromeku/Cpp/CUDA/cgcg-dev/experiments/segfault.py", line 168, in <module>
_pipeline_kernel[grid](*kernel_args, **kernel_constexprs)
File "/home/jeromeku/miniconda3/envs/test-env/lib/python3.11/site-packages/triton/runtime/jit.py", line 326, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jeromeku/miniconda3/envs/test-env/lib/python3.11/site-packages/triton/runtime/jit.py", line 643, in run
kernel = self.compile(
^^^^^^^^^^^^^
File "/home/jeromeku/miniconda3/envs/test-env/lib/python3.11/site-packages/triton/compiler/compiler.py", line 287, in compile
next_module = compile_ir(module, metadata)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jeromeku/miniconda3/envs/test-env/lib/python3.11/site-packages/triton/backends/nvidia/compiler.py", line 329, in <lambda>
stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jeromeku/miniconda3/envs/test-env/lib/python3.11/site-packages/triton/backends/nvidia/compiler.py", line 201, in make_ttgir
pm.run(mod)
RuntimeError: PassManager::run failed
Before, I was running with the following env, in which case the kernel ran for NUM_PIPELINE_STAGES < 2
:
pytorch-triton 3.0.0+dedb7bdf33
torch 2.5.0.dev20240803+cu121
I installed triton
nightly in this env, which is now leading to above error.
Also, using the previous environment:
pytorch-triton 3.0.0+dedb7bdf33
torch 2.5.0.dev20240803+cu121
The kernel segfaults if I change tl.range
to range
(and remove NUM_PIPELINE_STAGES
). See here.
The problem you mentioned has been fixed by this PR. I'm waiting for @ThomasRaoux to take the second pass.
Hi @Jokeren. Can you confirm that this bug has been fixed? If yes, can we close this?
Yes, closed
Trying to implement pipelining using tl.range(..., num_stages=num_pipeline_stages)` for a persistent kernel.
Each SM executes the following operations per iteration: Memory: There are 7 loads ( 32 x 32 tile of fp16 ) and one store of same size and dtype Compute: 4 elementwise (3
mul
, 1add
), and 2dot
, all of size 32x32.The tile size (32x32) should be configurable but even at this smallest size, setting the
num_pipeline_stages
> 1 results inSegmentation fault (core dumped)
. (The kernel executes correctly across a range of args as long asnum_pipeline_stages
<= 1).Any suggestions on how to debug would be greatly appreciated!
tl.range
with pipelining?num_stages
kernel config need to be adjusted in tandem with thetl.range
specificnum_stages
(I believe the kernelnum_stages
is fortl.dot
)?fused softmax
tutorial by settingnum_pipeline_stages
to exceedsmem
limits results in anOutOfResources
error and not aSegfault
, so something else is going on.Here is some additional context on hardware and launch config:
HW Info:
Here are some basic kernel stats from aot compiling the kernel with
num_warps=4
,num_pipeline_stages=1
and grid set toNUM_SM
:where: