triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.3k stars 1.63k forks source link

Segfault when using `tl.range(...,num_stages)` #4368

Closed jeromeku closed 4 weeks ago

jeromeku commented 3 months ago

Trying to implement pipelining using tl.range(..., num_stages=num_pipeline_stages)` for a persistent kernel.

Each SM executes the following operations per iteration: Memory: There are 7 loads ( 32 x 32 tile of fp16 ) and one store of same size and dtype Compute: 4 elementwise (3 mul, 1 add), and 2 dot , all of size 32x32.

The tile size (32x32) should be configurable but even at this smallest size, setting the num_pipeline_stages > 1 results in Segmentation fault (core dumped). (The kernel executes correctly across a range of args as long as num_pipeline_stages <= 1).

Any suggestions on how to debug would be greatly appreciated!


Here is some additional context on hardware and launch config:

HW Info:

target=GPUTarget(backend='cuda', arch=86, warp_size=32) NUM_SM=82 NUM_REGS=65536 SIZE_SMEM=101376 WARP_SIZE=32

Here are some basic kernel stats from aot compiling the kernel with num_warps=4, num_pipeline_stages=1 and grid set to NUM_SM:

n_regs: 80, size_smem: 8192, register occupancy: 6, smem occupancy: 12

where:

n_regs = kernel.n_regs
size_smem = kernel.metadata.shared
register_occ = occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
smem_occ = SIZE_SMEM // size_smem
(
device = torch.cuda.current_device()
properties = driver.active.utils.get_device_properties(device)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
)
Jokeren commented 3 months ago

Does the num_stages kernel config need to be adjusted in tandem with the tl.range specific num_stages (I believe the kernel num_stages is for tl.dot)?

No

Are there any debugging flags that can be set for diagnosing segfaults?

It would be better if you could provide with error messages. I'm not sure if core dump happens at kernel runtime or compilation.

jeromeku commented 3 months ago

@Jokeren

The error happens at compile time, as running the kernel in warmup mode results in a segfault.

Attached is the output of TRITON_ENABLE_LLVM_DEBUG before the segfault. segfault.txt

Jokeren commented 3 months ago

I think this is more of an error in the pipeline pass.

We can take a look if you have a reproducer.

cc @pawelszczerbuk @htyu @ThomasRaoux

jeromeku commented 3 months ago

@Jokeren @pawelszczerbuk @htyu @ThomasRaoux

Much appreciated!

Here's a stripped down version of the kernel.

Setting NUM_PIPELINE_STAGES > 1 results in a segfault, but works otherwise.

jeromeku commented 3 months ago

@Jokeren

Anything I can do to help debug?

Jokeren commented 3 months ago

There might be more than one problem. Since I'm working on something else, I could be slow and I was trying to find you others who could be more available to handle this issue.

jeromeku commented 3 months ago

@Jokeren FWIW, the kernel no longer compiles at all on triton nightly (3.0.0.post20240716052845):

loc("/home/jeromeku/Cpp/CUDA/cgcg-dev/experiments/segfault.py":68:46): error: 'scf.yield' op must be the last operation in the parent block
Traceback (most recent call last):
  File "/home/jeromeku/Cpp/CUDA/cgcg-dev/experiments/segfault.py", line 168, in <module>
    _pipeline_kernel[grid](*kernel_args, **kernel_constexprs)
  File "/home/jeromeku/miniconda3/envs/test-env/lib/python3.11/site-packages/triton/runtime/jit.py", line 326, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeromeku/miniconda3/envs/test-env/lib/python3.11/site-packages/triton/runtime/jit.py", line 643, in run
    kernel = self.compile(
             ^^^^^^^^^^^^^
  File "/home/jeromeku/miniconda3/envs/test-env/lib/python3.11/site-packages/triton/compiler/compiler.py", line 287, in compile
    next_module = compile_ir(module, metadata)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeromeku/miniconda3/envs/test-env/lib/python3.11/site-packages/triton/backends/nvidia/compiler.py", line 329, in <lambda>
    stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability)
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jeromeku/miniconda3/envs/test-env/lib/python3.11/site-packages/triton/backends/nvidia/compiler.py", line 201, in make_ttgir
    pm.run(mod)
RuntimeError: PassManager::run failed

Before, I was running with the following env, in which case the kernel ran for NUM_PIPELINE_STAGES < 2:

pytorch-triton           3.0.0+dedb7bdf33
torch                    2.5.0.dev20240803+cu121

I installed triton nightly in this env, which is now leading to above error.

jeromeku commented 3 months ago

Also, using the previous environment:

pytorch-triton           3.0.0+dedb7bdf33
torch                    2.5.0.dev20240803+cu121

The kernel segfaults if I change tl.range to range (and remove NUM_PIPELINE_STAGES). See here.

Jokeren commented 3 months ago

The problem you mentioned has been fixed by this PR. I'm waiting for @ThomasRaoux to take the second pass.

https://github.com/triton-lang/triton/pull/4438

zeb209 commented 4 weeks ago

Hi @Jokeren. Can you confirm that this bug has been fixed? If yes, can we close this?

Jokeren commented 4 weeks ago

Yes, closed