flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
1.45k stars 140 forks source link

Will AOT compilation still be supported after JIT compilation is added? #510

Closed danieldk closed 1 month ago

danieldk commented 1 month ago

We saw that support for JIT compilation will be added in #507. We were wondering what the plans are for ahead-of-time compilation. We are happily using flashinfer in Text Generation Inference the support for KV caches with block_size=1 has really been helpful for us to support fine-grained prefix caching.

For many of our users it's pretty important that compilation is done ahead of time. When infrastructure is scaled up, we want to avoid delaying/slowing down processing of user requests due to JIT compilation and since infrastructure is often heterogeneous (both in the models served and in the GPUs used), we would have to compile most kernels anyway. So, for us it would be really useful if AOT will be supported going forward.

Thank you for your awesome work on flashinfer 🤗.

yzh119 commented 1 month ago

Hi @danieldk , thanks for bringing this up!

When infrastructure is scaled up, we want to avoid delaying/slowing down processing of user requests due to JIT compilation and since infrastructure is often heterogeneous (both in the models served and in the GPUs used), we would have to compile most kernels anyway

This is a reasonable concern, I think we can keep both JIT and AOT (for a set of "core" kernels, ~200mb). We should use "core" kernels whenever possible, and use JIT for the remaining kernels (new head dimensions, some attention variants, etc.), WDTY?

abcdabcd987 commented 1 month ago

I agree with @danieldk that AOT is important. For production use, we typically build a docker image. Then Kubernetes will spawn a pod running a container of that image. The container is ephemeral. So if AOT is missing, this would mean that every time the pod restart, we'll have to JIT compile. This would slow down the start time significantly.

For PyPI, we can ship a sdist and do JIT only. This can make sure that PyPI size is small.

For our hosted wheels, I agree with @yzh119 that AOT "core" kernels is a good idea. I think the "core" kernels should include kernels that popular pretrained models uses (e.g., Llama, QWen, DeepSeek).

I have a few suggestions additionally --

Frist, For better user experience, output a log when JITing a kernel (maybe also include elapsed time). This way, if we experience an unexpected long start time, we can know that it comes from JIT FlashInfer kernels. Logging the JIT kernel names can also help us decide what to be included in "core" kernels.

Second, wheels shouldn't pin to PyTorch versions. We can compile kernels that link to particular CUDA version and expose C ABI. We write a separate .cpp file that extracts PyTorch Tensor metadata and calls the kernel. When installing the wheel, we compile the python binding only. This way, it makes sure that compilation takes minimal time when pip install the wheel (only the time for pybind).

Shipping wheels tied to PyTorch version takes time and storage. And it might even be wrong. I don't think PyTorch explicitly guarantee that torch.Tensor ABI remains the same, even across minor versions.

Third, it would also be good to provide a customizable AOT script, just in case some users want AOT beyond the "core" kernels.

Fourth, as for the wheel size, I think even 2GB is acceptable. This is because CUDA + PyTorch already takes up maybe 10GB container size. It's already huge even without FlashInfer. So we shouldn't worry about FlashInfer takes up additional spaces.

yzh119 commented 1 month ago

Thanks @abcdabcd987 for your thoughts, here is my tentative plan:

Maintain two packages: flashinfer_aot which ships the pre-built binary in the sdist, another flashinfer which is sdist package that runs jit by default.

The two packages share version numbers. flashinfer package will first check whether flashinfer_aot is installed, if so, flashinfer will prioritize using pre-compiled kernels in flashinfer_aot and only uses JIT when kernel configuration is not found in the flashinfer_aot, otherwise, it will always compile kernels with JIT.

@danieldk how does this plan sound to you?

ping @comaniac @zhyncs @Ying1123 @merrymercy @WoosukKwon

comaniac commented 1 month ago

Sounds good to me. It would be even more better if we could allow flashinfer to install flashinfer_aot; otherwise most users would probably suffer from long compile time (and require nvcc in the environment). Ideally something like the following (not sure if it's achievable or make sense):

pip install "flashinfer[aot]" # Implicitly call `pip install flashinfer_aot`
abcdabcd987 commented 1 month ago

I'd push against a separate flashinfer_aot package name. It's possible that users install both and observe confusing behaviors. Especially, when the user upgrades one but not the other. Having a single package name will at least ensure that only one copy is installed.

danieldk commented 1 month ago

Third, it would also be good to provide a customizable AOT script, just in case some users want AOT beyond the "core" kernels.

This sounds great! I don't think we mind compiling flashinfer ourselves to get all the kernels AOT. For development we are caching builds anyway through Nix and for production docker containers we are also looking to improve build caching.

I think even outside applications like TGI, AOTing the most-used kernels and JITing the rest sounds like a good strategy.

abcdabcd987 commented 1 month ago

After offline discussion, I think the updated plan is as follows (@yzh119 please confirm):

  1. For PyPI, we publish flashinfer package as a sdist. It does not contain any precompiled kernels. Users will JIT compile when kernel is invoked.
  2. For local development, it's the same as PyPI. JIT-only.
  3. Perhaps most importantly, we will provide a "precompiled sdist" under pip index url https://flashinfer.ai/whl/cu???/.
    • The "precompiled sdist" will contain precompiled kernels for common uses.
    • These precompiled kernels (.so files) are linked with specific CUDA version.
    • We don't want to link the kernels with PyTorch because PyTorch might change ABI.
    • When pip install flashinfer -i https://flashinfer.ai/whl/cu???/, users will compile a PyTorch extension that links to the precompiled kernel .so files. Since this is only compiling the pybind, not the kernels, this compilation will be fast.
    • Kernels that are not precompiled should still be able to JIT.
  4. The script for producing the precompiled sdist should be customizable. So if users have different set of kernels that want AOT compiled, they can produce their own precompiled sdist. (I think this would satisfy @danieldk 's need.)
  5. There will be no bdist / wheel anymore. It's replaced by precompiled sdist.
danieldk commented 1 month ago

That sounds awesome. Thank you for taking our use case into account!

yzh119 commented 1 month ago

Both JIT mode and AOT mode are supported in #507 .