pyro-ppl / funsor

Functional tensors for probabilistic programming
https://funsor.pyro.ai
Apache License 2.0
236 stars 20 forks source link

`funsor.joint.eager_reduce_exp` behaves differently with `memoize` #561

Open ordabayevy opened 3 years ago

ordabayevy commented 3 years ago

The if statement in eager_reduce_exp evaluates to False under memoize and the function returns None. Without memoize it returns log_result.exp() as expected.

https://github.com/pyro-ppl/funsor/blob/ca1557b3786235e7803e45a4d7025c75b13c7d5f/funsor/joint.py#L157-L165

Example code:

from funsor.cnf import Contraction
from funsor.tensor import Tensor
import torch
import funsor.ops as ops
from funsor import Bint, Real
from funsor.terms import Unary, Binary, Variable, Number, eager, lazy, to_data, Reduce
from funsor.constant import Constant
from funsor.delta import Delta
from funsor.integrate import Integrate
import funsor

funsor.set_backend("torch")

cls = Reduce
args = (ops.add,
        Unary(ops.exp,
         Contraction(ops.null, ops.add,
          frozenset(),
          (Delta(
            (('x__BOUND_16',
              (Tensor(
                torch.tensor([1, 0, 1, 0, 0, 0, 1, 0, 1, 1], dtype=torch.int64),
                (('plate__BOUND_17',
                  Bint[10],),),
                3),
               Number(0.0),),),)),
           Tensor(
            torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], dtype=torch.float64),  # noqa
            (('plate__BOUND_17',
              Bint[10],),),
            'real'),))),
        frozenset({Variable('x__BOUND_16', Bint[3])})
    )

# evaluates to a Tensor
result = eager.interpret(cls, *args)

with funsor.interpretations.memoize():
    # evaluates to a lazy Contraction term
    result2 = eager.interpret(cls, *args)