rickgroen / cov-weighting

Implementation for our WACV 2021 paper "Multi-Loss Weighting with Coefficient of Variations"
MIT License
49 stars 9 forks source link

How to use your method to my algorithm? #7

Closed leolin65 closed 1 year ago

leolin65 commented 1 year ago

Hi rick,

I want to use your method to optimize my multiple losses. Below codes are my part of algorithm? May I know how to adopt your code, please give me suggestion as detail as possible? thank you vary much.

========================================================== l2_loss = loss_l2(gray_rec,gray_batch) ssim_loss = loss_ssim(gray_rec, gray_batch) segment_loss = loss_focal(out_mask_sm, anomaly_mask)

            target = make_one_hot(anomaly_mask, num_classes=2).cuda()
            criterion = DiceLoss(reduction='mean') #ignore_index=[2, 3],reduction='mean' 
            dice_loss = criterion(out_mask_sm, target)

            loss = l2_loss + ssim_loss + segment_loss + dice_loss  

            l2_losses.update(l2_loss.item(), gray_rec.size(0))
            ssim_losses.update(ssim_loss.item(), gray_rec.size(0))

            segment_losses.update(segment_loss.item(), out_mask_sm.size(0))                
            dice_losses.update(dice_loss.item(), out_mask_sm.size(0))    

            optimizer.zero_grad()

            loss.backward()
            optimizer.step()

        Train_total_losses = l2_losses.avg + ssim_losses.avg + segment_losses.avg +  dice_losses.avg
rickgroen commented 1 year ago

Hi Leolin,

I would probably make a new class of BaseLoss found in losses/base_loss.py. Its forward function returns a list of torch tensors of size 1. A basic example (you may need to adjust based on your specific code):

class YourNewBaseLoss(nn.modules.Module):

def __init__(self, args):
    super(NewBaseLoss, self).__init__()
    self.device = args.device
    # Record the weights.
    self.num_losses = 4
    self.alphas = torch.zeros((self.num_losses,), requires_grad=False).type(torch.FloatTensor).to(self.device)
    # Your loss functions (if they need to be initialized).
    self.dice_loss = DiceLoss(reduction='mean')

def forward(self, pred, target):
    """ Compute your losses and return them as a list.
    """
    dice_loss = self.dice_loss(pred, target)
    ssim = loss_ssim(pred, target)
    l2 = loss_l2(pred, target)
    segment = loss_focal(pred, target)
    return [dice_loss, ssim, l2, segment]

Then set it as the new super class for the the COV-weighting class in losses/covweighting_loss.py like this: class CoVWeightingLoss(YourNewBaseLoss): ... In addition, you can do the same for any other loss classes to compare to the other methods for which code is available.

leolin65 commented 1 year ago

Hi Rick,

Thank you for your kindly support.

I already follow your suggestion to add new class YourNewBaseLoss(nn.modules.Module) named MultiLoss():, I want to ask more about how to pass two values(pred, target) to forward function.

My existing four values(pred1/2, target1/2) are embeded in for loop (see below) for i_batch, sample_batched in enumerate(train_loader):

old

l2_loss = loss_l2(pred1, target1) ssim_loss = loss_ssim(pred1, target1) segment_loss = loss_focal(pred2, target2) dice_loss = criterion(pred2, target2)

new

class MultiLoss_recconstruction(nn.modules.Module):

def __init__(self, args):
    super(MultiLoss_recconstruction, self).__init__()
    self.device = args.device
    # Record the weights.
    self.num_losses = 2
    self.alphas = torch.zeros((self.num_losses,), requires_grad=False).type(torch.FloatTensor).to(self.device)
    # Your loss functions (if they need to be initialized).
    self.dice_loss = DiceLoss(reduction='mean')

def forward(self, pred, target):
    """ Compute your losses and return them as a list.
    """
    l2 = loss_l2(pred, target)
    ssim = loss_ssim(pred, target)
    loss = l2 + ssim
    #return [ l2, ssim]
    return loss

class MultiLoss_discriminative(nn.modules.Module):

def __init__(self, args):
    super(MultiLoss_discriminative, self).__init__()
    self.device = args.device
    # Record the weights.
    self.num_losses = 2
    self.alphas = torch.zeros((self.num_losses,), requires_grad=False).type(torch.FloatTensor).to(self.device)
    # Your loss functions (if they need to be initialized).
    self.dice_loss = DiceLoss(reduction='mean')

def forward(self, pred, target):
    """ Compute your losses and return them as a list.
    """
    segment = loss_focal(pred, target)
    dice_loss = self.dice_loss(pred, target)
    loss = segment +  dice_loss
    #return [ segment, dice_loss ]
    return loss

to replace existing loss = l2_loss + ssim_loss + segment_loss + dice_loss and loss.backward()

