pyro-ppl / funsor

Functional tensors for probabilistic programming
https://funsor.pyro.ai
Apache License 2.0
236 stars 20 forks source link

Avoid trivial subs #545

Open ordabayevy opened 3 years ago

ordabayevy commented 3 years ago

This proposes to avoid any trivial subs (do not call interpret if all subs are trivial).

Trivial subs can arise for example in eager MarkovProduct with step_names = {"prev": "prev", "curr": "curr"}:

def eager_markov_product(sum_op, prod_op, trans, time, step, step_names):
    ...
    return Subs(result, step_names)

which then pollutes AdjointTape.tape under adjoint interpretation (#493 #544).

eb8680 commented 3 years ago

I don't think *Meta.__call__ methods are the right place for this kind of simplification logic (these methods should do nothing except fill default values for optional arguments and possibly call to_funsor where appropriate), and it really shouldn't be necessary regardless - it sounds like the issue, as is so often unfortunately the case, is with the logic in funsor.adjoint, or perhaps with alpha-renaming. Can you distill the underlying issue into a failing test comparing a correct expression and a very simple incorrect one generated by adjoint?

Pragmatically, however, I'm OK with merging this if you can do that and explain how this fix would unblock what you're actually working on.

ordabayevy commented 3 years ago

Can you distill the underlying issue into a failing test comparing a correct expression and a very simple incorrect one generated by adjoint?

I discovered this issue by examining test_adjoint.py::test_sequential_sum_product_adjoint: xfail_param(MarkovProduct, reason="mysteriously doubles adjoint values?").

In that test under adjoint interpretation eager_markov_product calls sequential_sum_product (no issues there) and then calls Subs(results, step_names) where step_names = {"prev": "prev", "curr": "curr"}. Since the names in step_names are the identical, Subs(result, step_names) returns the same result, however, it also gets appended to the AdjointTape.tape and I think that doubles adjoint values.

Simple solution can be to check step_names and call Subs(result, step_names) only if names are not identical. I moved that logic to SubsMeta.__call__ thinking that it might help to avoid similar issues in the future.

def eager_markov_product(sum_op, prod_op, trans, time, step, step_names):
    if step:
        result = sequential_sum_product(sum_op, prod_op, trans, time, dict(step))
    ...
    return Subs(result, step_names)
ordabayevy commented 3 years ago

Can you distill the underlying issue into a failing test comparing a correct expression and a very simple incorrect one generated by adjoint?

Yes, I believe this can be boiled down to much simpler failing test with just couple of lines of code than test_adjoint.py::test_sequential_sum_product_adjoint: xfail_param(MarkovProduct, reason="mysteriously doubles adjoint values?"). Should I add such a test?

eb8680 commented 3 years ago

Should I add such a test?

Yes, that would be very helpful!