rickgroen / cov-weighting

Implementation for our WACV 2021 paper "Multi-Loss Weighting with Coefficient of Variations"
MIT License
50 stars 10 forks source link

Use COV-Weighting for custom Loss #2

Closed YashRunwal closed 2 years ago

YashRunwal commented 3 years ago

Hi,

I have an object detection model which has the following losses:

cls_loss_function = HeatmapLoss(reduction='mean')  # Custom Loss Function
txty_loss_function = nn.BCEWithLogitsLoss(reduction='none')
twth_loss_function = nn.SmoothL1Loss(reduction='none')
depth_loss_funciton = nn.BCEWithLogitsLoss(reduction='none')

Therefore in mytrain.py script, I have a total loss that is backpropagated. total_loss = cls_loss + txty_loss + twth_loss + depth_loss However, the txty_loss dominates the aforementioned losses as its value is highest. Hence, I would like to use COV-Weighting.

How should I use the COV-weighting for my case?

rickgroen commented 3 years ago

Hi,

I would probably make a new class of BaseLoss found in losses/base_loss.py. Its forward function returns a list returns a list of torch tensors of size 1. So something like:

class NewBaseLoss(nn.modules.Module):

def __init__(self, args):
    super(NewBaseLoss, self).__init__()
    self.device = args.device
    # Record the weights.
    self.num_losses = 4
    self.alphas = torch.zeros((self.num_losses,), requires_grad=False).type(torch.FloatTensor).to(self.device)

def forward(self, pred, target):
    cls_loss_function = HeatmapLoss(reduction="mean")  # Custom Loss Function
    txty_loss_function = nn.BCEWithLogitsLoss(reduction='none')
    twth_loss_function = nn.SmoothL1Loss(reduction='none')
    depth_loss_funciton = nn.BCEWithLogitsLoss(reduction='none')
    return [cls_loss, txty_loss, twth_loss, depth_loss]

Then adapt the COV-weighting class in losses/covweighting_loss.py like this: class CoVWeightingLoss(NewBaseLoss): ... Note that this should be possible for the static_loss and uncertainty_loss as well. I haven't tested the code above, so let me know if I can help out.

YashRunwal commented 3 years ago

Hi,

Thanks for replying. I am a little busy right now but I will definitely test it and let you know. Will keep this Issue open till then.

Need some time :)

Have a nice day.

rickgroen commented 2 years ago

Closing this for now. Feel free to open another issue once you tested the code ;)