pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.53k stars 1.98k forks source link

BUG: Unable to apply `Interval` foward transformation #7193

Closed tvwenger closed 4 months ago

tvwenger commented 4 months ago

Describe the issue:

Description

I wish to evaluate the log probability of a model at a specific point. Since some RVs may be transformed, I must apply the transformation to the input point in order to pass the transformed point to model.logp. If one of the RV transforms is Interval, then the call to forward() fails with an obscure error.

Expected Behavior

Applying forward to a RV transformed via Interval should not raise an exception.

Actual Behavior

Applying forward to a RV transformed via Interval raises an exception.

Minimum working example

The following example demonstrates what I want to do. I have a model where x is not transformed, y has a LogTransform, and z has an Interval transform. I want to evaluate logp at a given x, y, z. Since logp requires the transformed point in order to evaluate the expression, I must apply the transformations and pass the transformed point to logp().eval().

The output of this MWE is below:

```shell 5.10.4 name: x param: x transform: None value: 1.0 transformed_value: 1.0 name: y param: y_log__ transform: LogTransform value: 0.5 transformed_value: -0.6931471824645996 name: z param: z_interval__ transform: Interval ```

Reproduceable code example:

import pymc as pm

print(pm.__version__)

with pm.Model() as model:
    x = pm.Normal("x", mu=0.0, sigma=1.0)
    y = pm.LogNormal("y", mu=1.0, sigma=1.0)
    z = pm.TruncatedNormal("z", mu=0.0, sigma=1.0, lower=-1.0, upper=1.0)

point = {"x": 1.0, "y": 0.5, "z": 0.0}
point_transformed = {}
for rv in model.free_RVs:
    name = rv.name
    param = model.rvs_to_values[rv]
    transform = model.rvs_to_transforms[rv]
    print(f"name: {name} param: {param} transform: {transform}")
    if transform is None:
        point_transformed[param] = point[name]
    else:
        point_transformed[param] = transform.forward(point[name]).eval()
    print(f"value: {point[name]} transformed_value: {point_transformed[param]}")

print(f"Log prob: {model.logp().eval(point_transformed)}")

Error message:

```shell IndexError Traceback (most recent call last) Cell In[25], line 20 18 point_transformed[param] = point[name] 19 else: ---> 20 point_transformed[param] = transform.forward(point[name]).eval() 21 print(f"value: {point[name]} transformed_value: {point_transformed[param]}") 23 print(f"Log prob: {model.logp().eval(point_transformed)}") File ~/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc/logprob/transforms.py:853, in IntervalTransform.forward(self, value, *inputs) 852 def forward(self, value, *inputs): --> 853 a, b, lower_bounded, upper_bounded = self.get_a_and_b(inputs) 855 log_lower_distance = pt.log(value - a) 856 log_upper_distance = pt.log(b - value) File ~/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc/logprob/transforms.py:842, in IntervalTransform.get_a_and_b(self, inputs) 836 def get_a_and_b(self, inputs): 837 """Return interval bound values. 838 839 Also returns two boolean variables indicating whether the transform is known to be statically bounded. 840 This is used to generate smaller graphs in the transform methods. 841 """ --> 842 a, b = self.args_fn(*inputs) 843 lower_bounded, upper_bounded = True, True 844 if a is None: File ~/miniconda3/envs/pymc/lib/python3.11/site-packages/pymc/distributions/continuous.py:180, in bounded_cont_transform..transform_params(*args) 178 lower, upper = None, None 179 if bound_args_indices[0] is not None: --> 180 lower = args[bound_args_indices[0]] 181 if bound_args_indices[1] is not None: 182 upper = args[bound_args_indices[1]] IndexError: tuple index out of range ```

PyMC version information:

pymc: 5.10.4

Context for the issue:

I would like to be able to capture and apply forward to any transform.

ricardoV94 commented 4 months ago

The signature of forward expects all inputs of the RV besides the value.

For your goal, if you don't care about the jacobian term you can remove the value transforms with https://www.pymc.io/projects/docs/en/stable/api/model/generated/pymc.model.transform.conditioning.remove_value_transforms.html

Otherwise you need to define a function that pushes all the values forward. That's not straightforward and something we want to offer users: https://github.com/pymc-devs/pymc/issues/6721

Help there would be much welcome.

Unless I missed something I would close this issue as duplicated. The code doesn't show a bug but an incorrect use of the transform objects

tvwenger commented 4 months ago

@ricardoV94 Thanks for the insight. Following your response here I was able to get this to work by also passing rv.owner.inputs to forward:

import pymc as pm

print(pm.__version__)

with pm.Model() as model:
    x = pm.Normal("x", mu=0.0, sigma=1.0)
    y = pm.LogNormal("y", mu=1.0, sigma=1.0)
    z = pm.TruncatedNormal("z", mu=0.0, sigma=1.0, lower=-1.0, upper=1.0)

point = {"x": 1.0, "y": 0.5, "z": 0.0}
point_transformed = {}
for rv in model.free_RVs:
    name = rv.name
    param = model.rvs_to_values[rv]
    transform = model.rvs_to_transforms[rv]
    print(f"name: {name} param: {param} transform: {transform}")
    if transform is None:
        point_transformed[param] = point[name]
    else:
        point_transformed[param] = transform.forward(
            point[name], *rv.owner.inputs
        ).eval()
    print(f"value: {point[name]} transformed_value: {point_transformed[param]}")

print(f"Log prob: {model.logp().eval(point_transformed)}")
ricardoV94 commented 4 months ago

Note that only works because lower/upper are constants. If they depended on other parameters you would get wrong results