pyro-ppl / funsor

Functional tensors for probabilistic programming
https://funsor.pyro.ai
Apache License 2.0
236 stars 20 forks source link

Fix nan gradients in TraceEnum_ELBO when evaluated eagerly #544

Closed ordabayevy closed 3 years ago

ordabayevy commented 3 years ago

Fixes nan gradients in TraceEnum_ELBO when evaluated eagerly (mentioned in https://github.com/pyro-ppl/funsor/issues/493).

The problem seems to be that reflect.interpret in line 62 leads to substitutions that have adjoint as a base_interpretation which in turn leads to some extra expressions being added to the tape. Here I fix it by changing interpretation to _old_interpretation.

This also solves test_adjoint.py::test_sequential_sum_product_adjoint: xfail_param(MarkovProduct, reason="mysteriously doubles adjoint values?") (blocked by #545 ).

eb8680 commented 3 years ago

Why is this blocked by #545?

ordabayevy commented 3 years ago

545 is needed to unblock MarkovProduct test (along with _old_interpretation fix).

eb8680 commented 3 years ago

Why do we need to unblock the MarkovProduct test to merge this?

ordabayevy commented 3 years ago

Sorry for making it confusing, I wasn't sure how to organize these PRs. I will make a separate PR for MarkovProduct test, it is related but a separate issue.

eb8680 commented 3 years ago

I wasn't sure how to organize these PRs

No worries, I just don't want your substantive work to be held up with yet more adjoint/alpha-conversion insanity.

eb8680 commented 3 years ago

Do the Pyro contrib.funsor tests pass under these changes?

ordabayevy commented 3 years ago

Yes, just tested it.

$ pytest tests/contrib/funsor/
============================================================================ test session starts =============================================================================
platform linux -- Python 3.8.10, pytest-4.3.1, py-1.10.0, pluggy-0.13.1
rootdir: /home/ordabayev/repos/pyro, inifile: setup.cfg
plugins: xdist-1.27.0, nbval-0.9.6, forked-1.3.0
collected 331 items

tests/contrib/funsor/test_enum_funsor.py .........................x...................x.x.....                                                                         [ 16%]
tests/contrib/funsor/test_infer_discrete.py ..........................                                                                                                 [ 23%]
tests/contrib/funsor/test_named_handlers.py ......                                                                                                                     [ 25%]
tests/contrib/funsor/test_pyroapi_funsor.py ................x                                                                                                          [ 30%]
tests/contrib/funsor/test_tmc.py ........................................................................................                                              [ 57%]
tests/contrib/funsor/test_valid_models_enum.py ................xxxx....................................                                                                [ 74%]
tests/contrib/funsor/test_valid_models_plate.py ........X.X.X.X................                                                                                        [ 83%]
tests/contrib/funsor/test_valid_models_sequential_plate.py .........                                                                                                   [ 86%]
tests/contrib/funsor/test_vectorized_markov.py ........................XXXXXXXXXXXX.....x...                                                                           [100%]

============================================================= 306 passed, 9 xfailed, 16 xpassed in 89.28 seconds =============================================================