zihangJiang / TokenLabeling

Pytorch implementation of "All Tokens Matter: Token Labeling for Training Better Vision Transformers"
Apache License 2.0
425 stars 36 forks source link

Balanced class weights #11

Closed javierrodenas closed 3 years ago

javierrodenas commented 3 years ago

Hello,

I was trying to compute the class weight "balanced". I see that there are two arguments:

parser.add_argument('--dense-weight', type=float, default=0.5,
                    help='Token labeling loss multiplier (default: 0.5)')
parser.add_argument('--cls-weight', type=float, default=1.0,
                    help='Cls token prediction loss multiplier (default: 1.0)')

How can I add the parameter "weights" into the loss?

What I did so far:

from sklearn.utils import class_weight

class_weights = class_weight.compute_class_weight('balanced', np.unique(target_values), target_values.numpy())
class_weights = torch.tensor(class_weights, dtype=torch.float)

train_loss_fn = nn.CrossEntropyLoss(weight=class_weights).cuda()

See that I am changing the loss function, before I was using TokenLabelCrossEntropy:

train_loss_fn = TokenLabelCrossEntropy(dense_weight=args.dense_weight,\
      cls_weight = args.cls_weight, mixup_active = mixup_active, ground_truth=args.ground_truth).cuda()

Thank you in advance

zihangJiang commented 3 years ago

Hi @javierrodenas , we haven't conducted any experiments using class weight to balance the classes. I think you can try to modify the SoftTargetCrossEntropy loss function here https://github.com/zihangJiang/TokenLabeling/blob/09bb641b1e8f3e94fa1b6c7180addf4507458541/tlt/loss/cross_entropy.py#L6-L17 to add the balanced class weights.