Open yzh119 opened 3 months ago
For reference the following CI command on main will compile in 142minutes on a 32 core Zen3 4GHZ server container.
# time FLASHINFER_BUILD_VERSION="999.0.4+cu124torch2.3_gpPadded8_v4_arch8x" FLASHINFER_GROUP_SIZES="1,4,5,6,7,8" TORCH_CUDA_ARCH_LIST="8.0 8.9" python -m build --no-isolation
real 141m49.680s
user 2389m59.725s
sys 200m31.997s
Without flashinfer_jit
the only way to speed up flashinfer whl compilation to reasonable timeframe is to lock group_size to only 1 value needed by the intended model as that will cut down compilation steps/time by ~8x in my tests.
Changing TORCH_CUDA_ARCH_LIST doesn't have much impact on speed.
As the combination of shapes and configurations increases, our pip wheel size grows and the compilation time becomes long.
PyTorch supports Just-In-Time compilation of extensions: https://pytorch.org/tutorials/advanced/cpp_extension.html#jit-compiling-extensions, which makes it possible to only compile kernels corresponding to certain configurations/shapes, thus reducing both the wheel size and the development overhead on the codebase.
We can release a
flashinfer_jit
wheel where all kernels are compiled with JIT.