christianpayer / MedicalDataAugmentationTool

GNU General Public License v3.0
165 stars 47 forks source link

About the optimal sigma #11

Open ChessQian opened 4 years ago

ChessQian commented 4 years ago

Hi, I've read your nice work Integrating spatial configuration into heatmap regression based CNNs for landmark localization, in the paper, you said "we enable learning of the optimal heatmap peak width separately for each landmark, depending on the prediction confidences of the network", I am really confused that how can I use the prediction confidences to optimize the sigma of the heatmaps?

christianpayer commented 4 years ago

Hi, thanks for your interest in this. To see how our network manage to learn the optimal sigma peak width, look at the spine localization experiment.

Most parts are in the https://github.com/christianpayer/MedicalDataAugmentationTool/blob/master/bin/experiments/localization/spine/main.py At line 128 you see the initialization of the sigma variables. Then at line 137 the sigmas are used to create the target heatmap images. At line 138 there is the regularization loss of the sigmas and at line 143 there is the l2 loss function for the network output compared with the target heatmaps.

Note that the used tensorflow functions all provide gradients and allow backpropagation to the inputs, i.e. through the network to the input image and through the function generate_heatmap_target to the sigma variable. Look at the tensorflow function tf.gradients and use the variable sigmas as parameter, if you want to investigate this further.

The backpropagation to the sigma variables can be seen as: l2_loss(network_output, target_heatmap) -> generate_heatmap_target(sigma) -> sigma variable and l2_loss(sigma) -> sigma variable

Minimizing the first part alone would lead to very large sigma values, as the best (minimal and trivial) solution for the l2_loss(network_output, target_heatmap) would be a target heatmap with values close to 0 everywhere. Minimizing the second part alone (l2_loss(sigma)) would lead to sigma values that are 0. The combination allows to find the optimal heatmap peak widths for every landmark.

I hope my rough explanation could help you figuring out how our method and code are working.