qubvel / segmentation_models.pytorch

Semantic segmentation models with 500+ pretrained convolutional and transformer-based backbones.
https://smp.readthedocs.io/
MIT License
9.14k stars 1.63k forks source link

Model can't predict more than 3 classes #791

Closed ctyler9 closed 12 months ago

ctyler9 commented 1 year ago

I have tried training both a Unet and SMP model with the parameters below, but i can only get 3 clasess to output.

encoder: 'resnet34' encoder_weights: 'image_net' classes: [class_1, ... , class_n] activation: 'sigmoid' loss: jaccardLoss/DICEloss optimizer, adam

model = smp.Unet( encoder_name= encoder, encoder_weights=encoder_weights, classes=len(classes), activation=activation)

The data outputs correctly and in the visualization preprocessing i can clearly see each class represented from the masks. The ground truth for example shows all segmented classes. for some reason when training however it will only be able to predict for 3 of them. The modek output is (R, C, N) where N is the number of classes, so it's not outputting only 3 channels.

Any idea what could be wrong? i have a feeling the encoder/decoder classes are not updating to match the number of classes.

ljb-1 commented 1 year ago

Have you solved the problem? I have the same problem.

ctyler9 commented 1 year ago

I have not. @qubvel is there a limit to the number of classes the models can predict?

ctyler9 commented 12 months ago

https://github.com/qubvel/segmentation_models.pytorch/pull/500

after digging around in the code a bit more and looking at past issues i found this. The issue with a lot of the multiclass prediction is the custom implemented Dice loss and jaquard loss do not do well with imbalanced classes as figured out in this post. try doing some of the other loss functions made in pytorch at the time being, may work on something to automatically adjust loss functions to class imbalances