pyro-ppl / funsor

Functional tensors for probabilistic programming
https://funsor.pyro.ai
Apache License 2.0
236 stars 20 forks source link

Gaussian funsor variable elimination #559

Open fritzo opened 3 years ago

fritzo commented 3 years ago

Addresses https://github.com/pyro-ppl/pyro/pull/2929 See design doc

This issue tracks changes needed to efficiently perform variable elimination in Gaussian graphical models with plates. While funsor.sum_product.sum_product() is a partial solution, we'd like to generalize to a complete solution.

Tasks

fritzo commented 3 years ago

@eb8680 it looks like AutoGaussian(pyrocov_model) runs out of GPU memory in constructing a low-rank matrices precision = sqrt @ sqrt.T. One possible solution is to use a sqrt(precision) representation in funsor's Gaussian. I guess the crux is whether we can implement cheap Gaussian tensordot without materializing intermediate low-rank precision matrices. @fehiepsi already worked out most of the sqrt representation in Pyro PR #2019, where ops.add becomes mere concatenation.

@fehiepsi how much effort do you think it would it take for us to port your Pyro PR #2019 to funsor (where it would also be available in NumPyro πŸ˜‰)?

fritzo commented 3 years ago

Here is the optimized GFVE schedule for my pyro-cov model. It fits in main memory but runs out of GPU memory.

