rickgroen / cov-weighting

Implementation for our WACV 2021 paper "Multi-Loss Weighting with Coefficient of Variations"
MIT License
50 stars 10 forks source link

Can I get help? #4

Closed yan9qu closed 2 years ago

yan9qu commented 2 years ago

What a nice work! And these days I try to use this method in my own algorithm. My loss is composed of three items as below: Loss = lossA + lossB + lossC How can I use your method directly but not add another "loss" class? Thank you!

rickgroen commented 2 years ago

Hi,

It is probably easiest to add a new class, but if you really don't want to you could replace the forward function in BaseLoss. Be sure to change self.n and self.num_losses accordingly in the initializer.

If you do not mind adding a new class, you can go about it in the following way:

class NewBaseLoss(nn.modules.Module):

def __init__(self, args):
    super(NewBaseLoss, self).__init__()
    self.device = args.device
    # Record the weights.
    self.num_losses = 3
    self.alphas = torch.zeros((self.num_losses,), requires_grad=False).type(torch.FloatTensor).to(self.device)

def forward(self, pred, target):
    lossA = ...
    lossB = ...
    lossC = ...
    return [lossA, lossB, lossC]

Then adapt the COV-weighting class in losses/covweighting_loss.py like this: class CoVWeightingLoss(NewBaseLoss): ... Note that this should be possible for the static_loss and uncertainty_loss as well.

I haven't tested the code above, so let me know if I can help out.

yan9qu commented 2 years ago

Thank you! I'll try it and reply to you today. Have a nice day!