vandit15 / Class-balanced-loss-pytorch

Pytorch implementation of the paper "Class-Balanced Loss Based on Effective Number of Samples"
MIT License
784 stars 120 forks source link

base information about cbloss #7

Closed Adorablepet closed 4 years ago

Adorablepet commented 4 years ago

Thanks for sharing code.I have modified the code. I want to calculate the multivariate classification. The code is as follows. I can always report errors. I wonder if you have tried it?

pred = logits.log_softmax(dim=1)
cb_loss = F.cross_etropy(input=pred, target=labels, weights=weights)

Thanks.

LcenArthas commented 4 years ago

i have the same questios, have u sove it?

Adorablepet commented 4 years ago

@LcenArthas I have solved.

def forward(self, logits, labels):
      labels = labels.cuda()
      loss = F.cross_entropy(input=logits, target=labels, weight=self.weight)
      return loss