Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.24k stars 3.38k forks source link

[how-to] Handle multiple losses and/or weighted losses? #2645

Closed vr140 closed 4 years ago

vr140 commented 4 years ago

What is your question?

I'm trying replicate the model built in https://github.com/ekagra-ranjan/AE-CNN with Pytorch Lightning as a way to learn the framework.

They use an Auto-encoder along with a CNN (e.g. Inception V3), and this means there are multiple loss functions for each model, with a separate weight for each:

loss1 = MSE loss for auto-encoder
loss2 = 0.8 * BCE loss for one branch of inceptionv3 + 0.2 * BCE loss for another branch of inceptionv3
overall loss = 0.1 * loss1 + 0.9 * loss2

How would I ensure the losses are backpropagated correctly through: a) the different models (autoencoder and inceptionv3) b) the different branches of inceptionv3

Code

Here is my Lightning module code:

  def __init__(self, classCount, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.module = ae_cnn_models.AECNN(classCount)
    self.classCount = classCount
    # ... more initialization elided for simplicity
    self.loss1 = torch.nn.MSELoss(size_average=True)
    self.loss2 = torch.nn.BCELoss(size_average=True)

  def configure_optimizers(self):
    # ... elided for simplicity

  def train_dataloader(self):
    # ... elided for simplicity

  def training_step(self, batch, batch_idx):
    x, y = batch
    input = self.trans_train(x)
    input = input.to(dtype=torch.float32, device=self.device)
    varInput = torch.autograd.Variable(input)
    varTarget1 = torch.autograd.Variable(input)
    varTarget2 = torch.autograd.Variable(y)
    varOutput1, varOutput2 = self.module.forward(varInput)
    classifierOut1, classifierOut2 = varOutput2

    # autoencoder model loss:
    lossvalue1 = self.loss1(varOutput1, varTarget1)
    # weighting between main and aux branch of inception model:
    lossvalue2 = 0.8 * self.loss2(classifierOut1, varTarget2) + 0.2 * self.loss2(classifierOut2, varTarget2)  
    # weighting btween MSE and BCE respectively:
    loss = 0.1 * lossvalue1 + 0.9 * lossvalue2

    output = {
      'loss': loss,  # required
    }
    return output

def trans_train(self, x):
  # ...  some image transformations elided for simplicity

if __name__ == "__main__":
  classCount = 14
  model = LightningChestXrayCnnClassifierInceptionV3(classCount)
  trainer = pl.Trainer(max_epochs=15)
  trainer.fit(model)

When I look in the original source code, I see they simply do:

optimizer.zero_grad()
lossvalue.backward()
 optimizer.step()

which seems identical to what Pytorch Lightning would be doing. But again, what's not clear is how is this loss backpropagated correctly through:

a) the different models (autoencoder and inceptionv3) b) the different branches of inceptionv3

?

Should I configure multiple / weighted losses differently to ensure the correct losses are backpropagated to the respective models/branches of models? Or is setting up a single loss value in the training_step sufficient?

What's your environment?

github-actions[bot] commented 4 years ago

Hi! thanks for your contribution!, great first issue!

rohitgr7 commented 4 years ago

That's what the PyTorch autograd module handles itself. If during a forward pass a model or a branch of the model or a layer of the model is involved in calculating the final loss and is a parameter with requires_grad=True, it will be updated during gradient descent. For weighted loss, weighted gradients will be calculated in the first step of backward propagation w.r.t to the final loss. Setting up a single loss value in the training_step is all you need.

vr140 commented 4 years ago

Thanks! " For weighted loss, weighted gradients will be calculated in the first step of backward propagation w.r.t to the final loss."

Does this assume I use a WeightedLoss class? Instead of hand multiplying the weights myself ? ie What I'm wondering is how it knows to use the 0.2/0.8 weights for the branches of inception and then the 0.1/0.9 weights for Autoencoder vs Inception?

rohitgr7 commented 4 years ago

You don't have to use the WeightedLoss class. I think there isn't one. You can manually multiply it by some real number. To understand the second part you have to do some math manually by hand, can't explain it in chat though 😅. In a simple way take it as since you scaled your loss so during backprop this scaled loss will be used to calculate the gradients and eventually the gradients scale itself. When you add 2 or more losses, during backprop each of them receives the same gradient from the previous backprop step. To understand it in a better way, I suggest: https://youtu.be/i94OvYb6noo

vr140 commented 4 years ago

Got it. Thanks! What if instead of setting my own weights for the losses I wanted to learn them instead? What would I need to change?

s-rog commented 4 years ago

You can use a torch parameter for the weights (p and 1-p), but that would probably cause the network to lean towards one loss which defeats the purpose of using multiple losses.

If you want the weights to change during training you can have a scheduler to update the weight (increasing p with epoch/batch).

rohitgr7 commented 4 years ago

yeah, something like p = nn.Parameter(torch.tensor(0.5)) in your model init.

vr140 commented 4 years ago

Thank you!