aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
461 stars 72 forks source link

Questions about Subnetwork #86

Closed FrederikWarburg closed 2 years ago

FrederikWarburg commented 2 years ago

Hi!

Thanks for the cool new feature about the subnetwork. I have some comments and questions.

1) Bug in check for subnetwork

Running on GPU, it fails as you have a check for torch.longtensor, but the indices will be of type torch.cuda.longtensor and the program fails. Possible fix in utils/subnetwork line 94 change to:

        elif not ((isinstance(subnetwork_indices, torch.LongTensor) or isinstance(subnetwork_indices, torch.cuda.LongTensor)) and
            subnetwork_indices.numel() > 0 and len(subnetwork_indices.shape) == 1):
            raise ValueError('Subnetwork indices must be non-empty 1-dimensional torch.LongTensor.')

2) Hyper optimizer

In your standard example, you first have trainer.fit() followed by hyperparameter optimisation like this:

    log_prior, log_sigma = torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True)
    hyper_optimizer = torch.optim.Adam([log_prior, log_sigma], lr=1e-1)
    for i in range(n_epochs):
        hyper_optimizer.zero_grad()
        neg_marglik = - la.log_marginal_likelihood(log_prior.exp(), log_sigma.exp())
        neg_marglik.backward()
        hyper_optimizer.step()

    print(f"sigma={la.sigma_noise.item():.2f}")
    print(f"prior precision={la.prior_precision.item():.2f}")

Can you provide an example of how you would do something similar with the subnetwork? When I try naively, I get a dimension error:

return (delta * self.prior_precision_diag) @ delta
RuntimeError: The size of tensor a (2701) must match the size of tensor b (100) at non-singleton dimension 0

Let me know if I should provide more details on network etc.

3) DiagLaplace for subnetwork

I would like to use the diagonal hessian structure for subnetwork. Could you provide me with some pointers to how I would do this? If I understand correctly, I cannot just:

import torch
subnetwork_indices = torch.tensor([diagonal elements])

as this will also account for the correlation between the diagonal elements. What would be the best way to combine subnetwork and laplace.DiagLaplace?

4) Strange warning:

If I code my network like this:

class Model(nn.Module):
            def __init__(self):
                super(Model, self).__init__()
                self.l1 = torch.nn.Linear(1, 50)
                self.l2 = torch.nn.Linear(50, 100)
                self.l3 = torch.nn.Linear(100, 50)
                self.l4 = torch.nn.Linear(50, 1)

                self.tanh = torch.nn.Tanh()

            def forward(self, x):
                x = self.tanh(self.l1(x))
                x = self.tanh(self.l2(x))
                x = self.tanh(self.l3(x))
                x = self.tanh(self.l4(x))
                return x

    subnetwork_mask = ModuleNameSubnetMask(model, module_names=['l1'])
    subnetwork_mask.select()
    subnetwork_indices = subnetwork_mask.indices

I get the following warning:

UserWarning: Extension saving to grad_batch does not have an extension for Module <class '__main__.get_model.<locals>.Model'> although the module has parameters

However, if I code the network like this:

torch.nn.Sequential(
            torch.nn.Linear(1, 50), 
            torch.nn.Tanh(), 
            torch.nn.Linear(50, 100), 
            torch.nn.Tanh(), 
            torch.nn.Linear(100, 50), 
            torch.nn.Tanh(), 
            torch.nn.Linear(50, 1),
        )

    subnetwork_mask = ModuleNameSubnetMask(model, module_names=['0'])
    subnetwork_mask.select()
    subnetwork_indices = subnetwork_mask.indices

I do not get any warnings. Do you know what the warning means? and if I should be careful with the first implementation?

Phoveran commented 2 years ago

Facing the same issues of 1. and 2. For 2. I'm using la.optimize_prior_precision(method='marglik') For 3. it's mentioned by the author that they think this doesn't make sense(However I think it could be useful...) https://github.com/AlexImmer/Laplace/pull/58#issue-1067104633

edaxberger commented 2 years ago

Thanks @FrederikWarburg for raising this issue and @Phoveran for commenting, and thanks to both for your interest in the subnetwork Laplace feature, this is very much appreciated (also sorry for the late response due to the ICML deadline)!

I just opened a PR #87 that addresses your comments. Feel free to take a look at the PR and try out the corresponding branch; let me know if you encounter any further issues (either here or directly in the PR).

Detailed comments:

  1. Thanks for pointing out that bug, I fixed that!
  2. Good catch! I hadn't tested the subnet LA with the marginal likelihood optimization so didn't realise it throws an error. The reason I didn't test it is that I do not think it's a sensible thing to do, as the marginal likelihood objective probably doesn't do what you want if you just compute it over a subnetwork (i.e. I think making the marginal likelihood work properly with subnetworks is an open research problem). I fixed the dimension error, so you should technically be able to use it, but be careful, you might not get sensible results. Perhaps better to use cross-validation for hyperparameter tuning. But it might also work -- do let me know if you get good results with marginal likelihood optimization :)
  3. We didn't support this initially as we didn't think people would be interested in it. Could you elaborate why you want to use a diagonal Hessian approximation over a subnetwork? Can't you just use a diagonal Hessian over the full model? Anyways, I've added this feature now; feel free to try it out and let me know what you think / if it works!
  4. This warning is raised by the BackPACK library, which we use as a backend for the Hessian computations; BackPACK only supports certain kinds of models, and seems to work best with torch.nn.Sequential models; see this page of their documentation for details. I'm not sure what exactly the warning means, but wouldn't expect it to be critical. Feel free to raise an issue on the BackPACK repo if you have any issues with your models or have questions on their model support. We also support the ASDL library as a second backend which supports different models and might work in some cases in which BackPACK doesn't (and vice versa). Feel free to play around with different backends if things don't work. You can do so by passing the backend argument to Laplace(); possible values include BackPackGGN, BackPackEF, AsdlGGN, AsdlEF (all in laplace.curvature).
Phoveran commented 2 years ago

Thanks for your work! For 3. I wonder if the lottery ticket hypothesis in pruning still counts in laplace situation. However the subnetwork I found maybe too big to use full hessian. I don't know if it makes sense, but I think it deserves a try.

edaxberger commented 2 years ago

Could you elaborate on what exactly you mean when referring to the LTH in this context? That the diagonal Hessian might perform as well as the full Hessian on certain subnetworks?

Phoveran commented 2 years ago

Yes. That's my guess.

edaxberger commented 2 years ago

I see. It's an interesting thought, but I think there is some empirical evidence that capturing correlations is generally favourable over a diagonal approximation. E.g. in our subnetwork inference paper, we showed that estimating a full Hessian over just a small subnetwork can outperform diagonal Laplace over the full model (see e.g. Fig. 4). But there might also be cases / subnetworks where a diagonal posterior is as good as a full posterior, not sure.

Phoveran commented 2 years ago

I see, thanks for your information!

edaxberger commented 2 years ago

I'll close this issue for now (as PR #87 should address the concerns raised once merged in) -- feel free to re-open (or open another issue) if anything else arises, and thanks again for your interest in our library @FrederikWarburg and @Phoveran!