f-dangel / backpack

BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.
https://backpack.pt/
MIT License
558 stars 55 forks source link

Support nn.GaussianNLLLoss #204

Open mseitzer opened 3 years ago

mseitzer commented 3 years ago

Hi,

I would like to apply the Cockpit library to my problem, which is using the Gaussian log-likelihood for training. If I only want to look at first-order information, this loss function should already work with Backpack. However, I would be very interested in also seeing the second-order informations, for which explicit support in Backpack is needed.

What would it take to integrate this loss? I might be able to contribute as well if it is not too complicated.

The documentation is here: https://pytorch.org/docs/stable/generated/torch.nn.GaussianNLLLoss.html

Thanks!

f-dangel commented 3 years ago

Hi,

I briefly glanced over the linked documentation. Two questions for your specific use case, (i) do you use GaussianNLLLoss with full=False (and do I understand the doc correctly that the log term will not be computed then?), and (ii) is var usually a trainable parameter?

The loss looks quite similar to MSELoss, and for it to be supported in BackPACK, a symmetric factorization of its Hessian w.r.t. input is required (here is the code for MSELoss, the docstring is slightly misleading and has to be updated). We don't yet have a documented process that explains how to add new layers to second-order extensions, but if you think you could provide the linked functionality, I'd be happy to integrate and test it.

For now, if you already want to start using cockpit, I suggest that you exclude the quantities that require second-order extensions, i.e. TIC and HessTrace (the HessMaxEV instrument purely relies on PyTorch's autodiff and should work).

mseitzer commented 3 years ago

Hi,

thanks for the response.

(i) do you use GaussianNLLLoss with full=False (and do I understand the doc correctly that the log term will not be computed then?)

The log term is always part of the loss. full=True only adds a constant to the loss which does not matter for the gradients (I typically include it to get the correct quantity for the likelihood).

(ii) is var usually a trainable parameter?

Yes, when using this loss function, var is a trainable parameter, or, more commonly, a function of the input.

Typical example:

loss_fn = nn.GaussianNLLLoss()
features = nn.Linear(3, 32)
mean = nn.Linear(32, 2)
variance = nn.Linear(32, 2)

x = features(input)
m = mean(x)
var = F.softplus(variance(x))

loss = loss_fn(m, target, var)

From what I understand about backpack, the fork in the graph might create problems here, so it is now just as simple as the MSE loss.

For now, if you already want to start using cockpit, I suggest that you exclude the quantities that require second-order extensions, i.e. TIC and HessTrace (the HessMaxEV instrument purely relies on PyTorch's autodiff and should work).

Yes, I will do this. Thanks!

ponykid commented 10 months ago

Thank you for your sharing. However, I think that there a mistask in the code. var = F.softplus(variance(x)) means that log(1+var(x)*2), however,the gaussiannlloss is 0.5 log(x2)+mse/var(x)2(attention the var(x)2) , not the softplus(var(x))2)