wilson-labs / cola

Compositional Linear Algebra
Apache License 2.0
325 stars 24 forks source link

Observation about new Kernel Operator #87

Open adam-hartshorne opened 4 months ago

adam-hartshorne commented 4 months ago

I notice the new kernel operator makes use of nested for loops and update operations. For JAX that is a very bad idea. for loops should be avoided at all costs.

mfinzi commented 4 months ago

that's a good point, think we can replace these with a scan @AndPotap ?

AndPotap commented 4 months ago

Thanks for pointing this out. I'll think about how to do this for it to also be compatible with PyTorch.

mfinzi commented 4 months ago

I think we would just want to add a xnp.scan function to the backends and then use that. Also for cases where we are using a xnp.for_loop right now, we should probably replace with scan where possible