computational-cell-analytics / micro-sam

Segment Anything for Microscopy
https://computational-cell-analytics.github.io/micro-sam/
MIT License
311 stars 35 forks source link

Add SemanticSamTrainer #637

Closed anwai98 closed 1 week ago

anwai98 commented 2 weeks ago

@constantinpape Here is the trainer for semantic segmentation using SAM. Let me know if this aligns with what we discussed.

anwai98 commented 2 weeks ago

Thanks @constantinpape. I tested this in a 2d dataset, looks like it's doing the job as expected. This is GTG from my side now (only pending a few minor discussion in the evaluation PR)

anwai98 commented 2 weeks ago

Hi @constantinpape,

Looks like the mutli-class semantic segmentation works as expected now (atleast from the first looks of the Tensorboard logs). I am not a big fan of the workarounds I had to apply to make this work, but maybe we find a better way to make things work in a much more modular setup. Let me know if you spot something. We can discuss together tomorrow the details anyways.

ADDITION: I added the support for an added loss function (cross entropy) over the logits (between the low_res_masks returned by the model and the downscaled version of the ground-truth). Looks like it's working as expected, and converges a bit faster compared to just dice over masks.

anwai98 commented 1 week ago

@constantinpape,

I've removed the downscaling of masks. Should be GTG now. Thanks!

PS. Tested it on a quick training as well