pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.47k stars 1.97k forks source link

Implement specialized MvNormal density based on precision matrix #7345

Open ricardoV94 opened 4 weeks ago

ricardoV94 commented 4 weeks ago

Description

This PR is exploring a specialized logp for a MvNormal (and possible MvStudentT) parametrized directly in terms of tau. According to common model implementation looks like:

import pymc as pm
import numpy as np

A = np.array([
    [0, 1, 1],
    [1, 0, 1], 
    [1, 1, 0]
])
D = A.sum(axis=-1)
np.testing.assert_allclose(A, A.T), "should be symmetric"

with pm.Model() as m:
    tau = pm.InverseGamma("tau", 1, 1)
    alpha = pm.Beta("alpha", 10, 10)
    Q = tau * (D - alpha * A)
    y = pm.MvNormal("y", mu=np.zeros(3), tau=Q)

TODO (some are optional for this PR)

Related Issue

Checklist

Type of change

CC @theorashid @elizavetasemenova


📚 Documentation preview 📚: https://pymc--7345.org.readthedocs.build/en/7345/

ricardoV94 commented 4 weeks ago

Implementation checks may fail until https://github.com/pymc-devs/pytensor/issues/799 is fixed

review-notebook-app[bot] commented 1 week ago

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

ricardoV94 commented 1 week ago

Benchmark code

import pymc as pm
import numpy as np
import pytensor
import pytensor.tensor as pt

rng = np.random.default_rng(123)

n = 100
L = rng.uniform(low=0.1, high=1.0, size=(n,n))
Sigma = L @ L.T
Q_test = np.linalg.inv(Sigma)
x_test = rng.normal(size=n)

with pm.Model(check_bounds=False) as m:
    Q = pm.Data("Q", Q_test)
    x = pm.MvNormal("x", mu=pt.zeros(n), tau=Q)

logp_fn = m.compile_logp().f
logp_fn.trust_input=True
print("logp")
pytensor.dprint(logp_fn)

dlogp_fn = m.compile_dlogp().f
dlogp_fn.trust_input=True
print("dlogp")
pytensor.dprint(dlogp_fn)

np.testing.assert_allclose(logp_fn(x_test), np.array(-1789.93662205))

np.testing.assert_allclose(np.sum(dlogp_fn(x_test) ** 2), np.array(18445204.8755109), rtol=1e-6)

# Before: 2.66 ms
# After: 1.31 ms
%timeit -n 1000 logp_fn(x_test)

# Before: 2.45 ms
# After: 72 µs
%timeit -n 1000 dlogp_fn(x_test)
codecov[bot] commented 6 days ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 92.20%. Comparing base (b496127) to head (2b9886e). Report is 1 commits behind head on main.

Additional details and impacted files [![Impacted file tree graph](https://app.codecov.io/gh/pymc-devs/pymc/pull/7345/graphs/tree.svg?width=650&height=150&src=pr&token=JFuXtOJ4Cb&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs)](https://app.codecov.io/gh/pymc-devs/pymc/pull/7345?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs) ```diff @@ Coverage Diff @@ ## main #7345 +/- ## ========================================== + Coverage 92.19% 92.20% +0.01% ========================================== Files 103 103 Lines 17214 17247 +33 ========================================== + Hits 15870 15903 +33 Misses 1344 1344 ``` | [Files](https://app.codecov.io/gh/pymc-devs/pymc/pull/7345?dropdown=coverage&src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs) | Coverage Δ | | |---|---|---| | [pymc/distributions/multivariate.py](https://app.codecov.io/gh/pymc-devs/pymc/pull/7345?src=pr&el=tree&filepath=pymc%2Fdistributions%2Fmultivariate.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs#diff-cHltYy9kaXN0cmlidXRpb25zL211bHRpdmFyaWF0ZS5weQ==) | `93.07% <100.00%> (+0.22%)` | :arrow_up: | | [pymc/logprob/rewriting.py](https://app.codecov.io/gh/pymc-devs/pymc/pull/7345?src=pr&el=tree&filepath=pymc%2Flogprob%2Frewriting.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs#diff-cHltYy9sb2dwcm9iL3Jld3JpdGluZy5weQ==) | `89.30% <100.00%> (+0.17%)` | :arrow_up: |
ricardoV94 commented 6 days ago

Final question is just whether we want / can do a similar thing for the MvStudentT. Otherwise it's ready to merge on my end

CC @elizavetasemenova