Closed yan9qu closed 2 years ago
Hi,
It is probably easiest to add a new class, but if you really don't want to you could replace the forward function in BaseLoss. Be sure to change self.n and self.num_losses accordingly in the initializer.
If you do not mind adding a new class, you can go about it in the following way:
class NewBaseLoss(nn.modules.Module):
def __init__(self, args):
super(NewBaseLoss, self).__init__()
self.device = args.device
# Record the weights.
self.num_losses = 3
self.alphas = torch.zeros((self.num_losses,), requires_grad=False).type(torch.FloatTensor).to(self.device)
def forward(self, pred, target):
lossA = ...
lossB = ...
lossC = ...
return [lossA, lossB, lossC]
Then adapt the COV-weighting class in losses/covweighting_loss.py like this:
class CoVWeightingLoss(NewBaseLoss): ...
Note that this should be possible for the static_loss and uncertainty_loss as well.
I haven't tested the code above, so let me know if I can help out.
Thank you! I'll try it and reply to you today. Have a nice day!
What a nice work! And these days I try to use this method in my own algorithm. My loss is composed of three items as below: Loss = lossA + lossB + lossC How can I use your method directly but not add another "loss" class? Thank you!