This work looks promising!
I am thinking if it's feasible to port the CUDA code into JAX and achieve similar speed.
Do you think it's feasible to achieve this?
I'm less experience with JAX. The custom cuda code is necessary to achieve the speed.
But I'm not sure whether it will be difficult to adapt my custom cuda code to JAX or just call them from JAX.
This work looks promising! I am thinking if it's feasible to port the CUDA code into JAX and achieve similar speed. Do you think it's feasible to achieve this?