pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.67k stars 2k forks source link

Convert random variables to value variables so `pm.sample(var_names)` works correctly #7284

Closed tomicapretto closed 5 months ago

tomicapretto commented 5 months ago

Description

This PR converts the random variables to value variables when var_names is not None in pm.sample(). Before this PR, using var_names resulted in sampling from the prior. The problem is better explained in the linked issue.

Related Issue

Closes #7258

Type of change


📚 Documentation preview 📚: https://pymc--7284.org.readthedocs.build/en/7284/

tomicapretto commented 5 months ago

@ricardoV94 in order for this to match what is done in model.unobserved_value_vars I do

    # Get value variables for the trace
    if var_names is not None:
        value_vars = []
        transformed_rvs = []
        for rv in model.unobserved_RVs:
            if rv.name in var_names:
                value_var = model.rvs_to_values[rv]
                transform = model.rvs_to_transforms[rv]
                if transform is not None:
                    transformed_rvs.append(rv)
                value_vars.append(value_var)

        transformed_value_vars = model.replace_rvs_by_values(transformed_rvs)
        trace_vars = value_vars + transformed_value_vars
        assert len(trace_vars) == len(var_names), "Not all var_names were found in the model"

However, an assertion error is raised when there are transformed variables because they add two elements to the trace_vars list (i.e. kappa_log__ and kappa in the example shown in the issue).

Do you think the modification already pushed in the PR is the correct one, or do we need to somehow explicitly account for the transformations?

tomicapretto commented 5 months ago

It seems the current state is doing the right thing. See the following example.

import arviz as az
import numpy as np
import pymc as pm

batch = np.array(
    [
        1,  1,  1,  1,  2,  2,  2,  3,  3,  3,  4,  4,  4,  4,  5,  5,  5,
        6,  6,  6,  7,  7,  7,  7,  8,  8,  8,  9,  9, 10, 10, 10
    ]
)
temp = np.array(
    [
        205, 275, 345, 407, 218, 273, 347, 212, 272, 340, 235, 300, 365,
        410, 307, 367, 395, 267, 360, 402, 235, 275, 358, 416, 285, 365,
        444, 351, 424, 365, 379, 428
    ]
)

y = np.array(
    [
        0.122, 0.223, 0.347, 0.457, 0.08 , 0.131, 0.266, 0.074, 0.182,
        0.304, 0.069, 0.152, 0.26 , 0.336, 0.144, 0.268, 0.349, 0.1  ,
        0.248, 0.317, 0.028, 0.064, 0.161, 0.278, 0.05 , 0.176, 0.321,
        0.14 , 0.232, 0.085, 0.147, 0.18
    ]
)

batch_values, batch_idx  = np.unique(batch, return_inverse=True)

coords = {
    "batch": batch_values
}

with pm.Model(coords=coords) as model:
    b_batch = pm.Normal("b_batch", dims="batch")
    b_temp = pm.Normal("b_temp")
    mu = pm.Deterministic("mu", pm.math.invlogit(b_batch[batch_idx] + b_temp * temp))
    kappa = pm.Gamma("kappa", alpha=2, beta=2)

    alpha = mu * kappa
    beta = (1 - mu) * kappa

    pm.Beta("y", alpha=alpha, beta=beta, observed=y)

with model:
    idata_1 = pm.sample(random_seed=1234)
    idata_2 = pm.sample(var_names=["b_batch", "b_temp", "kappa"], random_seed=1234)

az.plot_forest([idata_1, idata_2], var_names=["b_batch"])
az.plot_forest([idata_1, idata_2], var_names=["b_temp"])
az.plot_forest([idata_1, idata_2], var_names=["kappa"])

image image image

codecov-commenter commented 5 months ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 92.33%. Comparing base (60a6314) to head (281fb5d).

Additional details and impacted files [![Impacted file tree graph](https://app.codecov.io/gh/pymc-devs/pymc/pull/7284/graphs/tree.svg?width=650&height=150&src=pr&token=JFuXtOJ4Cb&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs)](https://app.codecov.io/gh/pymc-devs/pymc/pull/7284?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs) ```diff @@ Coverage Diff @@ ## main #7284 +/- ## ========================================== + Coverage 91.67% 92.33% +0.65% ========================================== Files 102 102 Lines 17017 17018 +1 ========================================== + Hits 15600 15713 +113 + Misses 1417 1305 -112 ``` | [Files](https://app.codecov.io/gh/pymc-devs/pymc/pull/7284?dropdown=coverage&src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs) | Coverage Δ | | |---|---|---| | [pymc/sampling/mcmc.py](https://app.codecov.io/gh/pymc-devs/pymc/pull/7284?src=pr&el=tree&filepath=pymc%2Fsampling%2Fmcmc.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs#diff-cHltYy9zYW1wbGluZy9tY21jLnB5) | `87.74% <100.00%> (+0.46%)` | :arrow_up: | ... and [3 files with indirect coverage changes](https://app.codecov.io/gh/pymc-devs/pymc/pull/7284/indirect-changes?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=pymc-devs)
ricardoV94 commented 5 months ago

This one warranted a regression test