JunMa11 / SegLossOdyssey

A collection of loss functions for medical image segmentation
Apache License 2.0
3.82k stars 603 forks source link

Will torch.no_grad() make it impossible to update the weight and only calculate the value for printing? ? ? ? #17

Closed yanghedada closed 3 years ago

yanghedada commented 3 years ago

Thanks for collecting and sharing the loss function work. There are a lot of torch.no_grad() used in the code. Will this cause the gradient to be truncated, making it impossible to update the weights?

JunMa11 commented 3 years ago

Hi @yanghedada ,

Sorry for my late reply.

The direct answer is No. We only use torch.no_grad() when something should not be used during BP, e.g., computing the distance transform maps.

EJShim commented 3 years ago

I still don't get it.

Sorry I am not very familiar to pytorch gradients things,, but shouldn't every loss functions be differentiable? Is it really OK if there is "no_grad" thing included?

JunMa11 commented 3 years ago

@EJShim Yes!

EJShim commented 3 years ago

Thank you for your reply, then is it possible to use non-differentiable operations in Loss function when the function is wrapped in torch.no_grad()?

EJShim commented 3 years ago

I cannot understand how pred_dt can still have grad_fn after it passed through self.distance_field() function.

        pred_dt = torch.from_numpy(self.distance_field(pred.detach().cpu().numpy())).float()
        target_dt = torch.from_numpy(self.distance_field(target.cpu().numpy())).float()
        print(pred_dt) #grad_fn exists!

I made a simple test function that returns same np.ndarray() as input array:

    @torch.no_grad()
    def test(self, img : np.ndarray) -> np.ndarray:
        return img

and the results torch.tensor() have lost its grad_fn

        pred = torch.from_numpy(self.test( pred.detach().cpu().numpy() )).float()
        print(pred) # no grad

What is the difference between test() and distance_field() function? how distance_field doesn't lose its grad function?

JunMa11 commented 3 years ago

Hi @EJShim,

I got your concerns. Actually, we (and also pytorch) do not need to compute the gradient of the distance transform. You can just regard them as a pre-computed constant. Please check the equations in the original paper (HD loss, BD loss) or recent study to get more details of these loss functions.

Also, please feel free to raise questions and I'm happy to try my best to help you:)

Best, Jun

EJShim commented 3 years ago

@JunMa11 Thank you very much for your kind replay. At least now I can see that this works fine.

HDDT loss optimizes interestingly compared to the others : https://www.youtube.com/watch?v=oHLykKOqytI&ab_channel=EJShim