``` Contraction(ops.null, ops.add, frozenset(), (Contraction(ops.logaddexp, ops.add, frozenset({Variable('rate_loc_scale__BOUND_13', Real)}), (Gaussian( β”‚ torch.tensor(...1..., dtype=torch.float32), β”‚ torch.tensor(...1 x 1..., dtype=torch.float32), β”‚ (('rate_loc_scale__BOUND_13', Real),)), β”‚Contraction(ops.logaddexp, ops.add, β”‚ frozenset({Variable('rate_scale__BOUND_14', Real)}), β”‚ (Gaussian( β”‚ torch.tensor(...1..., dtype=torch.float32), β”‚ torch.tensor(...1 x 1..., dtype=torch.float32), β”‚ (('rate_scale__BOUND_14', Real),)), β”‚ Contraction(ops.logaddexp, ops.add, β”‚ frozenset({Variable('coef__BOUND_12', Reals[2367])}), β”‚ (Gaussian( β”‚ β”‚ torch.tensor(...2367..., dtype=torch.float32), β”‚ β”‚ torch.tensor(...2367 x 2367..., dtype=torch.float32), β”‚ β”‚ (('coef__BOUND_12', Reals[2367]),)), β”‚ β”‚Contraction(ops.add, ops.null, β”‚ β”‚ frozenset({Variable('strain__BOUND_11', Bint[1343])}), β”‚ β”‚ (Contraction(ops.logaddexp, ops.add, β”‚ β”‚ frozenset({Variable('rate_loc__BOUND_10', Real)}), β”‚ β”‚ (Gaussian( β”‚ β”‚ β”‚ torch.tensor(...1343 x 2369..., dtype=torch.float32), β”‚ β”‚ β”‚ torch.tensor(...1343 x 2369 x 2369..., dtype=torch.float32), β”‚ β”‚ β”‚ (('strain__BOUND_11', Bint[1343]), β”‚ β”‚ β”‚ ('rate_loc__BOUND_10', Real), β”‚ β”‚ β”‚ ('rate_loc_scale__BOUND_13', Real), β”‚ β”‚ β”‚ ('coef__BOUND_12', Reals[2367]),)), β”‚ β”‚ β”‚Contraction(ops.add, ops.null, β”‚ β”‚ β”‚ frozenset({Variable('place__BOUND_4', Bint[1372])}), β”‚ β”‚ β”‚ (Contraction(ops.logaddexp, ops.null, β”‚ β”‚ β”‚ frozenset({Variable('rate__BOUND_3', Real)}), β”‚ β”‚ β”‚ (Gaussian( β”‚ β”‚ β”‚ β”‚ torch.tensor(...1372 x 1343 x 3..., dtype=torch.float32), β”‚ β”‚ β”‚ β”‚ torch.tensor(...1372 x 1343 x 3 x 3..., dtype=torch.float32), β”‚ β”‚ β”‚ β”‚ (('place__BOUND_4', Bint[1372]), β”‚ β”‚ β”‚ β”‚ ('strain__BOUND_11', Bint[1343]), β”‚ β”‚ β”‚ β”‚ ('rate__BOUND_3', Real), β”‚ β”‚ β”‚ β”‚ ('rate_scale__BOUND_14', Real), β”‚ β”‚ β”‚ β”‚ ('rate_loc__BOUND_10', Real),)),)),)),)),)),)),)),)), Contraction(ops.null, ops.add, frozenset(), (Contraction(ops.logaddexp, ops.add, β”‚ frozenset({Variable('pois_loc__BOUND_16', Real)}), β”‚ (Gaussian( β”‚ torch.tensor(...1..., dtype=torch.float32), β”‚ torch.tensor(...1 x 1..., dtype=torch.float32), β”‚ (('pois_loc__BOUND_16', Real),)), β”‚ Contraction(ops.logaddexp, ops.add, β”‚ frozenset({Variable('pois_scale__BOUND_15', Real)}), β”‚ (Gaussian( β”‚ β”‚ torch.tensor(...1..., dtype=torch.float32), β”‚ β”‚ torch.tensor(...1 x 1..., dtype=torch.float32), β”‚ β”‚ (('pois_scale__BOUND_15', Real),)), β”‚ β”‚Contraction(ops.add, ops.null, β”‚ β”‚ frozenset({Variable('place__BOUND_6', Bint[1372]), Variable('time__BOUND_7', Bint[49])}), β”‚ β”‚ (Contraction(ops.logaddexp, ops.null, β”‚ β”‚ frozenset({Variable('pois__BOUND_5', Real)}), β”‚ β”‚ (Gaussian( β”‚ β”‚ β”‚ torch.tensor(...49 x 1372 x 3..., dtype=torch.float32), β”‚ β”‚ β”‚ torch.tensor(...49 x 1372 x 3 x 3..., dtype=torch.float32), β”‚ β”‚ β”‚ (('time__BOUND_7', Bint[49]), β”‚ β”‚ β”‚ ('place__BOUND_6', Bint[1372]), β”‚ β”‚ β”‚ ('pois__BOUND_5', Real), β”‚ β”‚ β”‚ ('pois_loc__BOUND_16', Real), β”‚ β”‚ β”‚ ('pois_scale__BOUND_15', Real),)),)),)),)),)), β”‚Contraction(ops.logaddexp, ops.add, β”‚ frozenset({Variable('init_loc_scale__BOUND_17', Real)}), β”‚ (Gaussian( β”‚ torch.tensor(...1..., dtype=torch.float32), β”‚ torch.tensor(...1 x 1..., dtype=torch.float32), β”‚ (('init_loc_scale__BOUND_17', Real),)), β”‚ Contraction(ops.logaddexp, ops.add, β”‚ frozenset({Variable('init_scale__BOUND_18', Real)}), β”‚ (Gaussian( β”‚ β”‚ torch.tensor(...1..., dtype=torch.float32), β”‚ β”‚ torch.tensor(...1 x 1..., dtype=torch.float32), β”‚ β”‚ (('init_scale__BOUND_18', Real),)), β”‚ β”‚Contraction(ops.add, ops.null, β”‚ β”‚ frozenset({Variable('strain__BOUND_9', Bint[1343])}), β”‚ β”‚ (Contraction(ops.logaddexp, ops.add, β”‚ β”‚ frozenset({Variable('init_loc__BOUND_8', Real)}), β”‚ β”‚ (Gaussian( β”‚ β”‚ β”‚ torch.tensor(...1343 x 2..., dtype=torch.float32), β”‚ β”‚ β”‚ torch.tensor(...1343 x 2 x 2..., dtype=torch.float32), β”‚ β”‚ β”‚ (('strain__BOUND_9', Bint[1343]), β”‚ β”‚ β”‚ ('init_loc__BOUND_8', Real), β”‚ β”‚ β”‚ ('init_loc_scale__BOUND_17', Real),)), β”‚ β”‚ β”‚Contraction(ops.add, ops.null, β”‚ β”‚ β”‚ frozenset({Variable('place__BOUND_2', Bint[1372])}), β”‚ β”‚ β”‚ (Contraction(ops.logaddexp, ops.null, β”‚ β”‚ β”‚ frozenset({Variable('init__BOUND_1', Real)}), β”‚ β”‚ β”‚ (Gaussian( β”‚ β”‚ β”‚ β”‚ torch.tensor(...1372 x 1343 x 3..., dtype=torch.float32), β”‚ β”‚ β”‚ β”‚ torch.tensor(...1372 x 1343 x 3 x 3..., dtype=torch.float32), β”‚ β”‚ β”‚ β”‚ (('place__BOUND_2', Bint[1372]), β”‚ β”‚ β”‚ β”‚ ('strain__BOUND_9', Bint[1343]), β”‚ β”‚ β”‚ β”‚ ('init__BOUND_1', Real), β”‚ β”‚ β”‚ β”‚ ('init_scale__BOUND_18', Real), β”‚ β”‚ β”‚ β”‚ ('init_loc__BOUND_8', Real),)),)),)),)),)),)),)),)),)) ```

