teddykoker / torchsort

Fast, differentiable sorting and ranking in PyTorch
https://pypi.org/project/torchsort/
Apache License 2.0
774 stars 34 forks source link

Is there a way to use this to find the index of the biggest number in a torch vector? #21

Closed roeeben closed 3 years ago

roeeben commented 3 years ago

When you use the regular torch.sort, the ranks vector that's returned is sorted, so if I want the index of the maximum value I just take the last elements of the ranks vector. Same for the index of the 2nd biggest element: I just takes ranks_vec[-2] .

Unfortunately the regular torch.sort does not support a gradient. I've been trying to think of a way to achieve this with your torchsort, any chance that you have any clue?

Appreciate it!

teddykoker commented 3 years ago

Torchsort doesn't have a differentiable argsort. Because argsort returns indices, it would be difficult to have a "soft" version as the output is discreet.

That being said, just because argsort itself isn't differentiable doesn't mean you can't use it:

In [1]: import torch

In [2]: import torchsort

In [3]: a = torch.rand(10, requires_grad=True)

In [4]: a
Out[4]:
tensor([0.6678, 0.5866, 0.1417, 0.7602, 0.3445, 0.5952, 0.9163, 0.7174, 0
.4133,
        0.3033], requires_grad=True)

In [5]: b = a * 2

In [6]: largest_idx = torch.argsort(b)[-1]

In [7]: torch.autograd.grad(b[largest_idx], a)
Out[7]: (tensor([0., 0., 0., 0., 0., 0., 2., 0., 0., 0.]),)

Additionally, if you do not care about the index, only the value itself, you can use torchsort.soft_sort:

In [1]: import torch

In [2]: import torchsort

In [3]: a = torch.rand(1, 10, requires_grad=True)

In [4]: a
Out[4]:
tensor([[0.4250, 0.5007, 0.6709, 0.1663, 0.3852, 0.7698, 0.6450, 0.5475,
0.4423,
         0.1594]], requires_grad=True)

In [5]: b = a * 2

In [6]: largest = torchsort.soft_sort(b)[0, -1]

In [7]: torch.autograd.grad(largest, a)
Out[7]: (tensor([[0., 0., 0., 0., 0., 2., 0., 0., 0., 0.]]),)

Whatever you are trying to do can likely be achieved with one of these methods, but I would need to know more information to offer more help.

roeeben commented 3 years ago

Thank you :)