stan-dev / math

The Stan Math Library is a C++ template library for automatic differentiation of any order using forward, reverse, and mixed modes. It includes a range of built-in functions for probabilistic modeling, linear algebra, and equation solving.
https://mc-stan.org
BSD 3-Clause "New" or "Revised" License
723 stars 183 forks source link

[WIP] Sparse cholesky #3069

Open dpsimpson opened 1 month ago

dpsimpson commented 1 month ago

Summary

Decided to dive in on the sparse cholesky. This is a WIP.

The interesting wrinkle here is that

Eigen::SimplicialLLT<SpMat>(A) llt =  m.llt()

does not (without a lot of coercing) perform a Cholesky decomposition of A. It instead performs a Cholesky of A.twistedBy(perm), where perm is stored in the Eigen::SimplicialLLT<T> class. This means that it is not enough for cholesky_decompose to return

llt.matrixL()

as this will not be the matrix we are looking for. This is the problem with @SteveBronder's old branch https://github.com/stan-dev/math/blob/a97bfa2f9a418bba192666133a28930651066fe7/stan/math/prim/mat/fun/cholesky_decompose.hpp#L114

The two options here are to compute perm explicitly somewhere else and carry it around (I do not like this option). The other option is to carry treat Eigen::SimplicialLLT<T> as if it were the lower triangular matrix when performing triangular solves, multiplication, and anything else we might need. (those operations require knowledge of perm).

I'm not quite sure how my preferred option will work with autodiff - we might need a light specialization.

Tests

Coming

Side Effects

By adding a new, non-matrix Eigen type, we need to be very careful that none of the other template patterns match it. It should be fine, but care is needed.

Release notes

Replace this text with a short note on what will change if this pull request is merged in which case this will be included in the release notes.

Checklist

dpsimpson commented 1 month ago

argh sorry. didn't expect the CI to flag with a draft WIP PR.

spinkney commented 1 month ago

If you figure out the permutation thing I think that would also apply to the LDL decomposition because it does pivoting. I've been wanting to get that into stan-lang but I don't know how to differentiate and handle that issue.

dpsimpson commented 1 month ago

Do you think it needs to be exposed at the language level? To me, things like pivoting are just implementation details that should be abstracted away. If this thing works with sparse matrices, it should be the same there too.

spinkney commented 1 month ago

I mean the derivatives are what's hard. At the language level no one needs to know the pivoting. I have an implementation of a cholesky factor of correlation matrices but it constructs the LD factor instead and removes the need for the square root. So users could have a ldl_factor_correlation that is a tuple object of a lower triangular matrix and a vector that holds the D diagonal. It makes sense to have the ability to construct an LDL tuple from a PD matrix too hence why I want the LDL factorization.

dpsimpson commented 1 month ago

Permutation matrices are linear and have determinant one, so I don't think they affect the derivatives (it is 8am and I've not had coffee, so I might be wrong)

dpsimpson commented 1 month ago

In my mind, this would not be terribly difficult to implement, but it would add the wrinkle that the vari type would now have a solver type for val_ and a matrix type for adj_. I'm not deep enough into the autodiff system to know if that would be a disaster (maybe for higher order? afaik the only place val and adj meet is in the chain methods, which are hand-written).

SteveBronder commented 1 month ago

In my mind, this would not be terribly difficult to implement, but it would add the wrinkle that the vari type would now have a solver type for val and a matrix type for adj. I'm not deep enough into the autodiff system to know if that would be a disaster (maybe for higher order? afaik the only place val and adj meet is in the chain methods, which are hand-written).

The return type of the function has to be a matrix or tuple of matrices for the stan language, but you can store the actual solver to use in the reverse pass like we do here for mdivide_left_ldlt

https://github.com/stan-dev/math/blob/develop/stan/math/rev/fun/mdivide_left_ldlt.hpp#L42

SteveBronder commented 1 month ago

Also is the permutation matrix fixed for each iteration of the sampler? If so would we want a cholesky signature where the user passes in the permutation matrix?