JunMa11 / SegWithDistMap

How Distance Transform Maps Boost Segmentation CNNs: An Empirical Study
https://openreview.net/forum?id=hM4pNbXWst
Apache License 2.0
378 stars 67 forks source link

Questions about the hd loss #22

Closed nohsoocheol closed 2 years ago

nohsoocheol commented 2 years ago

Thanks for sharing your great work and I have questions about the Hausdorff loss.

I saw your implementation of the Hausdorff loss in train_LA_HD.py. In tran_LA_HD.py, two augments(seg_dtm and gt_dtm) of the function hd_loss are calculated using numpy and scipy. In my knowledge, using numpy and scipy (also torch.no_grad()) broke the pytorch backward graphs of the augments, so I think the augments can be only used for multiplying gradients from the others (seg_soft, gt).

My questions are, If I'm right, is it okay to calculate Hausdorff loss in train_LA_HD.py? Is numpy and scipy being used because the distance transform mapping is non-differentiable? If the distance transform mapping is differentiable (I think it's differentiable because of Karimi's paper, but I'm not sure), could calculating seg_dtm and gt_dtm using pytorch operation help to improve the model?

dy2728 commented 2 years ago

The loss format is like: (seg_soft - gt)^2 x (seg_dtm^2 + gt_dtm^2). The first part (seg_soft - gt)^2 is a pytorch tensor, which has grad. While the second part (seg_dtm^2 + gt_dtm^2) is calculated by numpy and scipy, which can be viewed as a multiplicative factor. Note that we propagate gradients by the first part (seg_soft - gt)^2.

nohsoocheol commented 2 years ago

Thank you for the answer. So actually the parts calculated by numpy and scipy act like weights. Thanks!!