Open adam-hartshorne opened 4 months ago
that's a good point, think we can replace these with a scan @AndPotap ?
Thanks for pointing this out. I'll think about how to do this for it to also be compatible with PyTorch.
I think we would just want to add a xnp.scan function to the backends and then use that. Also for cases where we are using a xnp.for_loop right now, we should probably replace with scan where possible
I notice the new kernel operator makes use of nested for loops and update operations. For JAX that is a very bad idea. for loops should be avoided at all costs.