teddykoker / torchsort

Fast, differentiable sorting and ranking in PyTorch
https://pypi.org/project/torchsort/
Apache License 2.0
765 stars 33 forks source link

Apply torchsort to learn real permutation for downstream tasks #55

Closed Guaishou74851 closed 1 year ago

Guaishou74851 commented 1 year ago

Hi, your torchsort is a quite solid, interesting and inspiring work!

I am trying to apply it into my task. Specifically, I want to use it to learn a permutation for a image x with a fixed shape of [b, c, h, w] (batch size, channel number, height, width). The following toy code shows my basic idea:

...
perm_param = nn.Parameter(torch.rand(h * w))  # learnable parameter for permutation
...
perm = torchsort.soft_rank(perm_param) # generate a learnable permutation via torchsort
...  # some discretization operations
x_p = x.reshape(b, c, h * w)
x_p = x[:, :, perm]  # permute the image
x_p = x.reshape(b, c, h, w)
loss = my_loss(Net(x_p), target)
...
loss.backward()
...

In my implementation, I want to use torchsort to learn a permutation based on a fixed parameter tensor perm_param for an image with fixed size. However, my basic implementation as showed above can not successfully learn the permutation since the loss.backward() would not reach perm_param and update it due to the undifferentiable operations including indexing and ones like .long() for discretization.

I am quite sure that there exists an optimal permutation in my task. However, finding it may bring an O((h * w)!) time complexity. Is there a way to learn the permutation by using torchsort? I am still trying and thinking ...

I am looking forwarding to your reply. Thank you very much for reading such a long post!