keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.27k stars 19.38k forks source link

Adding `ops.associative_scan`? #19904

Open SeKim12 opened 1 week ago

SeKim12 commented 1 week ago

Hello!

I was looking for the equivalent of jax.lax.associative_scan. This is an important operation for recent state space models, e.g. S5, Mamba.

Maybe it's there and I'm missing it. Otherwise, I could start working on it. Thanks!

fchollet commented 1 week ago

Yes, we could definitely add this op. Do you think it can be supported in all backends? Are you able to start a PR?

SeKim12 commented 1 week ago

From what I could tell, it's supported in tensorflow, jax (above), but not yet in PyTorch or Numpy.

I'd be interested in working on this op. I also see a lot of great torch implementations on this thread as well: https://github.com/pytorch/pytorch/issues/95408