Sllambias / yucca

Apache License 2.0
17 stars 2 forks source link

Use generalized dice everywhere #192

Closed asbjrnmunk closed 2 months ago

asbjrnmunk commented 2 months ago

Closses #191

This PR changes the YuccaLightningModule such that it used torchmetrics.segmentation.GeneralizedDice to calculate dice metrics during training. This has the following implications:

  1. We no longer need another LightningModule for label regions, since this was the only difference.
  2. We can now get Dice scores per label for once.

~It does however require us to add an argmax/sigmoid to our training and val loops, which is not ideal. Another choice would be to make two wrappers around the GeneralizedDice metric, which implements this, or to make a LightningModule for each of the task types, however in the interest of time this has been deferred for now (i'll make an issue when this PR is merged).~

asbjrnmunk commented 2 months ago

@Sllambias ready for review! Please note the above comment 🥲

asbjrnmunk commented 2 months ago

@Sllambias now cleaned up.