flashinfer-ai / flashinfer

FlashInfer: Kernel Library for LLM Serving
https://flashinfer.ai
Apache License 2.0
1.48k stars 147 forks source link

perf: accelerate JIT compilation speed #618

Closed yzh119 closed 1 week ago

yzh119 commented 1 week ago

Current JIT compilation is slow because we rely on a huge header <torch/extension.h> which is too heavy for our use case.

This PR refactors the codebase to only include necessary headers for pybind, and moves most of torch runtime API calls from C++ to python.

The compilation time was reduced from 48 seconds to 18 seconds for lightweight operators such as norm.

yzh119 commented 1 week ago

Using nanobind or ctypes/cython for ffi could make the compilation speed even faster, but requires some fundamental changes to the codebase, I prefer sticking to pybind at this moment.