Open fritzo opened 3 years ago
@eb8680 it looks like AutoGaussian(pyrocov_model)
runs out of GPU memory in constructing a low-rank matrices precision = sqrt @ sqrt.T
. One possible solution is to use a sqrt(precision) representation in funsor's Gaussian
. I guess the crux is whether we can implement cheap Gaussian tensordot without materializing intermediate low-rank precision matrices. @fehiepsi already worked out most of the sqrt representation in Pyro PR #2019, where ops.add
becomes mere concatenation.
@fehiepsi how much effort do you think it would it take for us to port your Pyro PR #2019 to funsor (where it would also be available in NumPyro π)?
Here is the optimized GFVE schedule for my pyro-cov model. It fits in main memory but runs out of GPU memory.
The crux is this pair of Gaussian contractions with over 1e9 elements
I believe we can work around this using a combination of @fehiepsi's prec_sqrt
representation https://github.com/pyro-ppl/pyro/pull/2019 and a ConditionalGaussian
that generalizes AffineNormal
. Happy to discuss.
My impression is most of the details can be preserved (e.g. block vector, block matrix, align gaussian). Back then, one issue was batch qr is very slow on GPU, but torch linalg seems to have been improved a lot since then.
@fehiepsi do you recall whether Cholesky was sufficient instead of QR? IIRC there was a PyTorch discussion about cheaply testing for positive definiteness or condition number using torch.linalg.cholesky_ex().
Looking at the code, I guess we need to triangulate a non-positive-definite precision matrix (e.g. zeros matrix) but I can't recall when we need such triangularization. :( Probably, it is unnecessary. (anyway, we can switch to qr if we face the positive definiteness issue)
@eb8680 want to pair code next week on the high-level algorithm for variable elimination, continuing our work from https://github.com/pyro-ppl/funsor/compare/tractable-for-gaussians ?
Sure!
Addresses https://github.com/pyro-ppl/pyro/pull/2929 See design doc
This issue tracks changes needed to efficiently perform variable elimination in Gaussian graphical models with plates. While
funsor.sum_product.sum_product()
is a partial solution, we'd like to generalize to a complete solution.Tasks
[x] Introduce a new Funsor
ConditionalGaussian(info_vec, precision, conditional, inputs)
representing the batched conditional distribution of the rightmost real input variable, conditioned on other real input variables. This could be (i) a new Funsor in addition toGaussian
, (ii) a replacement or generalization ofGaussian
, or (iii) a special case ofGaussian
where the inputinfo_vec
andprecision
are structured (requires #556). This may allow cheaper linear algebra.Alternatively #567 Temporary Workaround: naively scatter the three parameters
(info_vec, precision, conditional)
into a denseGaussian
. This can be much more computationally expensive..to_event()
.x_i --> y_ij <-- z_j
). Currentlysum_product()
and the TVE algorithm give up in this case with "intractable!". Temporary workaround: no known workaround