normal-computing / posteriors

Uncertainty quantification with PyTorch
https://normal-computing.github.io/posteriors/
Apache License 2.0
282 stars 12 forks source link

Add dense VI #59

Open SamDuffield opened 2 months ago

SamDuffield commented 2 months ago

Add dense covariance matrix variational inference (which might be useful in e.g. last layer approaches).

Gradients should be taken directly on the triangular Cholesky factor to ensure the covariance is always symmetric positive definite.

Probably best/easiest to store said Cholesky factor as a Tensor rather than a TensorTree as done in laplace.dense_fisher

gil2rok commented 2 months ago

Why is it preferable to represent the Cholesky factor as a Tensor and not a TensorTree?

SamDuffield commented 2 months ago

In general, for posteriors we choose to represent dense matrices relating to parameters as Tensors as it is unclear (at least to me) how to represent a dense matrix as a TensorTree

Suppose your parameters are

params = {'a' : torch.tensor([1.0]), 'b':  torch.tensor([2.0, 3.0])}

How would you then store a covariance matrix representing params such that you can access the covariance between a and b?

We decided it was much simpler and clearer here to rely on optree.integration.torch.tree_ravel and also makes many of the matrix operations (inverse, Cholesky) much easier

Hope that makes sense!

gil2rok commented 1 month ago

@SamDuffield Would be happy to attempt a PR for this.

Could you go into more (specific) detail about what code needs to be changed? Looked into the diag VI code and struggled to find what would be different in a dense VI implementation.

Seems like posteriors.utils.diag_normal_sample would need to be different? And potentially some other math derivations would change?

Have a couple hours to kill now so if you get back soon, would be happy to start today.

SamDuffield commented 1 month ago

That would be wonderful! (Sorry for not replying yesterday, was travelling this weekend 🥾 )

gil2rok commented 1 month ago

Will happily attempt this PR. Though will need to start in a week or two because of some prior commitments.