Open ricardoV94 opened 2 years ago
I think that we need to ask ourselves if we want this to throw an error or to add as a criteria to tell if a variable is volatile
This might be a self-correcting issue, because the next release of Aesara
will start outputting static-shapes type for RVs: https://github.com/aesara-devs/aesara/pull/1253/commits/b11a9a0a8c553491b57bf8b35dffdb35a1b792f1
Aesara main now does the following:
import aesara
import aesara.tensor as at
x = at.random.normal(size=(3,))
print(x.type) # TensorType(float64, (3,))
y = x + 1
f = aesara.function([x], y)
f([0, 0, 0, 0]) # TypeError: The type's shape ((3,)) is not compatible with the data's ((4,))
I think that we need to ask ourselves if we want this to throw an error or to add as a criteria to tell if a variable is volatile
Yeah, the error part will be done automatically in the next version (see my previous comment).
It will require more work, but I think it makes sense to flag a variable as volatile if it's shape is known to be different from that in the trace (or if we cannot be sure it is the same)
compile_forward_sampling_function
tries to reuse any variables that are in the trace as inputs. It uses some logic based on variable name, shared inputs, volatile inputs, ... and so on:https://github.com/pymc-devs/pymc/blob/534b89bb73c11eb32f65b71344199b45717c8914/pymc/sampling.py#L1706-L1723
It currently does not check variable shapes (except some cases involving dims or shared variables), leading to incorrect samples or even invalid graphs:
CC @lucianopaz