tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
https://burn.dev
Apache License 2.0
9.07k stars 450 forks source link

scatter_mul operation and related #2044

Open McArthur-Alford opened 4 months ago

McArthur-Alford commented 4 months ago

Feature description

It would be really nice to have a scatter_mul op, which is identical to scatter except using multiplication reduction. Similarly, scatter_max/min would be nice. There are other functions (select_assign) that could probably benefit from this as well. I suspect it might be better to have an enum for reduction strategy, rather than many functions, and just pass it as an arg to scatter?

Feature motivation

Pytorch has this functionality, so it would be nice for parity reasons. Personally, I've wanted this a few times while working on #1998 but it isn't a pressing issue. From my looking around it seems like it would be quite a bit of work, so for now I'm not doing it myself (just making a feature request). Once I'm done with all the other sparse stuff I hope I will be able to get around to this, if it hasn't been done by then.

louisfd commented 4 months ago

Hi @McArthur-Alford I'll be working on porting both reduction kernels and scatter/select operations using CubeCL in the coming days. I'll try to think of an elegant way to generalize this.