aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
436 stars 63 forks source link

Weird issue when Laplace is inside class/function #122

Closed wiseodd closed 4 months ago

wiseodd commented 1 year ago

[Tested in the master branch]

Say you have this class:

class Test():

    def __init__(self, train_loader, device='cpu'):
        super().__init__()

        self.nn = torch.nn.Sequential(
            torch.nn.Linear(1, 50), torch.nn.Tanh(), torch.nn.Linear(50, 1)
        ).to(device)
        self.train_loader = train_loader
        self.nn.eval()

        self.bnn = Laplace(
            self.nn, 'regression',
            subset_of_weights='all', hessian_structure='kron',
        )
        self.bnn.fit(self.train_loader)
        self.bnn.optimize_prior_precision(n_steps=10)

Then, this works fine:

model1 = Test(train_loader)
_, _ = model1.bnn(X_test)  
model2 = Test(train_loader)
_, _ = model2.bnn(X_test) 

But this

model1 = Test(train_loader)
model2 = Test(train_loader)
_, _ = model1.bnn(X_test)  
_, _ = model2.bnn(X_test) 

will throw error:

[...]
File "/Users/agustinuskristiadi/Projects/Laplace/laplace/curvature/backpack.py", line 47, in jacobians
    to_cat.append(param.grad_batch.detach().reshape(x.shape[0], -1))
AttributeError: 'Parameter' object has no attribute 'grad_batch'

Same with function. See a minimal script for reproducing them here: (Put it inside examples dir.) https://gist.github.com/wiseodd/62fadb452f77e488acc3716ed3822ac7.

When one removes these lines, things work. Felix currently doesn't have any idea why this happens from Backpack's side. https://github.com/AlexImmer/Laplace/blob/ada1c6f9a4aa879939de52c8fa454ff27cbfe5a4/laplace/curvature/backpack.py#L57-L58

Do we know from Laplace's side why might this happen?

wiseodd commented 4 months ago

This seems to be a bug of BackPACK. @f-dangel don't really know the cause though. We should just document the caveats of each backend on the README.

This bug is less relevant anyway with the change to the curvlinops backend as the default.

wiseodd commented 4 months ago

Let's collapse this with #82.