Open wiseodd opened 2 years ago
Beyond CIFAR images, though, the GGN computation will also be an issue. E.g. in ImageNet the output dim is 224*224*3=150528
, much larger than 3024
of CIFAR. I talked Felix about this and one solution is to exploit the per-pixel nature of the loss and compute the minibatch-GGN in chunk in terms of output dimension, see example for MNIST below.
Thoughts?
from backpack import backpack, extend
from backpack.custom_module.slicing import Slicing
from backpack.extensions import DiagGGNExact
lastlayer = extend(model.last_layer)
lossfunc = extend(nn.MSELoss(reduction='sum'))
chunked_ggn = torch.zeros_like(model.last_layer[0].weight)
for x, _ in trainloader:
x = x.to(DEVICE)
# [N, 784]
reconstruction = lastlayer(model.feature_extractor(x))
for i in range(28):
slicing = (slice(None), slice(i * 28, (i + 1) * 28))
slicing_module = extend(Slicing(slicing))
sliced_reconstruction = slicing_module(reconstruction)
sliced_loss = lossfunc(sliced_reconstruction, x.flatten(1)[slicing])
with backpack(DiagGGNExact(), retain_graph=True):
sliced_loss.backward(retain_graph=True)
chunked_ggn += model.last_layer[0].weight.diag_ggn_exact
For predictions/reconstructions, my proposal is to use https://github.com/f-dangel/unfoldNd. Using this, then conv_transpose2d
is just a matrix multiplication under the original weights/filters, implying that we can easily obtain $p(f(x))$.
import unfoldNd
prec0 = 1
# Laplace cov
diag_Sigma = 1/(diag_GGN + prec0)
diag_Sigma = diag_Sigma.transpose(0, 1).flatten(1)
# diag_Sigma.shape should be (c_out, c_in*k*k
# )
assert len(diag_Sigma.shape) == 2 and diag_Sigma.shape == (1, 100*3*3)
# Following the last layer of the model
unfold_transpose = unfoldNd.UnfoldTransposeNd(
kernel_size=3, dilation=1, padding=1, stride=3
)
@torch.no_grad()
def reconstruct(x):
phi = model.feature_extractor(x)
# MAP prediction
mean_pred = model.last_layer(phi).reshape(x.shape)
# Variance
J_pred = unfold_transpose(phi)
var_pred = torch.einsum('bij,ki,bij->bkj', J_pred, diag_Sigma, J_pred).reshape(mean_pred.shape)
return mean_pred.cpu().numpy(), var_pred.cpu().numpy()
x_recons = []
for x, _ in testloader:
x = x.cuda()
x_recons.append(reconstruct(x))
Full, self-contained prototype here: https://gist.github.com/wiseodd/b8d57fa029f876e00b336b7b3b5052bd
Hello @wiseodd , have there been any updates on this topic over the last year? I am currently working on Laplace approximations for segmentation tasks and would be very interested. Thank you!
Unfortunately, there's no update on this. Partly because the loss function usually used in image problems (BCELoss) is not supported by the Hessian backends, and partly because my research agenda is far away from computer vision/graphics.
In any case, I can point you to a good direction:
In any case, I hope the references above and the snippets in the previous posts are useful for you.
This issue should be easier to solve once #145 is merged. Will work on this after the release of milestone 0.2.
@AlexImmer, @runame, @edaxberger: As you know, I'm currently working on last-layer Laplace for img2img tasks, e.g. autoencoder, image segmentation. We can't use the current implementation in this library mainly due the fact that we hard-code the last-layer Jacobian to be the fully-connected Jacobian---see #111 for example. Note: GGN computation using BackPACK & ASDL doesn't seem to pose any problem (#111 for ASDL, below for BackPACK).
So, my current thinking is to simply generalizing the
last_layer_jacobians
inlaplace/curvature/curvature.py
usingfunctorch
, see thepredict
function below. I also propose to only support diagonal LLLA since it's too costly otherwise.Let me know your thoughts and if I missed anything. Feel free to try out the self-contained script below.