Closed ordabayevy closed 3 years ago
Why is this blocked by #545?
MarkovProduct
test (along with _old_interpretation
fix).Why do we need to unblock the MarkovProduct
test to merge this?
Sorry for making it confusing, I wasn't sure how to organize these PRs. I will make a separate PR for MarkovProduct
test, it is related but a separate issue.
I wasn't sure how to organize these PRs
No worries, I just don't want your substantive work to be held up with yet more adjoint/alpha-conversion insanity.
Do the Pyro contrib.funsor
tests pass under these changes?
Yes, just tested it.
$ pytest tests/contrib/funsor/
============================================================================ test session starts =============================================================================
platform linux -- Python 3.8.10, pytest-4.3.1, py-1.10.0, pluggy-0.13.1
rootdir: /home/ordabayev/repos/pyro, inifile: setup.cfg
plugins: xdist-1.27.0, nbval-0.9.6, forked-1.3.0
collected 331 items
tests/contrib/funsor/test_enum_funsor.py .........................x...................x.x..... [ 16%]
tests/contrib/funsor/test_infer_discrete.py .......................... [ 23%]
tests/contrib/funsor/test_named_handlers.py ...... [ 25%]
tests/contrib/funsor/test_pyroapi_funsor.py ................x [ 30%]
tests/contrib/funsor/test_tmc.py ........................................................................................ [ 57%]
tests/contrib/funsor/test_valid_models_enum.py ................xxxx.................................... [ 74%]
tests/contrib/funsor/test_valid_models_plate.py ........X.X.X.X................ [ 83%]
tests/contrib/funsor/test_valid_models_sequential_plate.py ......... [ 86%]
tests/contrib/funsor/test_vectorized_markov.py ........................XXXXXXXXXXXX.....x... [100%]
============================================================= 306 passed, 9 xfailed, 16 xpassed in 89.28 seconds =============================================================
Fixes nan gradients in
TraceEnum_ELBO
when evaluated eagerly (mentioned in https://github.com/pyro-ppl/funsor/issues/493).The problem seems to be that
reflect.interpret
in line 62 leads to substitutions that haveadjoint
as abase_interpretation
which in turn leads to some extra expressions being added to thetape
. Here I fix it by changing interpretation to_old_interpretation
.This also solvestest_adjoint.py::test_sequential_sum_product_adjoint
:xfail_param(MarkovProduct, reason="mysteriously doubles adjoint values?")
(blocked by #545 ).