Closed roeeben closed 3 years ago
Torchsort doesn't have a differentiable argsort
. Because argsort returns indices, it would be difficult to have a "soft" version as the output is discreet.
That being said, just because argsort itself isn't differentiable doesn't mean you can't use it:
In [1]: import torch
In [2]: import torchsort
In [3]: a = torch.rand(10, requires_grad=True)
In [4]: a
Out[4]:
tensor([0.6678, 0.5866, 0.1417, 0.7602, 0.3445, 0.5952, 0.9163, 0.7174, 0
.4133,
0.3033], requires_grad=True)
In [5]: b = a * 2
In [6]: largest_idx = torch.argsort(b)[-1]
In [7]: torch.autograd.grad(b[largest_idx], a)
Out[7]: (tensor([0., 0., 0., 0., 0., 0., 2., 0., 0., 0.]),)
Additionally, if you do not care about the index, only the value itself, you can use torchsort.soft_sort
:
In [1]: import torch
In [2]: import torchsort
In [3]: a = torch.rand(1, 10, requires_grad=True)
In [4]: a
Out[4]:
tensor([[0.4250, 0.5007, 0.6709, 0.1663, 0.3852, 0.7698, 0.6450, 0.5475,
0.4423,
0.1594]], requires_grad=True)
In [5]: b = a * 2
In [6]: largest = torchsort.soft_sort(b)[0, -1]
In [7]: torch.autograd.grad(largest, a)
Out[7]: (tensor([[0., 0., 0., 0., 0., 2., 0., 0., 0., 0.]]),)
Whatever you are trying to do can likely be achieved with one of these methods, but I would need to know more information to offer more help.
Thank you :)
When you use the regular torch.sort, the ranks vector that's returned is sorted, so if I want the index of the maximum value I just take the last elements of the ranks vector. Same for the index of the 2nd biggest element: I just takes ranks_vec[-2] .
Unfortunately the regular torch.sort does not support a gradient. I've been trying to think of a way to achieve this with your torchsort, any chance that you have any clue?
Appreciate it!