Closed jeejeelee closed 1 week ago
Hi @jeejeelee , thanks for bringing this up, would you mind adding device guard for JIT templates as well?
- https://github.com/flashinfer-ai/flashinfer/blob/a3360ff9c85e7d0bae8bb4ca6dbbca69cfadea37/python/flashinfer/jit/batch_decode_mla_templ.py
- https://github.com/flashinfer-ai/flashinfer/blob/a3360ff9c85e7d0bae8bb4ca6dbbca69cfadea37/python/flashinfer/jit/batch_decode_templ.py
- https://github.com/flashinfer-ai/flashinfer/blob/a3360ff9c85e7d0bae8bb4ca6dbbca69cfadea37/python/flashinfer/jit/batch_prefill_templ.py
- https://github.com/flashinfer-ai/flashinfer/blob/a3360ff9c85e7d0bae8bb4ca6dbbca69cfadea37/python/flashinfer/jit/single_decode_templ.py
- https://github.com/flashinfer-ai/flashinfer/blob/a3360ff9c85e7d0bae8bb4ca6dbbca69cfadea37/python/flashinfer/jit/single_prefill_templ.py
Okay, I will
@yzh119 I have checked all kernels and added the device guards. Additionally, I tested the ROPE and NORM kernel in a CUDA 11.8 environment to verify they pass the tests. Please let me know if any additional testing is needed.
plan
FIX: https://github.com/flashinfer-ai/flashinfer/issues/452