JunMa11 / SegWithDistMap

How Distance Transform Maps Boost Segmentation CNNs: An Empirical Study
https://openreview.net/forum?id=hM4pNbXWst
Apache License 2.0
378 stars 67 forks source link

sdf aaai loss implementation - summation order #25

Open kretes opened 10 months ago

kretes commented 10 months ago

Hello,

I'm looking at the function

https://github.com/JunMa11/SegWithDistMap/blob/153dabf3bc5d9e48058e1497857ac6b00c7abab8/code/train_LA_AAAISDF.py#L95C1-L108C1

and I don't understand why the summation order is such e.g. that it sums each of the parts of the equation in isolation (intersection, pd_sum, gt_sum) while according to equation from the paper:

image

the summation should be outside.

It does matter to the calculation:

smooth = 1e-5

net_output = torch.Tensor([0.3,0.4])
gt_sdm = torch.Tensor([0.05,-0.05])

# original
intersect = torch.sum(net_output * gt_sdm)
pd_sum = torch.sum(net_output ** 2)
gt_sum = torch.sum(gt_sdm ** 2)

L_product_orig = (intersect + smooth) / (intersect + pd_sum + gt_sum + smooth)

# summed at the end according to the equation
intersect = net_output * gt_sdm
pd_sum = net_output ** 2
gt_sum = gt_sdm ** 2

L_product_changed = (intersect + smooth) / (intersect + pd_sum + gt_sum + smooth)

print(L_product_orig, L_product_changed.sum())

shows:

tensor(-0.0200) tensor(-0.0007)