aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
471 stars 72 forks source link

Help with Kronecker-factored Laplace approximation for ResNet #210

Closed joemrt closed 4 months ago

joemrt commented 4 months ago

Hi,

I am trying to obtain a Kronecker-factored Laplace approximation for an entire Resnet (not only for the last layer) but obtain an error that some layers are not supported (I suspect this is due to the Batch-Norm Layers?). For example, using the WideResNet from your package, the following

from laplace import Laplace
from examples.helper import wideresnet as wrn
import examples.helper.dataloaders as dl

train_loader = dl.CIFAR10(train=True, batch_size=100)
net = wrn.WideResNet(16, 4, num_classes=10).cuda().eval()

la = Laplace(net, likelihood="classification",subset_of_weights='all',
             hessian_structure='kron')
la.fit(train_loader)

throws the error NotImplementedError: Found parameters in un-supported layers. I tried selecting only convolutional and linear layers via your subnetwork option, but this would only work with a full hessian which is unfeasible.

Is there a way around this or is it only possible to treat ResNets via their last layers with your package?

Thanks a lot for your help!

wiseodd commented 4 months ago

Try switching off the unsupported layers' grads, i.e. only conv and linear layers have requires_grad = True.

Doc: https://github.com/aleximmer/Laplace?tab=readme-ov-file#laplace-on-llm

joemrt commented 4 months ago

Thank you @wiseodd for your reply, this did the trick! For anyone else running into the same problem, here is what I added before the first call of Laplace above:

layers_to_ignore = [nn.BatchNorm2d]
for module in net.modules():
    if type(module) in layers_to_ignore:
        for par in module.parameters():
            par.requires_grad = False