aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
436 stars 63 forks source link

Using Laplace with purely convolutional networks for Image Regression Tasks #113

Closed icetube23 closed 1 year ago

icetube23 commented 1 year ago

Hello,

first of all, I want to appreciate the work and effort that went into developing this helpful package!

I have a specific use-case for which I would love to use this package. I am interested in obtaining an uncertainty approximation for one of my models, similar to the sinusoidal toy example. However, it seems that it is not so straight-forward for my specific problem.

My problem consists of an image regression task (i.e., the inputs and outputs of my model are images) that is solved by a purely convolutional neural network. More precisely, I am given a gray-scale image (e.g., 64x64) and I want to predict a slightly revised image at twice the resolution (e.g., 128x128). Here, is a minimal working example of a CNN that solves such a task:

import torch.nn as nn
import laplace
import torch

class MyConvModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 4, (3, 3), padding='same')
        self.convt = nn.ConvTranspose2d(4, 2, (2, 2), stride=2)
        self.conv2 = nn.Conv2d(2, 1, (3, 3), padding='same')

    def forward(self, x):
        x = self.conv1(x)
        x = self.convt(x)
        x = self.conv2(x)
        return x

model = MyConvModel()

# dummy batch consisting of 16 "images"
x = torch.randn(16, 1, 64, 64)
# produces 16 128x128 images, i.e., y.shape == (16, 1, 128, 128)
y = model(x)

I tried applying Laplace to that model using different values for the subset_of_weights argument but ultimately failed.\ Continuing the MWE, I created some dummy data:

# create dummy data
dataset = torch.utils.data.TensorDataset(torch.randn(16, 1, 64, 64), torch.randn(16, 1, 128, 128))
dataloader = torch.utils.data.DataLoader(dataset)

and tried to fit a Laplace approximation to it. I tried:

la = laplace.Laplace(model,
                     likelihood='regression',
                     subset_of_weights='last_layer'
                     )
la.fit(dataloader)

This lead to the following error: ValueError: Use model with a linear last layer.

Stacktrace (truncated) ``` ValueError Traceback (most recent call last) LaplaceRedux.ipynb Cell 6 in () 1 la = laplace.Laplace(model, 2 likelihood='regression', 3 subset_of_weights='last_layer' 4 ) ----> 5 la.fit(dataloader) File .../lib/python3.8/site-packages/laplace/lllaplace.py:103, in LLLaplace.fit(self, train_loader, override) 101 with torch.no_grad(): 102 try: --> 103 self.model.find_last_layer(X[:1].to(self._device)) 104 except (TypeError, AttributeError): 105 self.model.find_last_layer(X.to(self._device)) File .../lib/python3.8/site-packages/laplace/utils/feature_extractor.py:137, in FeatureExtractor.find_last_layer(self, x) 135 layer = dict(self.model.named_modules())[key] 136 if len(list(layer.children())) == 0: --> 137 self.set_last_layer(key) 139 # save features from first forward pass 140 self._features[key] = act_out[key] File .../lib/python3.8/site-packages/laplace/utils/feature_extractor.py:80, in FeatureExtractor.set_last_layer(self, last_layer_name) 78 self.last_layer = dict(self.model.named_modules())[last_layer_name] ... ---> 80 raise ValueError('Use model with a linear last layer.') 82 # set forward hook to extract features in future forward passes 83 self.last_layer.register_forward_hook(self._get_hook(last_layer_name)) ValueError: Use model with a linear last layer. ```

Next, I tried approximating the whole network:

la = laplace.Laplace(model,
                     likelihood='regression',
                     subset_of_weights='all'
                     )
la.fit(dataloader)

Which again results in an error: ValueError: Only 2D inputs are currently supported for MSELoss.

Stacktrace (truncated) ``` ValueError Traceback (most recent call last) LaplaceRedux.ipynb Cell 6 in () 1 la = laplace.Laplace(model, 2 likelihood='regression', 3 subset_of_weights='all' 4 ) ----> 5 la.fit(dataloader) File .../lib/python3.8/site-packages/laplace/baselaplace.py:797, in KronLaplace.fit(self, train_loader, override) 794 # discount previous Kronecker factors to sum up properly together with new ones 795 self.H_facs = self._rescale_factors(self.H_facs, n_data_old / (n_data_old + n_data_new)) --> 797 super().fit(train_loader, override=override) 799 if self.H_facs is None: 800 self.H_facs = self.H File .../lib/python3.8/site-packages/laplace/baselaplace.py:377, in ParametricLaplace.fit(self, train_loader, override) 375 self.model.zero_grad() 376 X, y = X.to(self._device), y.to(self._device) --> 377 loss_batch, H_batch = self._curv_closure(X, y, N) 378 self.loss += loss_batch 379 self.H += H_batch File .../lib/python3.8/site-packages/laplace/baselaplace.py:777, in KronLaplace._curv_closure(self, X, y, N) 776 def _curv_closure(self, X, y, N): ... 105 """Raises an exception if the shapes of the input are not supported.""" 106 if not len(module.input0.shape) == 2: --> 107 raise ValueError("Only 2D inputs are currently supported for MSELoss.") ValueError: Only 2D inputs are currently supported for MSELoss. ```

