triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.35k stars 1.64k forks source link

High kernel launch overhead #2637

Open void-main opened 1 year ago

void-main commented 1 year ago

Hey team, I'm suffering high triton kernel launch overhead. Here's my nsys capture: CleanShot 2023-11-10 at 10 28 53

The kernel executes around 80us on GPU, however, it takes 220us to launch, which causes the performance degration. I've checked the kernel is compiled and cached, so it's pure triton jit launcher overhead.

I noticed there's a tracker issue in the community, I wonder if there are any updates?

Besides, I wonder if aot compiler could be a help?

PS: in my scenario, the cuda graph is not an option (due to dynamic shaped inputs), so I have to reduce the kernel launch time.

void-main commented 1 year ago

Hi @ptillet , could you please help take a look? Are there any temp fixes/workarounds for such overhead?

Jokeren commented 1 year ago

You can try out cudagraph.

P.S., oh, just noticed that it's not an option for you...aot could help then

void-main commented 1 year ago

Thanks for the reply @Jokeren , I'll take a look at AOT.

Could you please share some tutorials / sample code that I could refer for the aot process? Thank you very much!

void-main commented 1 year ago

Hi @Jokeren , I tried compiled the kernel. But the kernel turns out to be using CUDA Driver API. How could I build a torch extension (seems to be using CUDA runtime API)?

Especially, how could I construct a CUdeviceptr out of at::Tensor?

Jokeren commented 1 year ago

I think it's the other way around. Get the data_ptr of a at::tensor and pass it to the kernel

jeromeku commented 12 months ago

@void-main Been building some tools around AOT compilation -- see PR #2642.

Let me know how I can be helpful -- happy to refine these tools for your use case.

void-main commented 12 months ago

@jeromeku great work!

After reading through the description of PR #2642 , if I understand correctly, the tool helps to ease the process of calling the aot compiler (through tracing the dtype of tensors).

But my problem is how to actually invoke the compiled .c functions via torch extension, could the tool help on this part?

jeromeku commented 12 months ago

@void-main

Sure -- I can extend the tool to automatically codegen a torch extension module from the compiled kernels. Do you have a particular kernel of interest along with some test inputs / outputs?

So that I understand your use case, you'd like to create a custom kernel using triton then call this kernel independent of the triton runtime (i.e., in a standalone pytorch script)?

Jokeren commented 12 months ago

It might be helpful to create an independent repo for this functionality. I can see it could simply deployment of triton kernels in many cases.

jeromeku commented 12 months ago

@Jokeren

Yes - that was the intent of the original PR #2642. What concrete use cases / scenarios -- frameworks, backends, kernels -- would you recommend targeting initially that would be most beneficial to triton proper and its users?

Also, are there any resources you'd recommend for diving deeper into MLIR as it pertains to the triton pipeline? I've been through the official toy tutorial and also found Jeremy Kun's tutorial to be particularly helpful. Also fairly familiar with cutlass and the GEMM hierarchy.

Thoughts on nod.ai and its fork of IREE?

Jokeren commented 12 months ago

FYI, @ThomasRaoux was a leading developer of IREE's GPU backend and can provide further suggestions

Are you worrying about such a new tool would be duplicated with IREE? I don't have a concrete idea yet, but I think you could start with a simple tool that makes AOT more user-friendly.

jeromeku commented 12 months ago

@Jokeren

Not worried about being duplicated, just want to create simple, useful tool.

Wanted to get your thoughts on technical and architectural similarities / differences between nod.ai and triton (e.g., see this blogpost for more context).

void-main commented 12 months ago

@jeromeku

Sure -- I can extend the tool to automatically codegen a torch extension module from the compiled kernels. Do you have a particular kernel of interest along with some test inputs / outputs?

That would be great! Could we just use the matmul example in the test code?

So that I understand your use case, you'd like to create a custom kernel using triton then call this kernel independent of the triton runtime (i.e., in a standalone pytorch script)?

Yep. The root cause for my scenario is the triton.jit kernel launch overhead is waaaaay too much compared to my kernel. (see the nsys capture in my description)

