Digital-Dermatology / t-loss

Official code for Robust T-Loss for Medical Image Segmentation (MICCAI 2023)
https://robust-tloss.github.io/
Apache License 2.0
41 stars 6 forks source link

Implementation questions + tf backend #1

Closed andreped closed 1 year ago

andreped commented 1 year ago

I wanted to reimplement the loss function with tensorflow (tf) backend, to enable people working with tf/keras to test this loss. I have made an initial attempt which is hosted at: https://github.com/andreped/t-loss-tf

  1. Loss terms: When reimplementing the loss, I noticed some strange things. Of the six terms, only one of them seemed to actually be needed to be ran in the forward method. Is that true? If so, the five other terms should rather be only ran once in the __init__() which should make the overall forward() step faster.

  2. Multi-class support It seems like the current implementation does not support a channel axis, which is often relevant when working with multiple classes (one-hot encoding). The input shapes to the loss are just B x H x W, hence, to support multiple classes one would need to encode the GT values as separate ints, but that rarely works well in practice. By onehotting, one could compute the loss for each class separately and compute the macro-average, which should work better to handle class imbalance.

  3. torch.nn.Parameter I'm not that familiar with torch, but from here it seems like you initialize the nu parameter as a learnable parameter, is that correct? I was not sure, so I have just kept it fixed in my implementation for now. I so, two of the terms can be moved to the init (4th and 5th terms).

My initial tests using my implementation demonstrate some strange behaviour. By providing a ones tensor as Pred and GT, I get an extremely high value. Have you observed the same?

lionettis commented 1 year ago

Dear andreped,

Thanks for your interest in the T-Loss. Here are my tentative answers:

  1. All terms that contain nu have to be computed during the forward pass, because this parameter is updated during training. The fourth term could be computed in the __init__ method, but we expect this to have minimal effect.
  2. This is correct, we still need to add multi-class support. There are a few nontrivial choices to be made, especially if tolerance should be adjusted on a per-class level or globally (there are arguments for both, probably this should be a choice).
  3. nu is indeed a trainable parameter. If I am not wrong similar behaviour to torch.nn.Parameter can be achieved with tf.Variable. Note that the fifth term depends on nu so it cannot be computed in the __init__.

Best, Simone

andreped commented 1 year ago

Thank you for the prompt reply, @lionettis!

I have corrected for your comments and rewritten the loss in TensorFlow to work in a Keras workflow. I also built a precompiled wheel and made it available in a release here.

Still I notice that the initial loss values are very high at the start of training. Have you observed the same?

alvarogonjim commented 1 year ago

@andreped, indeed, we experienced extremely high loss values at the initial stages of training. This could be attributed to the utilization of exp operations. The crucial factor to consider is ensuring that the loss starts to be minimized after a certain number of iterations. Another significant aspect is that the elevated loss aligns with the noisy images, whereas for normal/clean images, the value remains small (even negative).

andreped commented 1 year ago

@alvarogonjim OK, thank you for the heads up! I was just worried I did something wrong.

But in that case, I believe this issue can be closed, as my initial concerns have been addressed.

Could be an idea to add to your README that a reimplementation of the loss function in TensorFlow exists and add a hyperlink to the t-loss-tf repo. There are still too many stubborn developers still using TF/Keras (myself included ;)) who could be interested in testing it in their own projects/studies.

alvarogonjim commented 1 year ago

@andreped Absolutely! I'll make sure to include it in the README. Thanks a lot for your contribution – it's greatly appreciated! :+1: