pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.72k stars 2.01k forks source link

`compile_forward_sampling_function` does not take into consideration shape of variables in trace #6243

Open ricardoV94 opened 2 years ago

ricardoV94 commented 2 years ago

compile_forward_sampling_function tries to reuse any variables that are in the trace as inputs. It uses some logic based on variable name, shared inputs, volatile inputs, ... and so on:

https://github.com/pymc-devs/pymc/blob/534b89bb73c11eb32f65b71344199b45717c8914/pymc/sampling.py#L1706-L1723

It currently does not check variable shapes (except some cases involving dims or shared variables), leading to incorrect samples or even invalid graphs:

import pymc as pm

with pm.Model() as m1:
  x = pm.Normal("x", shape=(3,))
  y = pm.Normal("y", x[2])
  idata = pm.sample(chains=2, draws=100)

with pm.Model() as m2:
  x = pm.Normal("x", shape=(4,))
  y = pm.Normal("y", x[3])
  pp = pm.sample_posterior_predictive(idata, var_names=["y"])  # IndexError

CC @lucianopaz

lucianopaz commented 2 years ago

I think that we need to ask ourselves if we want this to throw an error or to add as a criteria to tell if a variable is volatile

ricardoV94 commented 2 years ago

This might be a self-correcting issue, because the next release of Aesara will start outputting static-shapes type for RVs: https://github.com/aesara-devs/aesara/pull/1253/commits/b11a9a0a8c553491b57bf8b35dffdb35a1b792f1

Aesara main now does the following:

import aesara
import aesara.tensor as at

x = at.random.normal(size=(3,))
print(x.type)  # TensorType(float64, (3,))
y = x + 1

f = aesara.function([x], y)
f([0, 0, 0, 0])  # TypeError: The type's shape ((3,)) is not compatible with the data's ((4,))
ricardoV94 commented 2 years ago

I think that we need to ask ourselves if we want this to throw an error or to add as a criteria to tell if a variable is volatile

Yeah, the error part will be done automatically in the next version (see my previous comment).

It will require more work, but I think it makes sense to flag a variable as volatile if it's shape is known to be different from that in the trace (or if we cannot be sure it is the same)