google-research / fast-soft-sort

Fast Differentiable Sorting and Ranking
Apache License 2.0
570 stars 47 forks source link

Gradients not backpropagated in Pytorch #7

Closed abhi1kumar closed 4 years ago

abhi1kumar commented 4 years ago

Hi Authors, Thank you for releasing your code. I tried checking the numerical gradients in Pytorch. With the soft_sort module, I do not obtain the gradients after doing loss.backward(). However, when I do not use the soft_sort module, I am able to obtain the gradients. Below are the code snippets for the two situations

With the soft_sort module

import torch
import torch.nn as nn
from fast_soft_sort.pytorch_ops import *

conf = torch.tensor([[0.1, 0.7, 0.2]], dtype=torch.float64, requires_grad= True)

# Sorting in descending order
conf = soft_sort(conf, direction="DESCENDING", regularization_strength= 1)

ideal = torch.tensor([[1.0, 0.0, 0]], dtype=torch.float64)
l2    = nn.MSELoss()

print("Grad before doing loss.backward()")
print(conf.grad)

loss = l2(conf, ideal)
loss.backward()

print("Loss = {:.2f}".format(loss))
print("Grad after loss.backward()")
print(conf.grad)

, the outputs are

Grad before doing loss.backward()
None
Loss = 0.05
Grad after loss.backward()
None

Without soft_sort

import torch
import torch.nn as nn
from fast_soft_sort.pytorch_ops import *

conf = torch.tensor([[0.1, 0.7, 0.2]], dtype=torch.float64, requires_grad= True)

# Sorting in descending order
# conf = soft_sort(conf, direction="DESCENDING", regularization_strength= 1)

ideal = torch.tensor([[1.0, 0.0, 0]], dtype=torch.float64)
l2    = nn.MSELoss()

print("Grad before doing loss.backward()")
print(conf.grad)

loss = l2(conf, ideal)
loss.backward()

print("Loss = {:.2f}".format(loss))
print("Grad after loss.backward()")
print(conf.grad)

, the outputs are

Grad before doing loss.backward()
None
Loss = 0.45
Grad after loss.backward()
tensor([[-0.6000,  0.4667,  0.1333]], dtype=torch.float64)
josipd commented 4 years ago

Hey Abhinav,

This is because you are re-using the variable name conf (conf = sort_tensor(...)). I tried and it works if you change the name of the sorted variable.

mblondel commented 4 years ago

@abhi1kumar If you can confirm that this solved your problem, we'll close the issue.

abhi1kumar commented 4 years ago

Thankyou. Yes, it solves my issue. You can close this thread.