DIAGNijmegen / pathology-hooknet

MIT License
51 stars 10 forks source link

Question about hooking mechanism #20

Open mdeleeuw1 opened 9 months ago

mdeleeuw1 commented 9 months ago

Hi Mart,

I was wondering how the weights of all branches are updated since the model only uses the high-resolution mask as target. Going through the code it seems like the weights are updated through the hooking mechanism, but I'm very curious to here your take on this.

Kind regards,

Mike

martvanrijthoven commented 9 months ago

Dear Mike,

You are correct: the loss is backpropagated via the hooking mechanism, starting from the target/high-res branch. However you can also configure the model such that it also backpropagates via the context/low-res branch. In the tensorflow model you can configure the loss weights for each branch when instantiating the model. Please let me know if you have any further questions about this.

Best wishes Mart

mdeleeuw1 commented 9 months ago

Dear Mart,

Thank you for getting back to me. Does this then mean the decoder layers beyond the hooking mechanism in low/mid-resolution branches are not used / not updated through backpropagation? I'm using the PyTorch model by the way, but I believe the hooking mechanism is similar to the TF model (except for the additional branch and lack of loss weights for each branch). Interesting to hear your perspective on these technical details!

Kind regards,

Mike

martvanrijthoven commented 9 months ago

Hi Mike,

With the PyTorch model you will get the outputs of all the branches and you can calculate and combine the losses in any way you want and use it for backpropagation. In that case the weight are updated. However if you ignore the outputs/losses for the other branches then you are correct and the weights will not be updated. I would advice to play around a bit with combining multiple losses and see what kind of setting works best for your problem. Let me know if you have any further questions.

Best wishes Mart

mdeleeuw1 commented 9 months ago

Hi Mart,

Thank you for the clarification! Just to be sure; if I only use the output of the high-res branch for backpropagation, the weights of the other branches will only be updated for the channels up to where the hooking mechanism is performed? And to update the weights of those branches beyond the channels where the hooking mechanism is performed, I have to use the output of all branches for backpropagation? Interesting!

Best wishes,

Mike

martvanrijthoven commented 9 months ago

Yes you are correct. But also note that the weights before the hooking mechanism will be impacted either by a single or multiple losses, depending if you backpropagate with context losses.

mdeleeuw1 commented 9 months ago

Hi Mart,

What do you mean precisely? The weights before the hooking mechanism will be impacted regardless of using single or multiple losses, right?

Kind regards,

Mike

martvanrijthoven commented 9 months ago

Yes indeed, maybe I added some unnecessary confusion, but just wanted to make clear that using a multiloss has a different impact on the weights before hooking than using a single loss.