aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
432 stars 61 forks source link

Handle feature reduction properly in LLLA #169

Closed wiseodd closed 3 weeks ago

wiseodd commented 2 months ago

Say, you have an LLM with a regression head on top. Then this code

f, phi = self.model.forward_with_features(x)
print(f.shape, phi.shape)
for p in self.model.last_layer.parameters():
    print(p.shape)
input()

outputs

torch.Size([4, 2]) torch.Size([4, 9, 768])
torch.Size([2, 768])

There is a mismatch between the last layer's inputs's dim and the last layer itself.

The best solution seems to let the user pass what kind of reduction they use. Common choices: first (the <CLS> token in BERT), last (in causal LMs), average (https://arxiv.org/abs/2402.05015). We can use enum for this.