pymc-devs / pymc-experimental

https://pymc-experimental.readthedocs.io
Other
72 stars 47 forks source link

Copy model-related shared variables in `model_fgraph` #218

Closed ricardoV94 closed 12 months ago

ricardoV94 commented 12 months ago

In earlier iterations of the do-blogpost it became clear that it's more intuitive for model-related shared variables to be copied, instead of shared across different models. Some of these variables are created by PyMC (dim_lengths) and users would have to know the internals not to be surprised about the behavior.

User defined MutableData is also copied, because unlike the underlying SharedVariables, nothing about the name and documentation suggests these variables are supposed to be "shared" across models (just that they can be mutated).

I have only allowed user-defined shared variables that are not model variables to be actually shared across model cloning.

This now works as (imo) expected:

import pymc as pm
from pymc_experimental.model_transform.conditioning import do

with pm.Model(coords_mutable={"i": [0]}) as m:
    x = pm.Normal("x")
    y = pm.Normal("y", x, dims=("i",))

new_m = do(m, {x: 0})
new_m.set_dim("i", new_length=5, coord_values=list(range(5)))

assert pm.draw(m["y"]).shape == (1,)  # Before this would also be (5,)
assert pm.draw(new_m["y"].shape) == (5,)

This PR also fixes a bug, as RNG shared variables weren't actually being copied and rendered independent with the old clone approach. The relevant test was missing an assert :( (see first commit of this PR)