Closed ctyler9 closed 12 months ago
Have you solved the problem? I have the same problem.
I have not. @qubvel is there a limit to the number of classes the models can predict?
https://github.com/qubvel/segmentation_models.pytorch/pull/500
after digging around in the code a bit more and looking at past issues i found this. The issue with a lot of the multiclass prediction is the custom implemented Dice loss and jaquard loss do not do well with imbalanced classes as figured out in this post. try doing some of the other loss functions made in pytorch at the time being, may work on something to automatically adjust loss functions to class imbalances
I have tried training both a Unet and SMP model with the parameters below, but i can only get 3 clasess to output.
encoder: 'resnet34' encoder_weights: 'image_net' classes: [class_1, ... , class_n] activation: 'sigmoid' loss: jaccardLoss/DICEloss optimizer, adam
model = smp.Unet( encoder_name= encoder, encoder_weights=encoder_weights, classes=len(classes), activation=activation)
The data outputs correctly and in the visualization preprocessing i can clearly see each class represented from the masks. The ground truth for example shows all segmented classes. for some reason when training however it will only be able to predict for 3 of them. The modek output is (R, C, N) where N is the number of classes, so it's not outputting only 3 channels.
Any idea what could be wrong? i have a feeling the encoder/decoder classes are not updating to match the number of classes.