Open SeKim12 opened 1 week ago
Yes, we could definitely add this op. Do you think it can be supported in all backends? Are you able to start a PR?
From what I could tell, it's supported in tensorflow, jax (above), but not yet in PyTorch or Numpy.
I'd be interested in working on this op. I also see a lot of great torch implementations on this thread as well: https://github.com/pytorch/pytorch/issues/95408
Hello!
I was looking for the equivalent of jax.lax.associative_scan. This is an important operation for recent state space models, e.g. S5, Mamba.
Maybe it's there and I'm missing it. Otherwise, I could start working on it. Thanks!