teddykoker / torchsort

Fast, differentiable sorting and ranking in PyTorch
https://pypi.org/project/torchsort/
Apache License 2.0
765 stars 33 forks source link

Some bug #60

Closed Senwang98 closed 1 year ago

Senwang98 commented 1 year ago

@teddykoker Hello, i am using torchsort. But I found there is something I can't understand when I ran the following code:

    x = torch.tensor([[1., -2., 2., 3., 0.5, -1.]])
    print(torchsort.soft_rank(x))

I got tensor([[3.8750, 1.0000, 4.8750, 5.8750, 3.3750, 2.0000]]) rather than [4, 1, 5, 6, 3, 2] Why?

teddykoker commented 1 year ago

Hi @Senwang98, this is the expected behavior - in order for the function to remain differentiable the output is not necessarily discrete. This can be controlled with the regularization_strength parameter:

In [2]: import torch
In [3]: import torchsort
In [4]: x = torch.tensor([[1., -2., 2., 3., 0.5, -1.]])

# default regularization strength
In [5]: torchsort.soft_rank(x, regularization_strength=1.0)
Out[5]: tensor([[3.8750, 1.0000, 4.8750, 5.8750, 3.3750, 2.0000]])

# decreased regularization brings values closer to true rank
In [6]: torchsort.soft_rank(x, regularization_strength=0.1)
Out[6]: tensor([[4., 1., 5., 6., 3., 2.]])

# increased regularization brings values closer together
In [7]: (torchsort.soft_rank(x, regularization_strength=10))
Out[7]: tensor([[3.5417, 3.2417, 3.6417, 3.7417, 3.4917, 3.3417]])

Hopefully this addresses your concern. Please let me know if you have any other questions. Happy new year!

Senwang98 commented 1 year ago

@teddykoker Great, thanks very much, I understand it! Good luck and happy new year!