Closed YashRunwal closed 2 years ago
Hi,
I would probably make a new class of BaseLoss found in losses/base_loss.py. Its forward function returns a list returns a list of torch tensors of size 1. So something like:
class NewBaseLoss(nn.modules.Module):
def __init__(self, args):
super(NewBaseLoss, self).__init__()
self.device = args.device
# Record the weights.
self.num_losses = 4
self.alphas = torch.zeros((self.num_losses,), requires_grad=False).type(torch.FloatTensor).to(self.device)
def forward(self, pred, target):
cls_loss_function = HeatmapLoss(reduction="mean") # Custom Loss Function
txty_loss_function = nn.BCEWithLogitsLoss(reduction='none')
twth_loss_function = nn.SmoothL1Loss(reduction='none')
depth_loss_funciton = nn.BCEWithLogitsLoss(reduction='none')
return [cls_loss, txty_loss, twth_loss, depth_loss]
Then adapt the COV-weighting class in losses/covweighting_loss.py like this:
class CoVWeightingLoss(NewBaseLoss): ...
Note that this should be possible for the static_loss and uncertainty_loss as well.
I haven't tested the code above, so let me know if I can help out.
Hi,
Thanks for replying. I am a little busy right now but I will definitely test it and let you know. Will keep this Issue open till then.
Need some time :)
Have a nice day.
Closing this for now. Feel free to open another issue once you tested the code ;)
Hi,
I have an object detection model which has the following losses:
Therefore in my
train.py
script, I have a total loss that is backpropagated.total_loss = cls_loss + txty_loss + twth_loss + depth_loss
However, the txty_loss dominates the aforementioned losses as its value is highest. Hence, I would like to use COV-Weighting.How should I use the COV-weighting for my case?