Closed danieldk closed 1 month ago
Hi @danieldk , thanks for bringing this up!
When infrastructure is scaled up, we want to avoid delaying/slowing down processing of user requests due to JIT compilation and since infrastructure is often heterogeneous (both in the models served and in the GPUs used), we would have to compile most kernels anyway
This is a reasonable concern, I think we can keep both JIT and AOT (for a set of "core" kernels, ~200mb). We should use "core" kernels whenever possible, and use JIT for the remaining kernels (new head dimensions, some attention variants, etc.), WDTY?
I agree with @danieldk that AOT is important. For production use, we typically build a docker image. Then Kubernetes will spawn a pod running a container of that image. The container is ephemeral. So if AOT is missing, this would mean that every time the pod restart, we'll have to JIT compile. This would slow down the start time significantly.
For PyPI, we can ship a sdist and do JIT only. This can make sure that PyPI size is small.
For our hosted wheels, I agree with @yzh119 that AOT "core" kernels is a good idea. I think the "core" kernels should include kernels that popular pretrained models uses (e.g., Llama, QWen, DeepSeek).
I have a few suggestions additionally --
Frist, For better user experience, output a log when JITing a kernel (maybe also include elapsed time). This way, if we experience an unexpected long start time, we can know that it comes from JIT FlashInfer kernels. Logging the JIT kernel names can also help us decide what to be included in "core" kernels.
Second, wheels shouldn't pin to PyTorch versions. We can compile kernels that link to particular CUDA version and expose C ABI. We write a separate .cpp
file that extracts PyTorch Tensor metadata and calls the kernel. When installing the wheel, we compile the python binding only. This way, it makes sure that compilation takes minimal time when pip install
the wheel (only the time for pybind).
Shipping wheels tied to PyTorch version takes time and storage. And it might even be wrong. I don't think PyTorch explicitly guarantee that torch.Tensor
ABI remains the same, even across minor versions.
Third, it would also be good to provide a customizable AOT script, just in case some users want AOT beyond the "core" kernels.
Fourth, as for the wheel size, I think even 2GB is acceptable. This is because CUDA + PyTorch already takes up maybe 10GB container size. It's already huge even without FlashInfer. So we shouldn't worry about FlashInfer takes up additional spaces.
Thanks @abcdabcd987 for your thoughts, here is my tentative plan:
Maintain two packages: flashinfer_aot
which ships the pre-built binary in the sdist, another flashinfer
which is sdist package that runs jit by default.
flashinfer
will be hosted on pypi.flashinfer_aot
use self-hosted index (https://flashinfer.ai/whl), which ships binary kernels, when user pip install it, it will be linked to user's pytorch installation. It's not mandatory to install flashinfer_aot
to use flashinfer
.The two packages share version numbers. flashinfer
package will first check whether flashinfer_aot
is installed, if so, flashinfer
will prioritize using pre-compiled kernels in flashinfer_aot
and only uses JIT when kernel configuration is not found in the flashinfer_aot
, otherwise, it will always compile kernels with JIT.
@danieldk how does this plan sound to you?
ping @comaniac @zhyncs @Ying1123 @merrymercy @WoosukKwon
Sounds good to me. It would be even more better if we could allow flashinfer
to install flashinfer_aot
; otherwise most users would probably suffer from long compile time (and require nvcc in the environment). Ideally something like the following (not sure if it's achievable or make sense):
pip install "flashinfer[aot]" # Implicitly call `pip install flashinfer_aot`
I'd push against a separate flashinfer_aot
package name. It's possible that users install both and observe confusing behaviors. Especially, when the user upgrades one but not the other. Having a single package name will at least ensure that only one copy is installed.
Third, it would also be good to provide a customizable AOT script, just in case some users want AOT beyond the "core" kernels.
This sounds great! I don't think we mind compiling flashinfer ourselves to get all the kernels AOT. For development we are caching builds anyway through Nix and for production docker containers we are also looking to improve build caching.
I think even outside applications like TGI, AOTing the most-used kernels and JITing the rest sounds like a good strategy.
After offline discussion, I think the updated plan is as follows (@yzh119 please confirm):
flashinfer
package as a sdist. It does not contain any precompiled kernels. Users will JIT compile when kernel is invoked.https://flashinfer.ai/whl/cu???/
.
.so
files) are linked with specific CUDA version.pip install flashinfer -i https://flashinfer.ai/whl/cu???/
, users will compile a PyTorch extension that links to the precompiled kernel .so
files. Since this is only compiling the pybind, not the kernels, this compilation will be fast.That sounds awesome. Thank you for taking our use case into account!
Both JIT mode and AOT mode are supported in #507 .
We saw that support for JIT compilation will be added in #507. We were wondering what the plans are for ahead-of-time compilation. We are happily using flashinfer in Text Generation Inference the support for KV caches with
block_size=1
has really been helpful for us to support fine-grained prefix caching.For many of our users it's pretty important that compilation is done ahead of time. When infrastructure is scaled up, we want to avoid delaying/slowing down processing of user requests due to JIT compilation and since infrastructure is often heterogeneous (both in the models served and in the GPUs used), we would have to compile most kernels anyway. So, for us it would be really useful if AOT will be supported going forward.
Thank you for your awesome work on flashinfer 🤗.