yassouali / pytorch-segmentation

:art: Semantic segmentation models, datasets and losses implemented in PyTorch.
MIT License
1.68k stars 381 forks source link

How to set the "weight" for CrossEntropyLoss2d in config file? #140

Closed panovr closed 2 years ago

panovr commented 2 years ago

Hi, my dataset has class imblance problem, so I want to set the "weight" for CrossEntropyLoss2d function. Suppose my dataset has 3 classes (1 background + 2 foregrounds) to be segmented, may I ask how to set the "weight" for CrossEntropyLoss2d in config file? Thanks!

yassouali commented 2 years ago

Hi @panovr

You can first go through your dataset and then compute it manually by using the frequency of each class, or you can use some function like compute_class_weight of scikit learn, then set them in your loss.