Open boubezariali opened 2 weeks ago
Thanks for the proposal!
~Isn't pytorch main repo a better suitable for this function??~ I see you had already opened a pr there as well.
How is this transformation used in computer vision? Could you give a small example ??
Yeah, I opened a PR in the main torch repo, but the response was that we already have scatter-reduce & index_put, and this operation is a little more vision specific.
The main vision application is around lidar voxelization (example). When transforming lidar points to a voxel grid, we compute each lidar point's voxel, then scatter it via sparse updates to a large grid. Most applications of voxelization also require a reduction to handle duplicate points to generate the maximum feature values for each voxel grid.
I appreciate the proposal and don't want to hold it back. However, I'm not entirely sure if torchvision is the right place for this, as it currently lacks any voxel or specific 3D computer vision operations or losses. But who knows—maybe this could be the start of something new!
Yeah, there are a ton of 3D specific utilities that can be built from this, including lidar voxelization itself! Theoretically, if we can start up a 3D/voxel library in torchvision Nuro would be happy to help contribute and help build it up. Let me know!
I guess for having operations for 3d data, we will need a detailed RFC and approval from the maintainers.
On side note, pytorch3d would be better suited for this. (Sorry for redirecting you to yet another repository) :p
🚀 The feature
Hey everyone, I’m Ali, a software engineer from Nuro, and we’ve been working on migrating a lot of vision models from tensorflow to pytorch. One particularly critical operation for these vision models is tf.tensor_scatter_nd_update , and its reduction counterparts (e.g. tf.tensor_scatter_nd_min, tf.tensor_scatter_nd_max , etc).
We'd like to offer to implement a robust implementation of ScatterND in pytorch. Some details:
I've attached sample implementations in the 'Additional context section'.
Motivation, pitch
The ScatterND operation was a major roadblock for the tensorflow => pytorch migration for Nuro. If this operation is widely available, more organizations might be encouraged to adopt pytorch (rather than JAX, which is the default migration path from tensorflow).
Alternatives
There are a few scatter_nd implementations scattered (pardon the pun) around the web but none seem to be perfect.
https://github.com/rusty1s/pytorch_scatter, this one uses custom CUDA kernels which isn't easy to use or maintain https://gist.github.com/airalcorn2/c7846d6fcb58a30b25ea6d97e16fe025, this implementation has a bug https://gist.github.com/Ending2015a/b034ebbedc55fec1d8ec3b7230a95f1e, doesn't support batching, or have static shapes which is important for torch.compile/torch-xla
Additional context
Forum post: https://discuss.pytorch.org/t/contribution-proposal-scatternd-implementation-for-pytorch/211874/1
(sample implementation) https://gist.github.com/boubezariali/2e3d9650461f302a541235e33d7cded2 (simple test) https://gist.github.com/boubezariali/a4a22736a14414404b45b17f62da9c2d
We also have extensive testing implemented that we're willing to contribute.