Open ordabayevy opened 3 years ago
I don't think *Meta.__call__
methods are the right place for this kind of simplification logic (these methods should do nothing except fill default values for optional arguments and possibly call to_funsor
where appropriate), and it really shouldn't be necessary regardless - it sounds like the issue, as is so often unfortunately the case, is with the logic in funsor.adjoint
, or perhaps with alpha-renaming. Can you distill the underlying issue into a failing test comparing a correct expression and a very simple incorrect one generated by adjoint
?
Pragmatically, however, I'm OK with merging this if you can do that and explain how this fix would unblock what you're actually working on.
Can you distill the underlying issue into a failing test comparing a correct expression and a very simple incorrect one generated by adjoint?
I discovered this issue by examining test_adjoint.py::test_sequential_sum_product_adjoint: xfail_param(MarkovProduct, reason="mysteriously doubles adjoint values?")
.
In that test under adjoint
interpretation eager_markov_product
calls sequential_sum_product
(no issues there) and then calls Subs(results, step_names)
where step_names = {"prev": "prev", "curr": "curr"}
. Since the names in step_names
are the identical, Subs(result, step_names)
returns the same result
, however, it also gets appended to the AdjointTape.tape
and I think that doubles adjoint values.
Simple solution can be to check step_names
and call Subs(result, step_names)
only if names are not identical. I moved that logic to SubsMeta.__call__
thinking that it might help to avoid similar issues in the future.
def eager_markov_product(sum_op, prod_op, trans, time, step, step_names):
if step:
result = sequential_sum_product(sum_op, prod_op, trans, time, dict(step))
...
return Subs(result, step_names)
Can you distill the underlying issue into a failing test comparing a correct expression and a very simple incorrect one generated by adjoint?
Yes, I believe this can be boiled down to much simpler failing test with just couple of lines of code than test_adjoint.py::test_sequential_sum_product_adjoint: xfail_param(MarkovProduct, reason="mysteriously doubles adjoint values?")
. Should I add such a test?
Should I add such a test?
Yes, that would be very helpful!
This proposes to avoid any trivial
subs
(do not callinterpret
if allsubs
are trivial).Trivial subs can arise for example in eager
MarkovProduct
withstep_names = {"prev": "prev", "curr": "curr"}
:which then pollutes
AdjointTape.tape
underadjoint
interpretation (#493 #544).