jonbarron / robust_loss_pytorch

A pytorch port of google-research/google-research/robust_loss/
Apache License 2.0
656 stars 88 forks source link

The Adaptive.py version doesn't change the values of alpha and scale during training #7

Closed ismarou closed 5 years ago

ismarou commented 5 years ago

Wow! Mr Barron, your work amazing! One of the best papers I have read in the last few years! The possibility of learning the Loss function in the training process makes me very excited :) However, when I downloaded your code and tried to replace an MSE Loss in a Regression CNN with your Adaptive one I noticed that the alpha and scale values don't change during the training procedure.

I'm posting my code below for you to have a clearer image:

from robust_loss_pytorch import adaptive

criterion=adaptive.AdaptiveLossFunction(6,torch.float32,device='cuda:0') output:Bx6 labels:Bx6 loss=criterion.lossfun(output-labels).mean() print(criterion.alpha()) print(criterion.scale())

Is it something wrong with this Pytorch implementation or is there a profound mistake of mine that you see? Thanks in advance!

jonbarron commented 5 years ago

Hi, thanks for the kind words.

It looks like you're just constructing and evaluating the loss, but not optimizing over the AdaptiveLossFunction parameters. For the model to estimate the alpha+scale parameters, you need to do some optimization first, making sure to include AdaptiveLossFunctions's internal parameters as free variables in that optimization.

ismarou commented 5 years ago

Sorry I didn't catch that at all, what do you mean? After I calculate the loss I have implemented all the necessary stuff for backpropagation and GD-based training: optimizer.zero_grad() loss.backward() optimizer.step() etc.,that's not the problem. Do you mean that I should put explicitely the Alpha and Scale parameters in the optimizer like below:

crit_params=torch.nn.ParameterList(criterion.parameters()) optimizer = optim.Adam(net.parameters() + crit_params, lr=1e-03) ?

jonbarron commented 5 years ago

Yep! I think so at least. If the optimizer doesn't know about the alpha+scale parameters, they wont be updated. It's hard to comment without seeing all of your code though.

ismarou commented 5 years ago

I think that's the problem. I'll inform you when the training finishes and I'll close the issue. Should all the parameters be put in the same optimizer or the ones from the criterion in a second one with a lower learrning rate in order to have smoother convergence?

jonbarron commented 5 years ago

In all of my experiments, I just dumped everything into the same optimizer and it worked well enough.

khornlund commented 5 years ago

If you normally do something like this:

trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(trainable_params, ... )

Instead do something like this:

from itertools import chain

trainable_params = filter(
    lambda p: p.requires_grad, 
    chain(model.parameters(), criterion.parameters())
)
optimizer = torch.optim.Adam(trainable_params, ... )

PyTorch supports per-parameter learning rates, see here.