void-main commented 12 months ago

Hey @Jokeren , I got the aot kernel wrapped in a torch extension. The code runs, however the aot compiled kernel produces wrong results, how could I debug such case...

I've double checked the input data, grid config, constants are the same for aot and jit kernels. I wonder if there are any constraints on aot compiler?

jeromeku commented 12 months ago

@void-main: I can take a look if you can provide the kernel and some test data.

Jokeren commented 12 months ago

Wanted to get your thoughts on technical and architectural similarities / differences between nod.ai and triton (e.g., see this blogpost for more context).

d128 has been improved after the commit he compared against. I haven't looked into cases of short contexts. Technically, the solutions are quite similar if you look at the MLIR representation.

I'd like to avoid commenting on architectural differences/similarities to avoid conflicts since I know only little about iree.

void-main commented 12 months ago

@void-main: I can take a look if you can provide the kernel and some test data.

Hey @jeromeku , thank you! But since the code is part of my company, I can't paste it publicly for now. Anyway we could get in touch (eg, slack, tg, discord), or email (my email address is on my github profile page)?

Than you very much!

chenzhengda commented 4 months ago

Hi @void-main, I've encountered the same issue with high Triton kernel launch overhead. Could you please share any solutions or workarounds that have worked for you? Thank you!

void-main commented 4 months ago

Hi @chenzhengda , I believe there are 3 solutions for now:

  1. wrap your triton kernel into a CUDA graph
  2. cache what cached_bin = _your_triton_kernel[grid](xxx) returns, and call the cached bin directly
  3. try the AoT compiler, which is not much documented
chenzhengda commented 4 months ago

Hi @void-main,

First of all, thank you very much for your suggestions!

I have a couple of questions. In my scenario, I have dynamic shaped inputs, so I wonder if CUDA graphs are not applicable in this case?

Additionally, I tried the second method you mentioned, but encountered an error: triton.compiler.compiler.CompiledKernel object is not callable. Could you please provide further clarification on this issue? I would greatly appreciate it.

Thank you!

void-main commented 3 months ago

Hi @chenzhengda , sorry for late reply.

In my scenario, I have dynamic shaped inputs, so I wonder if CUDA graphs are not applicable in this case?

What are the shape ranges of your inputs? If it only varies in one dimension, maybe you could build multiple cuda graphs along the varying dimension. For example, if you have a dimension for batch_size, and the batch_size is dynamic, you could build a graph each for batch_size of (1, 2, 4, 8, 16, ...) and pad your input to the nearest cuda graph batch. This is how vllm creates cuda graph for decoding stage.

Additionally, I tried the second method you mentioned, but encountered an error: triton.compiler.compiler.CompiledKernel object is not callable. The cached object is not directly callable, you need to call its bin_wrapper method with a lot of parameters. Here's a rough example:

if cached_bin is not None:
  bin = cached_bin
  stream = torch.cuda.current_stream().cuda_stream
  args = [
    your triton args
  ]
  bin.c_wrapper(grid[0], grid[1], grid[2], bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
else:
    cached_bin = _fwd_kernel[grid](your triton args)
mobicham commented 1 month ago

@void-main Trying your code above but there are a couple of issues:

Do you have by any chance a well tested code? Currently launch overhead an issue for smaller tensors. Thank you very much!

xinji1 commented 1 month ago

Hi @void-main, Thanks for your advice so much! I've fixed my launch issue through cache_bin. @mobicham & @chenzhengda, you can take the the following steps to see if your problems can be alleviated:

mobicham commented 1 week ago

So I have been looking into this issue, I can confirm that @xinji1's solution does work to some extent. However callingrun() directly is not compatible with torch.compile. Instead, I found that putting the kernel call in a torch model and using torch compile with max-autotune (no-cudagraph) helps:

model.forward = torch.compile(model.forward, mode='max-autotune-no-cudagraphs', fullgraph=True)

However, you need Pytorch 2.5.0 and you cannot use some Triton features like prune_configs_by. If you use mode='reduce-overhead', which is the recommended settings for many applications, the performance is pretty bad.