Open void-main opened 1 year ago
Hi @ptillet , could you please help take a look? Are there any temp fixes/workarounds for such overhead?
You can try out cudagraph.
P.S., oh, just noticed that it's not an option for you...aot could help then
Thanks for the reply @Jokeren , I'll take a look at AOT.
Could you please share some tutorials / sample code that I could refer for the aot process? Thank you very much!
Hi @Jokeren , I tried compiled the kernel. But the kernel turns out to be using CUDA Driver API. How could I build a torch extension (seems to be using CUDA runtime API)?
Especially, how could I construct a CUdeviceptr
out of at::Tensor
?
I think it's the other way around. Get the data_ptr of a at::tensor and pass it to the kernel
@void-main Been building some tools around AOT compilation -- see PR #2642.
Let me know how I can be helpful -- happy to refine these tools for your use case.
@jeromeku great work!
After reading through the description of PR #2642 , if I understand correctly, the tool helps to ease the process of calling the aot compiler (through tracing the dtype of tensors).
But my problem is how to actually invoke the compiled .c functions via torch extension, could the tool help on this part?
@void-main
Sure -- I can extend the tool to automatically codegen a torch extension module from the compiled kernels. Do you have a particular kernel of interest along with some test inputs / outputs?
So that I understand your use case, you'd like to create a custom kernel using triton
then call this kernel independent of the triton
runtime (i.e., in a standalone pytorch
script)?
It might be helpful to create an independent repo for this functionality. I can see it could simply deployment of triton kernels in many cases.
@Jokeren
Yes - that was the intent of the original PR #2642. What concrete use cases / scenarios -- frameworks, backends, kernels -- would you recommend targeting initially that would be most beneficial to triton
proper and its users?
Also, are there any resources you'd recommend for diving deeper into MLIR
as it pertains to the triton
pipeline? I've been through the official toy
tutorial and also found Jeremy Kun's tutorial to be particularly helpful. Also fairly familiar with cutlass
and the GEMM hierarchy.
FYI, @ThomasRaoux was a leading developer of IREE's GPU backend and can provide further suggestions
Are you worrying about such a new tool would be duplicated with IREE? I don't have a concrete idea yet, but I think you could start with a simple tool that makes AOT more user-friendly.
@Jokeren
Not worried about being duplicated, just want to create simple, useful tool.
Wanted to get your thoughts on technical and architectural similarities / differences between nod.ai
and triton
(e.g., see this blogpost for more context).
@jeromeku
Sure -- I can extend the tool to automatically codegen a torch extension module from the compiled kernels. Do you have a particular kernel of interest along with some test inputs / outputs?
That would be great! Could we just use the matmul example in the test code?
So that I understand your use case, you'd like to create a custom kernel using
triton
then call this kernel independent of thetriton
runtime (i.e., in a standalonepytorch
script)?
Yep. The root cause for my scenario is the triton.jit
kernel launch overhead is waaaaay too much compared to my kernel. (see the nsys capture in my description)
Hey @Jokeren , I got the aot kernel wrapped in a torch extension. The code runs, however the aot compiled kernel produces wrong results, how could I debug such case...
I've double checked the input data, grid config, constants are the same for aot and jit kernels. I wonder if there are any constraints on aot compiler?
@void-main: I can take a look if you can provide the kernel and some test data.
Wanted to get your thoughts on technical and architectural similarities / differences between nod.ai and triton (e.g., see this blogpost for more context).
d128 has been improved after the commit he compared against. I haven't looked into cases of short contexts. Technically, the solutions are quite similar if you look at the MLIR representation.
I'd like to avoid commenting on architectural differences/similarities to avoid conflicts since I know only little about iree.
@void-main: I can take a look if you can provide the kernel and some test data.
Hey @jeromeku , thank you! But since the code is part of my company, I can't paste it publicly for now. Anyway we could get in touch (eg, slack, tg, discord), or email (my email address is on my github profile page)?
Than you very much!
Hi @void-main, I've encountered the same issue with high Triton kernel launch overhead. Could you please share any solutions or workarounds that have worked for you? Thank you!
Hi @chenzhengda , I believe there are 3 solutions for now:
cached_bin = _your_triton_kernel[grid](xxx)
returns, and call the cached bin directlyHi @void-main,
First of all, thank you very much for your suggestions!
I have a couple of questions. In my scenario, I have dynamic shaped inputs, so I wonder if CUDA graphs are not applicable in this case?
Additionally, I tried the second method you mentioned, but encountered an error: triton.compiler.compiler.CompiledKernel object is not callable. Could you please provide further clarification on this issue? I would greatly appreciate it.
Thank you!
Hi @chenzhengda , sorry for late reply.
In my scenario, I have dynamic shaped inputs, so I wonder if CUDA graphs are not applicable in this case?
What are the shape ranges of your inputs? If it only varies in one dimension, maybe you could build multiple cuda graphs along the varying dimension. For example, if you have a dimension for batch_size, and the batch_size is dynamic, you could build a graph each for batch_size of (1, 2, 4, 8, 16, ...) and pad your input to the nearest cuda graph batch. This is how vllm creates cuda graph for decoding stage.
Additionally, I tried the second method you mentioned, but encountered an error: triton.compiler.compiler.CompiledKernel object is not callable. The cached object is not directly callable, you need to call its
bin_wrapper
method with a lot of parameters. Here's a rough example:
if cached_bin is not None:
bin = cached_bin
stream = torch.cuda.current_stream().cuda_stream
args = [
your triton args
]
bin.c_wrapper(grid[0], grid[1], grid[2], bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args)
else:
cached_bin = _fwd_kernel[grid](your triton args)
@void-main Trying your code above but there are a couple of issues:
cached_bin
doesn't have c_wrapper()
callnum_ctas
, clusterDims
, etc.Do you have by any chance a well tested code? Currently launch overhead an issue for smaller tensors. Thank you very much!
Hi @void-main, Thanks for your advice so much! I've fixed my launch issue through cache_bin
.
@mobicham & @chenzhengda, you can take the the following steps to see if your problems can be alleviated:
/triton/compiler/compiler.py
in your triton version, or your source code. Here i take the latest branch as an example: launch_metadata = self.launch_metadata(grid, stream, *args)
self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata,
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, *args)
is the key point. Try to adapt your cache bin code to this format. Generally, your args
should be variables
+ tl.constexpr
. kwargs like num_warps
& num_stages
should not be involved after your kernel cached.
So I have been looking into this issue, I can confirm that @xinji1's solution does work to some extent. However callingrun()
directly is not compatible with torch.compile.
Instead, I found that putting the kernel call in a torch model and using torch compile with max-autotune (no-cudagraph) helps:
model.forward = torch.compile(model.forward, mode='max-autotune-no-cudagraphs', fullgraph=True)
However, you need Pytorch 2.5.0 and you cannot use some Triton features like prune_configs_by
.
If you use mode='reduce-overhead'
, which is the recommended settings for many applications, the performance is pretty bad.
Hey team, I'm suffering high triton kernel launch overhead. Here's my nsys capture:
The kernel executes around 80us on GPU, however, it takes 220us to launch, which causes the performance degration. I've checked the kernel is compiled and cached, so it's pure triton jit launcher overhead.
I noticed there's a tracker issue in the community, I wonder if there are any updates?
Besides, I wonder if aot compiler could be a help?
PS: in my scenario, the cuda graph is not an option (due to dynamic shaped inputs), so I have to reduce the kernel launch time.