lorenmt / mtan

The implementation of "End-to-End Multi-Task Learning with Attention" [CVPR 2019].
https://shikun.io/projects/multi-task-attention-network
MIT License
673 stars 109 forks source link

Mean of the sum with only one element #23

Closed joaopalotti closed 4 years ago

joaopalotti commented 4 years ago

Dear Shikun Liu,

First of all, thanks for sharing your code.

I was checking it out and noticed that it has this line for your final loss used by both the dwa and equal weighting schemes:

loss = torch.mean(sum(lambda_weight[i, index] * train_loss[i] for i in range(3)))

In this line of code, I believe the torch.mean is not doing anything and might be confusing, as the input vector will be 1x1d, i.e., the sum of lambda * train_loss is 1x1-d.

Honestly, I am not sure if dividing the total loss by 3 would make any difference, but just to clarify, your intent was to have something like:

torch.mean(torch.Tensor([lambda_weight[i, index] * train_loss[i] for i in range(3)]))

or torch.sum(torch.Tensor([lambda_weight[i, index] * train_loss[i] for i in range(3)]))

Thanks,

Joao

https://github.com/lorenmt/mtan/blob/cb017f0ad19b9913cab0a128200d34483e816391/im2im_pred/model_segnet_mtan.py#L321

lorenmt commented 4 years ago

Hi Joao,

Thanks for this.

Yes, the torch.mean was supposed to average across batch, rather than task losses, considering the case when train_loss is a loss vector for each input image.

Since we want to make sure each task-specific attention module is updated with a similar scale of gradients, we directly applied a summation across all task losses.

So, your second proposal is the one I meant to write.

I will update the code accordingly.

Thanks for the notice.

Sk.