The crux is this pair of Gaussian contractions with over 1e9 elements

``` β”‚ Contraction(ops.logaddexp, ops.add, β”‚ frozenset({Variable('coef__BOUND_12', Reals[2367])}), β”‚ (Gaussian( β”‚ β”‚ torch.tensor(...2367..., dtype=torch.float32), β”‚ β”‚ torch.tensor(...2367 x 2367..., dtype=torch.float32), β”‚ β”‚ (('coef__BOUND_12', Reals[2367]),)), β”‚ β”‚Contraction(ops.add, ops.null, β”‚ β”‚ frozenset({Variable('strain__BOUND_11', Bint[1343])}), β”‚ β”‚ (Contraction(ops.logaddexp, ops.add, β”‚ β”‚ frozenset({Variable('rate_loc__BOUND_10', Real)}), β”‚ β”‚ (Gaussian( β”‚ β”‚ β”‚ torch.tensor(...1343 x 2369..., dtype=torch.float32), β”‚ β”‚ β”‚ torch.tensor(...1343 x 2369 x 2369..., dtype=torch.float32), # <-------- OOM here β”‚ β”‚ β”‚ (('strain__BOUND_11', Bint[1343]), β”‚ β”‚ β”‚ ('rate_loc__BOUND_10', Real), β”‚ β”‚ β”‚ ('rate_loc_scale__BOUND_13', Real), β”‚ β”‚ β”‚ ('coef__BOUND_12', Reals[2367]),)), β”‚ β”‚ β”‚Contraction(ops.add, ops.null, β”‚ β”‚ β”‚ frozenset({Variable('place__BOUND_4', Bint[1372])}), β”‚ β”‚ β”‚ (Contraction(ops.logaddexp, ops.null, β”‚ β”‚ β”‚ frozenset({Variable('rate__BOUND_3', Real)}), β”‚ β”‚ β”‚ (Gaussian( β”‚ β”‚ β”‚ β”‚ torch.tensor(...1372 x 1343 x 3..., dtype=torch.float32), β”‚ β”‚ β”‚ β”‚ torch.tensor(...1372 x 1343 x 3 x 3..., dtype=torch.float32), β”‚ β”‚ β”‚ β”‚ (('place__BOUND_4', Bint[1372]), β”‚ β”‚ β”‚ β”‚ ('strain__BOUND_11', Bint[1343]), β”‚ β”‚ β”‚ β”‚ ('rate__BOUND_3', Real), β”‚ β”‚ β”‚ β”‚ ('rate_scale__BOUND_14', Real), β”‚ β”‚ β”‚ β”‚ ('rate_loc__BOUND_10', Real),)),)),)),)),)),)),)),)), ```

I believe we can work around this using a combination of @fehiepsi's prec_sqrt representation https://github.com/pyro-ppl/pyro/pull/2019 and a ConditionalGaussian that generalizes AffineNormal. Happy to discuss.

fehiepsi commented 3 years ago

My impression is most of the details can be preserved (e.g. block vector, block matrix, align gaussian). Back then, one issue was batch qr is very slow on GPU, but torch linalg seems to have been improved a lot since then.

fritzo commented 3 years ago

@fehiepsi do you recall whether Cholesky was sufficient instead of QR? IIRC there was a PyTorch discussion about cheaply testing for positive definiteness or condition number using torch.linalg.cholesky_ex().

fehiepsi commented 3 years ago

Looking at the code, I guess we need to triangulate a non-positive-definite precision matrix (e.g. zeros matrix) but I can't recall when we need such triangularization. :( Probably, it is unnecessary. (anyway, we can switch to qr if we face the positive definiteness issue)

fritzo commented 2 years ago

@eb8680 want to pair code next week on the high-level algorithm for variable elimination, continuing our work from https://github.com/pyro-ppl/funsor/compare/tractable-for-gaussians ?

eb8680 commented 2 years ago

Sure!