flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
768 stars 64 forks source link

[DO NOT MERGE] cudagraph: use cuda dynamic parallelism to dispatch kernels #288

Open yzh119 opened 1 month ago

yzh119 commented 1 month ago

Background

FlashInfer will dispatch to different kernels depending on input shapes, which is not cuda-graph friendly at the moment. CUDA 12.4 announced conditional nodes in CUDA Graphs, and it will take some time for PyTorch to fully support this feature (capture conditional nodes).

Dynamic Parallelism

In #187 , @AgrawalAmey propose to use CUDA's Dynamic Parallelism feature to dispatch to different child kernels from a parent launcher kernel. We don't need to capture conditional nodes in this way.

Discussion

This PR implements the dynamic parallelism launcher, however, we need to enable -rdc(relocatable-device-code) in compiler options to support dynamic parallelism, which introduce some non-negligible runtime overhead for small kernels (there are some discussions in nvidia forum).

Considering the potential negative consequence of enabling rdc (e.g. performance degradation), I'll hold this PR and turn to use alternative approach (which do not require dispatching) to support CUDA Graph. We won't merge this PR at the moment and we welcome discussions.

AgrawalAmey commented 1 month ago

This is really fascinating! Curious to know about the alternate approach. Let me know if I can help in anyway. Thanks!