Closed yzh119 closed 1 week ago
Current JIT compilation is slow because we rely on a huge header <torch/extension.h> which is too heavy for our use case.
<torch/extension.h>
This PR refactors the codebase to only include necessary headers for pybind, and moves most of torch runtime API calls from C++ to python.
The compilation time was reduced from 48 seconds to 18 seconds for lightweight operators such as norm.
Using nanobind or ctypes/cython for ffi could make the compilation speed even faster, but requires some fundamental changes to the codebase, I prefer sticking to pybind at this moment.
Current JIT compilation is slow because we rely on a huge header
<torch/extension.h>
which is too heavy for our use case.This PR refactors the codebase to only include necessary headers for pybind, and moves most of torch runtime API calls from C++ to python.
The compilation time was reduced from 48 seconds to 18 seconds for lightweight operators such as norm.