Closed icetube23 closed 1 year ago
Thanks for the detailed description of the issue. You actually did not misuse the package but rather this is a current limitation of our package. In fact, image-outputs are quite tricky with Laplace mostly due to the Hessian approximation it requires. However, there is a recent paper that tackles this problem in the context of VAEs and proposes an efficient Hessian/GGN approximation for this particular case: https://arxiv.org/pdf/2206.15078.pdf
Regarding your questions:
nn.Flatten(start_dim=1)
and the observations are correspondingly flattened as well. Then you can use The full network or subnetwork options but it will compute 128 128 backward passes with a good approximation so I don't think this is a good solution in this case. The solution we will implement for the last layer will be practical and not require this. In case you are still interested in the hacky version, you can find it below with some further tricks: 1) use backend=BackPackEF
which does not require computing the Jacobians for the GGN (-> 1 instead of 128128 backward passes), 2) use a sampling-based predictive of the neural network instead of the linearized improved predictive, and 3) reduced dimensionality just so I could test it faster. Most likely this will not give very good performance but could be fine with some tuning of the prior precision parameter and choosing the subnetwork_indices to be just the last layer. If you want to tune the predictive performance of this approach, you would need to pass prior_precision=delta
where delta
is a very large value, for example 1000. This is to concentrate the posterior due to the rough approximations necessary for this approach.I hope this helps. Otherwise, please let me know.
from laplace.curvature.backpack import BackPackEF
model = nn.Sequential(
nn.Conv2d(1, 4, (3, 3), padding='same'),
nn.ConvTranspose2d(4, 2, (2, 2), stride=2),
nn.Conv2d(2, 1, (3, 3), padding='same'),
nn.Flatten(start_dim=1)
)
dataset = torch.utils.data.TensorDataset(torch.randn(8, 1, 16, 16), torch.randn(8, 32 * 32))
dataloader = torch.utils.data.DataLoader(dataset)
la = laplace.Laplace(
model,
likelihood='regression',
subset_of_weights='all', # or subnetwork
hessian_structure='full',
backend=BackPackEF
# subnetwork_indices=torch.tensor([], dtype=torch.int64)
)
la.fit(dataloader)
print(la.H.shape)
f_mu, f_var = la(torch.randn(8, 1, 16, 16), pred_type='nn', link_approx='mc', n_samples=100)
print(f_mu.shape, f_var.shape)
Thanks a lot for your input! I will take a look at the paper you provided and play around with the hacky solution to see if I can fit it to my needs. If I come across further questions, I'll get back to you and reopen the issue.
I'll also look forward to any future extensions to this package. :)
Hello,
first of all, I want to appreciate the work and effort that went into developing this helpful package!
I have a specific use-case for which I would love to use this package. I am interested in obtaining an uncertainty approximation for one of my models, similar to the sinusoidal toy example. However, it seems that it is not so straight-forward for my specific problem.
My problem consists of an image regression task (i.e., the inputs and outputs of my model are images) that is solved by a purely convolutional neural network. More precisely, I am given a gray-scale image (e.g., 64x64) and I want to predict a slightly revised image at twice the resolution (e.g., 128x128). Here, is a minimal working example of a CNN that solves such a task:
I tried applying Laplace to that model using different values for the
subset_of_weights
argument but ultimately failed.\ Continuing the MWE, I created some dummy data:and tried to fit a Laplace approximation to it. I tried:
This lead to the following error:
ValueError: Use model with a linear last layer.
Stacktrace (truncated)
``` ValueError Traceback (most recent call last) LaplaceRedux.ipynb Cell 6 inNext, I tried approximating the whole network:
Which again results in an error:
ValueError: Only 2D inputs are currently supported for MSELoss.
Stacktrace (truncated)
``` ValueError Traceback (most recent call last) LaplaceRedux.ipynb Cell 6 inFinally, I tried:
Which resulted in:
IndexError: index 1 is out of bounds for dimension 1 with size 1
Stacktrace (truncated)
``` IndexError Traceback (most recent call last) LaplaceRedux.ipynb Cell 6 inAll examples were conducted using a virtual environment, Python 3.8.10, and laplace-torch 0.1a2 on Ubuntu 20.04.3 LTS.
I suspect that none of these are actual bugs but rather me misusing the package. However, this still leaves me with a few questions regarding my desired use-case:
Sorry for the wall of text and thanks in advance for your help!