flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
760 stars 64 forks source link

JIT compilation #170

Open yzh119 opened 3 months ago

yzh119 commented 3 months ago

As the combination of shapes and configurations increases, our pip wheel size grows and the compilation time becomes long.

PyTorch supports Just-In-Time compilation of extensions: https://pytorch.org/tutorials/advanced/cpp_extension.html#jit-compiling-extensions, which makes it possible to only compile kernels corresponding to certain configurations/shapes, thus reducing both the wheel size and the development overhead on the codebase.

We can release a flashinfer_jit wheel where all kernels are compiled with JIT.

Qubitium commented 1 month ago

For reference the following CI command on main will compile in 142minutes on a 32 core Zen3 4GHZ server container.

# time FLASHINFER_BUILD_VERSION="999.0.4+cu124torch2.3_gpPadded8_v4_arch8x" FLASHINFER_GROUP_SIZES="1,4,5,6,7,8"  TORCH_CUDA_ARCH_LIST="8.0 8.9" python -m build --no-isolation

real    141m49.680s
user    2389m59.725s
sys     200m31.997s

Without flashinfer_jit the only way to speed up flashinfer whl compilation to reasonable timeframe is to lock group_size to only 1 value needed by the intended model as that will cut down compilation steps/time by ~8x in my tests.

Changing TORCH_CUDA_ARCH_LIST doesn't have much impact on speed.