pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.7k stars 2.01k forks source link

Relationship between valued RVs lost during logp inference #6917

Closed ricardoV94 closed 1 month ago

ricardoV94 commented 1 year ago

Description

Found out by @tvwenger in https://github.com/pymc-devs/pymc/discussions/6905#discussioncomment-7027204

import pymc as pm
from pymc.logprob import conditional_logp

a_base = pm.Normal.dist()
a = a_base * 5
b = pm.Normal.dist(a * 8)

a_value = a.type()
b_value = b.type()
conditional_logp({a: a_value, b: b_value})  #  UserWarning: Random variables detected in the logp graph:

In this case the value variable of a is not replaced in the graph of b, because b is rewritten as pm.Normal.dist(a_base * 40) which doesn't have a value variable!

The immediate issue here, is that we apply canonicalization rewrites that don't care about whether a variable is valued or not.

The bigger issue is that we don't have a means of allowing symbolic value transforms in our logprob inference. The correct thing would be to allow variables to depend on invertible transformations of valued variables. In this case a_base * 40 = (a_value / 5) * 40

ricardoV94 commented 1 year ago

This might be a good time to introduce ValuedVariables in the IR graph. That will block such rewrites from happening!