Open Indhujamari opened 3 years ago
@Indhujamari just throwing out extra things for you to try, but have you tried:
@JordanMakesMaps Thank you for your inputs. Let me try it out.
I also have the same issue with multi class segmentation where my dataset contains about 8 classes (6 classes + 1 background + 1 ignore) which are highly imbalanced.But how to manually assign class weights to this imbalanced dataset... Please help me to solve this issue. Thanks in advance!
To set the weights use a function like this:
import numpy as np
import math
# labels_dict : {ind_label: count_label}
# mu : parameter to tune
def create_class_weight(labels_dict,mu=0.15):
total = np.sum(list(labels_dict.values()))
keys = labels_dict.keys()
class_weight = dict()
for key in keys:
score = math.log(mu*total/float(labels_dict[key]))
class_weight[key] = score if score > 1.0 else 1.0
return class_weight
# random labels_dict
labels_dict = {0: 2813, 1: 78, 2: 2814, 3: 78, 4: 7914, 5: 248, 6: 7914, 7: 248}
create_class_weight(labels_dict)
If you don't want a class to be learned by the model, change its weights to 0. Use the class_weights
with either the loss function or in model.fit, I'm not sure what happens if you do both though.
@JordanMakesMaps Thank you for your solution! I'll definitely try it.
I am currently working on the multi-class segmentation of Agricultural aerial images. I have about 7 classes which are highly imbalanced both by image and no.of.pixels that each class has.
I have tried following steps to overcome this class imbalance:
1.I have calculated the class weights based on the procedure mentioned here https://github.com/qubvel/segmentation_models/issues/137#issuecomment-515774767
And obtained my class weights as [ 1.52601775 27.67985103 126.29271733 402.83724388 72.80682642 86.53852883 6.85154221 ]
since class weights are recommended to be [0-1]. I converted the above weights to decimal [0.0037 0.068 0.3141 1.002 0.1811 0.2152 0.017 ]
2. And I tried with combined focal and dice loss with and without class weights.
3. Also,I tried with both encoder_freeze = True and encoder_freeze = False
4.I have also checked that my masks are in the range[0-1]
I am using UNET with resnet-50 backbone.
But the val_loss does not decrease,and it continues to fluctuate above 93% and also I am not able to get good IOU score.
Please give me some inputs to overcome the above issue and also i need to know how to print the class wise metrics
Thanks in advance