aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
441 stars 66 forks source link

Help with Kronecker-factored Laplace approximation for ResNet #210

Closed joemrt closed 2 weeks ago

joemrt commented 2 weeks ago

Hi,

I am trying to obtain a Kronecker-factored Laplace approximation for an entire Resnet (not only for the last layer) but obtain an error that some layers are not supported (I suspect this is due to the Batch-Norm Layers?). For example, using the WideResNet from your package, the following

from laplace import Laplace
from examples.helper import wideresnet as wrn
import examples.helper.dataloaders as dl

train_loader = dl.CIFAR10(train=True, batch_size=100)
net = wrn.WideResNet(16, 4, num_classes=10).cuda().eval()

la = Laplace(net, likelihood="classification",subset_of_weights='all',
             hessian_structure='kron')
la.fit(train_loader)

throws the error NotImplementedError: Found parameters in un-supported layers. I tried selecting only convolutional and linear layers via your subnetwork option, but this would only work with a full hessian which is unfeasible.

Is there a way around this or is it only possible to treat ResNets via their last layers with your package?

Thanks a lot for your help!

wiseodd commented 2 weeks ago

Try switching off the unsupported layers' grads, i.e. only conv and linear layers have requires_grad = True.

Doc: https://github.com/aleximmer/Laplace?tab=readme-ov-file#laplace-on-llm

joemrt commented 2 weeks ago

Thank you @wiseodd for your reply, this did the trick! For anyone else running into the same problem, here is what I added before the first call of Laplace above:

layers_to_ignore = [nn.BatchNorm2d]
for module in net.modules():
    if type(module) in layers_to_ignore:
        for par in module.parameters():
            par.requires_grad = False