loss_recconstruction = CoVWeightingLoss_rec(MultiLoss_recconstruction(gray_rec,gray_batch)) loss_discriminative = CoVWeightingLoss_dis(MultiLoss_discriminative(out_mask_sm, target))

This code happened below error, how can I fix it? loss_recconstruction = CoVWeightingLoss_rec(MultiLoss_recconstruction(gray_rec,gray_batch)) TypeError: init() takes 2 positional arguments but 3 were given

And How to optimize two types of loss functions? thank you very much.

rickgroen commented 1 year ago

Hi Leolin,

As I understand, you have two separate issues. First, you want to know how to pass multiple predictions and targets to the loss class. The second question concerns how to call the loss.

To answer the first question: The easiest solution would be to pass lists of predictions and targets to the loss. You will need to unpack those inside your MultiLoss forward function. So something along the lines of:

  def forward(self, preds, targets):
    pred1, pred2 = preds
    target1, target2 = target

And compute your individual loss functions using these variables.

To answer your second question: You do not need to supply an instance of the MultiLoss class to the CovWeighting class. The CovWeighting class has a super class (MultiLoss), which is constructed the moment you call its initializer. The only change you make to the CovWeighting class is: class CovWeighting(BaseLoss) becomes class CovWeighting(MultiLoss) in losses/covweighting_loss.py.

And then simply initialize the CovWeighting class using loss_reconstruction = CovWeighting(args)

Good luck!

leolin65 commented 1 year ago

Hi Rick,

Thank you for your kindly support.

I follow your suggestion to modify base_loss.py and happened below error.

==================================== covweighting_loss.py", line 18, in init self.mean_decay = True if args.mean_sort == 'decay' else False AttributeError: 'Namespace' object has no attribute 'mean_sort'

================================= Below codes are my MultiLoss class.

class MultiLoss_recconstruction(nn.modules.Module): def init(self, args): super(MultiLoss_recconstruction, self).init() self.device = args.gpu_id

Record the weights.

    self.num_losses = 2
    self.alphas = torch.zeros((self.num_losses,), requires_grad=False).type(torch.FloatTensor).to(self.device)
    # Your loss functions (if they need to be initialized).

#def forward(self, preds, targets):
def forward( self, preds, targets):
    """ Compute your losses and return them as a list.
    """
    #pred1, pred2 = preds
    #target1, target2 = targets
    l2 = loss_l2(preds, targets)  
    ssim = loss_ssim(preds, targets)
    #loss = l2 + ssim
    return [ l2, ssim]
    #return loss

Could you help me to check what should I do? thank you very much.

rickgroen commented 1 year ago

From your code above, I would change two things:

First, the error that you are getting seems to indicate that you are not initializing the CovWeighting class correctly. Take for example the correct initialization of the loss class in methods/covweighting_method.py, line 14. There I supply args, which are the set of user inputs. You can see how to get these at train.py, line 90. This will sort your namespace error.

Second, in your current implementation of the Multiloss, are you passing a list of tensors as preds, targets? Or are you passing tensors as preds, targets? I do not know what your loss_ssim and loss_l2 functions accept, but I would expect them to need tensors? Did you want to first try the method using only l2 and ssim, since they require the same predictions and ground truths?

leolin65 commented 1 year ago

Hi Rick,

Thank you for your kindly support.

First, I follow your suggestion to add argument mean_sort and mean_decay_param. So, it was fixed. Second, you are right both of them(preds, targets) are tensors. I face new problem below.

line 171, in train_on_device lossCoVWeighting_rec = loss_recconstruction(MultiLoss_rec1) File "C:\Users\leolin.conda\envs\env1\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) TypeError: forward() missing 1 required positional argument: 'target'

My codes are

        MultiLoss_rec = MultiLoss_recconstruction(args)
        MultiLoss_rec1 = MultiLoss_rec(gray_rec, gray_batch)
        loss_recconstruction = CoVWeightingLoss_rec(args)

        lossCoVWeighting_rec = loss_recconstruction(MultiLoss_rec1) ==> happened error.

Sorry to bother you again.

rickgroen commented 1 year ago

Could you please read the documentation on python classes?

To repeat my statement from before: You do not use both MultiLoss and CovWeighting. CovWeighting is a derived class of MultiLoss. That means when you initialize CovWeighting, it automatically initializes all methods of its super class MultiLoss. For your code:

loss_covweighting_class = CoVWeightingLoss_rec(args)
loss_value = loss_covweighting_class(gray_rec, gray_batch)

That is all the code you need...

leolin65 commented 1 year ago

Hi Rick,

Thank you for your kindly support. I already can run the codes. Thank you again.

Shaw-Way commented 8 months ago

Hi Rick,

Thank you for your kindly support. I already can run the codes. Thank you again.

Hi, Leo. I tried this method in my project, but it didn't work. The loss didn't even converge. Has this method effectively improved the performance of your model?