Finally, I tried:

la = laplace.Laplace(model,
                     likelihood='regression',
                     subset_of_weights='subnetwork',
                     hessian_structure='full',
                     subnetwork_indices=torch.tensor([0, 2], dtype=torch.int64)
                     )
la.fit(dataloader)

Which resulted in: IndexError: index 1 is out of bounds for dimension 1 with size 1

Stacktrace (truncated) ``` IndexError Traceback (most recent call last) LaplaceRedux.ipynb Cell 6 in () 1 la = laplace.Laplace(model, 2 likelihood='regression', 3 subset_of_weights='subnetwork', 4 hessian_structure='full', 5 subnetwork_indices=torch.tensor([0, 2], dtype=torch.int64) 6 ) ----> 7 la.fit(dataloader) File .../lib/python3.8/site-packages/laplace/baselaplace.py:691, in FullLaplace.fit(self, train_loader, override) 689 def fit(self, train_loader, override=True): 690 self._posterior_scale = None --> 691 return super().fit(train_loader, override=override) File .../lib/python3.8/site-packages/laplace/baselaplace.py:377, in ParametricLaplace.fit(self, train_loader, override) 375 self.model.zero_grad() 376 X, y = X.to(self._device), y.to(self._device) --> 377 loss_batch, H_batch = self._curv_closure(X, y, N) 378 self.loss += loss_batch 379 self.H += H_batch File .../lib/python3.8/site-packages/laplace/baselaplace.py:687, in FullLaplace._curv_closure(self, X, y, N) 686 def _curv_closure(self, X, y, N): ... ---> 42 out[:, i].sum().backward() 43 else: 44 out.sum().backward() IndexError: index 1 is out of bounds for dimension 1 with size 1 ```

All examples were conducted using a virtual environment, Python 3.8.10, and laplace-torch 0.1a2 on Ubuntu 20.04.3 LTS.

I suspect that none of these are actual bugs but rather me misusing the package. However, this still leaves me with a few questions regarding my desired use-case:

Sorry for the wall of text and thanks in advance for your help!

aleximmer commented 1 year ago

Thanks for the detailed description of the issue. You actually did not misuse the package but rather this is a current limitation of our package. In fact, image-outputs are quite tricky with Laplace mostly due to the Hessian approximation it requires. However, there is a recent paper that tackles this problem in the context of VAEs and proposes an efficient Hessian/GGN approximation for this particular case: https://arxiv.org/pdf/2206.15078.pdf

Regarding your questions:

I hope this helps. Otherwise, please let me know.


from laplace.curvature.backpack import BackPackEF

model = nn.Sequential(
    nn.Conv2d(1, 4, (3, 3), padding='same'),
    nn.ConvTranspose2d(4, 2, (2, 2), stride=2),
    nn.Conv2d(2, 1, (3, 3), padding='same'),
    nn.Flatten(start_dim=1)
)

dataset = torch.utils.data.TensorDataset(torch.randn(8, 1, 16, 16), torch.randn(8, 32 * 32))
dataloader = torch.utils.data.DataLoader(dataset)

la = laplace.Laplace(
    model,
    likelihood='regression',
    subset_of_weights='all',  # or subnetwork
    hessian_structure='full',
    backend=BackPackEF
    #  subnetwork_indices=torch.tensor([], dtype=torch.int64)
)
la.fit(dataloader)
print(la.H.shape)
f_mu, f_var = la(torch.randn(8, 1, 16, 16), pred_type='nn', link_approx='mc', n_samples=100)
print(f_mu.shape, f_var.shape)
icetube23 commented 1 year ago

Thanks a lot for your input! I will take a look at the paper you provided and play around with the hacky solution to see if I can fit it to my needs. If I come across further questions, I'll get back to you and reopen the issue.

I'll also look forward to any future extensions to this package. :)