Closed Fr0do closed 2 months ago
We don't currently have plans for this, as unfortunately we don't have much expertise in the team with Pytorch/CUDA. The Pytorch code is primarily intended as a reference and is currently significantly less efficient that the Jax code.
However I think some external groups might have a go at this, which we are very excited about!
Maybe @tridao can take a look at your Pallas kernel?
Maybe @tridao can take a look at your Pallas kernel?
I don't have experience with Pallas, sorry
Pallas is JAX specific, so it can't really be used with PyTorch.
Thank you for amazing work! Any chance that you will release PyTorch cuda kernel for custom linear scan, or maybe provide some instructions to adapt Mamba parallel associative scan?