yaringal / multi-task-learning-example

A multi-task learning example for the paper https://arxiv.org/abs/1705.07115
MIT License
838 stars 206 forks source link

Why return torch.mean(loss)? #16

Open edshkim98 opened 3 years ago

edshkim98 commented 3 years ago

@yaringal Hi, I have a question about your multi-task loss function. Below you return a loss as torch.mean(loss), but if i undersatnd this function correctly, loss is just a single tensor value and not a list, so torch.mean(loss) will be same as loss. What was your motivation behind using torch.mean(loss)? Thank you!

def criterion(y_pred, y_true, log_vars):
  loss = 0
  for i in range(len(y_pred)):
    precision = torch.exp(-log_vars[i])
    diff = (y_pred[i]-y_true[i])**2.
    loss += torch.sum(precision * diff + log_vars[i], -1)
  return torch.mean(loss)
everye commented 3 years ago

Loss have batch_size values which have 20 tensor values, so it uses torch.mean(loss) to take the average of 20 values. @edshkim98