Closed charlio23 closed 1 week ago
Hi Carles, they are indeed good suggestions! I wrote the formula for KronLLLaplace in App. B.1 of my paper; not sure why I haven't managed to implement that yet here!
Would you like to submit a pull request for this? No problem if you can't---I can also do this rather quickly.
Thank you very much for your fast reply.
I am not sure if my pull request would meet the coding standards of the repo, so if you already have experience, I would appreciate if you could do it instead. The only thing to note is that I noticed a slight difference of the resulting covariance matrix when using the new formula, and I am not sure if the error comes from numerical precision.
PS: In fact the formula I used is from your paper 😄.
@aleximmer what's the best way to go about this? For KronLLLaplace
seems like this should be implemented in KronDecomposed
in matrix.py
.
This change would be very useful in relation to #144, e.g. for LLMs where num. of classes is in the order of $10^4$.
When experimenting with Last Layer Laplace approximation on a classification task with large labels (1K, e.g. ImageNet), the memory rapidly jumps to 10~20Gb and inference becomes slower.
For context, for a last layer 128 -> 3100 (where 3100 is the number of labels), each inference step takes ~10 seconds and 15Gb using only one sample (batch_size=1) using Diagonal Laplace. Kronecker Laplace does not fit in GPU with this setup.
When inspecting the code, I noticed the Last Layer Laplace computes the Jacobian as $\phi(x)^T \otimes I$ explicitly, which is very costly when considering >1K labels as the resulting matrix has shape (num_labels, num_features*num_labels).
Note: $\phi(x)$ are the features of the L-1 layer given input x, and $I$ is an identity matrix of size (num_labels, num_labels)
For the case of Last Layer Laplace, I believe this could be implemented more efficiently considering the factorisation of the jacobian. For example, posterior variance for the kronecker Laplace could be implemented as follows:
Below you can find some code I used to speed up the Last Layer Laplace by modifying the code from the repository. It would be nice if you could consider making this modification as I am sure there are people interested in your package that might benefit from this.
Diagonal Laplace
Kronecker Laplace
Note1: The speedup is considerable. In my case, it went from 3 hours on the test set to 5 seconds using the Diagonal laplace. Note2: The Kronecker Laplace gives a slightly different result from the initial case, but the math is correctly implemented as far as I can tell. Still, empirically the results are reasonable.
Best,
Carles