google-deepmind / recurrentgemma

Open weights language model from Google DeepMind, based on Griffin.
Apache License 2.0
567 stars 23 forks source link

PyTorch scan kernel #3

Closed Fr0do closed 2 months ago

Fr0do commented 2 months ago

Thank you for amazing work! Any chance that you will release PyTorch cuda kernel for custom linear scan, or maybe provide some instructions to adapt Mamba parallel associative scan?

SamSmithGDM commented 2 months ago

We don't currently have plans for this, as unfortunately we don't have much expertise in the team with Pytorch/CUDA. The Pytorch code is primarily intended as a reference and is currently significantly less efficient that the Jax code.

However I think some external groups might have a go at this, which we are very excited about!

Fr0do commented 2 months ago

Maybe @tridao can take a look at your Pallas kernel?

tridao commented 2 months ago

Maybe @tridao can take a look at your Pallas kernel?

I don't have experience with Pallas, sorry

botev commented 2 months ago

Pallas is JAX specific, so it can't really be used with PyTorch.