aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
436 stars 63 forks source link

Last Layer Laplace predictions could be computed much faster and becomes problematic for large label classification. #138

Closed charlio23 closed 1 week ago

charlio23 commented 8 months ago

When experimenting with Last Layer Laplace approximation on a classification task with large labels (1K, e.g. ImageNet), the memory rapidly jumps to 10~20Gb and inference becomes slower.

For context, for a last layer 128 -> 3100 (where 3100 is the number of labels), each inference step takes ~10 seconds and 15Gb using only one sample (batch_size=1) using Diagonal Laplace. Kronecker Laplace does not fit in GPU with this setup.

When inspecting the code, I noticed the Last Layer Laplace computes the Jacobian as $\phi(x)^T \otimes I$ explicitly, which is very costly when considering >1K labels as the resulting matrix has shape (num_labels, num_features*num_labels).

Note: $\phi(x)$ are the features of the L-1 layer given input x, and $I$ is an identity matrix of size (num_labels, num_labels)

For the case of Last Layer Laplace, I believe this could be implemented more efficiently considering the factorisation of the jacobian. For example, posterior variance for the kronecker Laplace could be implemented as follows:

Below you can find some code I used to speed up the Last Layer Laplace by modifying the code from the repository. It would be nice if you could consider making this modification as I am sure there are people interested in your package that might benefit from this.

Diagonal Laplace

    def _glm_predictive_distribution(self, X):
        f_mu, phi = self.model.forward_with_features(X)
        emb_size = phi.shape[-1]
        f_var torch.diag_embed(torch.matmul(self.posterior_variance.reshape(-1, emb_size),(phi*phi).transpose(0,1)).transpose(0,1))
        return f_mu.detach(), f_var.detach()

Kronecker Laplace

    def _glm_predictive_distribution(self, X):
        f_mu, phi = self.model.forward_with_features(X)

        eig_U, eig_V = self.posterior_precision.eigenvalues[0]
        vec_U, vec_V = self.posterior_precision.eigenvectors[0]
        delta = self.posterior_precision.deltas.sqrt()
        inv_U_eig, inv_V_eig = torch.pow(eig_U + delta, -1), torch.pow(eig_V + delta, -1)

        phiT_Q = torch.matmul(Js[:,None,:], vec_V[None,:,:])
        phiTVphi = torch.matmul(torch.matmul(phiT_Q, torch.diag(inv_V_eig)[None,:,:]), phiT_Q.transpose(1,2))

        f_var = phiTVphi*((vec_U @ torch.diag(inv_U_eig) @ vec_U.T)[None,:,:])

        return f_mu.detach(), f_var.detach()

Note1: The speedup is considerable. In my case, it went from 3 hours on the test set to 5 seconds using the Diagonal laplace. Note2: The Kronecker Laplace gives a slightly different result from the initial case, but the math is correctly implemented as far as I can tell. Still, empirically the results are reasonable.

Best,

Carles

wiseodd commented 8 months ago

Hi Carles, they are indeed good suggestions! I wrote the formula for KronLLLaplace in App. B.1 of my paper; not sure why I haven't managed to implement that yet here!

Would you like to submit a pull request for this? No problem if you can't---I can also do this rather quickly.

charlio23 commented 8 months ago

Thank you very much for your fast reply.

I am not sure if my pull request would meet the coding standards of the repo, so if you already have experience, I would appreciate if you could do it instead. The only thing to note is that I noticed a slight difference of the resulting covariance matrix when using the new formula, and I am not sure if the error comes from numerical precision.

PS: In fact the formula I used is from your paper 😄.

wiseodd commented 4 months ago

@aleximmer what's the best way to go about this? For KronLLLaplace seems like this should be implemented in KronDecomposed in matrix.py.

This change would be very useful in relation to #144, e.g. for LLMs where num. of classes is in the order of $10^4$.