Open SamDuffield opened 2 months ago
Why is it preferable to represent the Cholesky factor as a Tensor
and not a TensorTree
?
In general, for posteriors
we choose to represent dense matrices relating to parameters as Tensor
s as it is unclear (at least to me) how to represent a dense matrix as a TensorTree
Suppose your parameters are
params = {'a' : torch.tensor([1.0]), 'b': torch.tensor([2.0, 3.0])}
How would you then store a covariance matrix representing params
such that you can access the covariance between a
and b
?
We decided it was much simpler and clearer here to rely on optree.integration.torch.tree_ravel
and also makes many of the matrix operations (inverse, Cholesky) much easier
Hope that makes sense!
@SamDuffield Would be happy to attempt a PR for this.
Could you go into more (specific) detail about what code needs to be changed? Looked into the diag
VI code and struggled to find what would be different in a dense
VI implementation.
Seems like posteriors.utils.diag_normal_sample
would need to be different? And potentially some other math derivations would change?
Have a couple hours to kill now so if you get back soon, would be happy to start today.
That would be wonderful! (Sorry for not replying yesterday, was travelling this weekend 🥾 )
vi.diag
as you noted above is that the covariance should be stored as a Tensor
rather than a TensorTree
, this likely will make judicious use of optree.integration.torch.tree_ravel
and follow the API of our laplace.dense_fisher
implementation.Tensor
we might not need a bespoke sampling function over using torch.distributions.MultivariateNormal
. Indeed, the vi.dense.sample
function can be exactly the same as laplace.dense_fisher.sample
Will happily attempt this PR. Though will need to start in a week or two because of some prior commitments.
Add dense covariance matrix variational inference (which might be useful in e.g. last layer approaches).
Gradients should be taken directly on the triangular Cholesky factor to ensure the covariance is always symmetric positive definite.
Probably best/easiest to store said Cholesky factor as a
Tensor
rather than aTensorTree
as done inlaplace.dense_fisher