Closed yanghedada closed 3 years ago
Hi @yanghedada ,
Sorry for my late reply.
The direct answer is No.
We only use torch.no_grad()
when something should not be used during BP, e.g., computing the distance transform maps.
I still don't get it.
Sorry I am not very familiar to pytorch gradients things,, but shouldn't every loss functions be differentiable? Is it really OK if there is "no_grad" thing included?
@EJShim Yes!
Thank you for your reply, then is it possible to use non-differentiable operations in Loss function when the function is wrapped in torch.no_grad()?
I cannot understand how pred_dt
can still have grad_fn
after it passed through self.distance_field()
function.
pred_dt = torch.from_numpy(self.distance_field(pred.detach().cpu().numpy())).float()
target_dt = torch.from_numpy(self.distance_field(target.cpu().numpy())).float()
print(pred_dt) #grad_fn exists!
I made a simple test function that returns same np.ndarray() as input array:
@torch.no_grad()
def test(self, img : np.ndarray) -> np.ndarray:
return img
and the results torch.tensor() have lost its grad_fn
pred = torch.from_numpy(self.test( pred.detach().cpu().numpy() )).float()
print(pred) # no grad
What is the difference between test()
and distance_field()
function? how distance_field doesn't lose its grad function?
Hi @EJShim,
I got your concerns. Actually, we (and also pytorch) do not need to compute the gradient of the distance transform. You can just regard them as a pre-computed constant. Please check the equations in the original paper (HD loss, BD loss) or recent study to get more details of these loss functions.
Also, please feel free to raise questions and I'm happy to try my best to help you:)
Best, Jun
@JunMa11 Thank you very much for your kind replay. At least now I can see that this works fine.
HDDT loss optimizes interestingly compared to the others : https://www.youtube.com/watch?v=oHLykKOqytI&ab_channel=EJShim
Thanks for collecting and sharing the loss function work. There are a lot of torch.no_grad() used in the code. Will this cause the gradient to be truncated, making it impossible to update the weights?