teddykoker / torchsort

Fast, differentiable sorting and ranking in PyTorch
https://pypi.org/project/torchsort/
Apache License 2.0
765 stars 33 forks source link

Sorting in more than 2 dimensions #37

Closed zimonitrome closed 2 years ago

zimonitrome commented 2 years ago

I love this library now that I got it to work with my code!

I was wondering though, are there any plans to make it work with ambiguously shaped tensors? What work would that entail?

My current training scheme has tensors shaped [B, H, W], so I am currently doing torch.stack([soft_sort(item) for item in batch]). This library is fast, but running it sequentially like that is not. Maybe there is a way to parallelize it or extend the function to use more dimensions?

teddykoker commented 2 years ago

Hi! What I would recommend is reshaping the tensor to be two dimensions, e.g [B, H, W] -> [B * H, W] and then expanding it after the operation. This is how some of the torch functions operate behind the scenes as well. This will be much faster than using a for loop

Here is some example code:

In [1]: import torch

In [2]: import torchsort

In [3]: B, H, W = 16, 10, 15

In [4]: x = torch.randn(B, H, W)

In [5]: x = x.view(B * H, W)

In [6]: x = torchsort.soft_sort(x)

In [7]: x = x.view(B, H, W)

In [8]: x.shape
Out[8]: torch.Size([16, 10, 15])
zimonitrome commented 2 years ago

This is clever and I actually thoight of this after making the issue too. Will try it out, thanks!

Edit: For anyone interested, I was able to speed up my sorting by ~30x by calling soft_sort only once.