aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
436 stars 63 forks source link

Can't use subnetwork inference on custom module #127

Open ruili-pml opened 1 year ago

ruili-pml commented 1 year ago

Hi,

I tried to use the subnetwork inference on my model and it didn't work, and I pinpointed the bug down to the issue that the Hessian is not being computed for my custom pytorch module. I was wondering how can I solve this? Thank you.

Kind regards, Rui

aleximmer commented 1 year ago

Hi, could you post code or describe the module? In general, there are two ways:

  1. implement the relevant extension for the module in backpack or asdl yourself
  2. implement your module in terms of standard modules (e.g. torch.nn.Linear) combined with non-parametrized transformations like reshape, view It depends on your module what's possible and easier.
ruili-pml commented 1 year ago

Hi, thank you for your quick reply. If you happen to know the FiLM layer (https://arxiv.org/abs/1709.07871), it's the module I'm using, basically what it does is apply an affine transformation for each channel of the image. Otherwise please see the attached code:

class SimpleFiLM(nn.Module):

    def __init__(self, num_channels):
        super().__init__()

        self.num_channels = num_channels

        self.scale = nn.Parameter(torch.zeros(num_channels))
        self.shift = nn.Parameter(torch.zeros(num_channels))

    def forward(self, x):

        """
        x: [batch_size, num_channels, height, width]
        """

        scale = self.scale + 1.

        cur_scale = scale.reshape(1, self.num_channels, 1, 1)
        cur_shift = self.shift.reshape(1, self.num_channels, 1, 1)

        x = cur_scale * x + cur_shift

        return x

I suppose in my case reimplementing it using standard modules might be easier? Previously I was only digging into asdl, I was wondering do you happen to know which backend is better suited?

aleximmer commented 1 year ago

I would recommend extending asdl since it's faster and more flexible wrt architectures. For this layer, I think it's easier to implement an extension in asdl following the scale and bias modules, which are essentially scalar scale and shift parameter modules (see https://github.com/kazukiosawa/asdl/tree/master/asdl/operations). However, their extension is only correct if you remove the scale = self.scale + 1. Is there any necessity for this? I would argue you can instead just add the one at initialization. Let me know if anything is unclear or in case you run into problems with this.

ruili-pml commented 1 year ago

Thank you for your reply. Unfortunately scale = self.scale + 1 is necessary, it's pointed out in section 7.2 Model Details in the appendix that it has a large influence on the performance. Does it mean I basically need to implement the whole module?

aleximmer commented 1 year ago

In this case, you need to implement the module-extension yourself but you can use the simple Bias and Scale operations as template and only extend it to the multivariate setting. The +1 will be correctly incorporated into the partial gradient of the output, i.e., out_grads in asdl so you don't need to handle it separately. I am happy to help if you run into problems, just let me know.

ruili-pml commented 1 year ago

I'll work on it when I'm back from the holiday and there's a high chance I'll run into problems, many thanks in advance :)