automated_dr_learner.ipynb produces wrong results #571

Open dimkab opened 1 month ago

dimkab commented 1 month ago

Running automated_dr_learner.ipynb notebook from the main branch produces the following figures, which are different from what's saved in the notebook itself. image image

eb8680 commented 1 month ago

I just ran the version on master unmodified on my machine and I can't reproduce this - I get correct figures that look pretty much the same as the ones in the notebook now. Did you change the notebook in any way before you ran it? It looks to me like the data underlying the "DR-Monte Carlo" line could have been accidentally overwritten somehow with more draws from the "Plug-in" distribution.

Also, what versions of Pyro and PyTorch are you using?

dimkab commented 1 month ago

I've ran it unmodified both on my macbook and also made a clean checkout on a remote linux machine, installed with pip install -e ".[test]". Same result as before.

For the linux machine:

For the mac machine:

eb8680 commented 1 month ago

I'd be surprised if this was the issue, but can you try installing the latest version of Pyro (pyro-ppl==1.9.1) and running it again from scratch? Pyro is pinned to pyro-ppl<1.9 in the ChiRho dependencies but that's mainly for type-checking reasons - I believe it should work fine with pyro-ppl>=1.9.1 at runtime.

eb8680 commented 1 month ago

And can you also try with PyTorch pinned to torch==2.4.1? chirho.robust makes use of PyTorch features that (shouldn't in principle, but do in reality) change and often break in undocumented and unpredictable ways from release to release.

dimkab commented 1 month ago

Setting both pyro-ppl==1.9.1 and torch==2.4.1 did it!

qinqian commented 1 month ago

I could test this as well.

It seems the library with pyro-ppl==1.9.1 and python==3.9 cannot run the notebook.

TypeError                                 Traceback (most recent call last)
Cell In[1], line 14
     11 import pyro
     12 import pyro.distributions as dist
---> 14 from chirho.counterfactual.handlers import MultiWorldCounterfactual
     15 from chirho.indexed.ops import IndexSet, gather
     16 from chirho.interventional.handlers import do

File /opt/homebrew/Caskroom/miniconda/base/envs/chirho/lib/python3.9/site-packages/chirho/counterfactual/handlers/
----> 1 from .counterfactual import (  # noqa: F401
      2     MultiWorldCounterfactual,
      3     SingleWorldCounterfactual,
      4     SingleWorldFactual,
      5     TwinWorldCounterfactual,
      6 )

File /opt/homebrew/Caskroom/miniconda/base/envs/chirho/lib/python3.9/site-packages/chirho/counterfactual/handlers/
      6 import torch
      8 from chirho.counterfactual.handlers.ambiguity import FactualConditioningMessenger
----> 9 from chirho.counterfactual.ops import preempt, split
     10 from chirho.indexed.handlers import IndexPlatesMessenger
     11 from chirho.indexed.ops import get_index_plates

File /opt/homebrew/Caskroom/miniconda/base/envs/chirho/lib/python3.9/site-packages/chirho/counterfactual/
      6 import pyro
      8 from chirho.indexed.ops import IndexSet, cond_n, scatter_n
----> 9 from chirho.interventional.ops import Intervention, intervene
     11 S = TypeVar("S")
     12 T = TypeVar("T")

File /opt/homebrew/Caskroom/miniconda/base/envs/chirho/lib/python3.9/site-packages/chirho/interventional/
----> 1 from . import handlers  # noqa: F401

File /opt/homebrew/Caskroom/miniconda/base/envs/chirho/lib/python3.9/site-packages/chirho/interventional/
    111             return
    113         msg["value"] = intervene(
    114             msg["value"],
    115             self.actions[msg["name"]],
    116             event_dim=len(msg["fn"].event_shape),
    117             name=msg["name"],
    118         )
--> 121 do = pyro.poutine.handlers._make_handler(Interventions)[1]

TypeError: 'function' object is not subscriptable
eb8680 commented 1 month ago

@qinqian that error is in the latest release of ChiRho (chirho==0.2.0), which is significantly behind the master branch and is not compatible with the notebook in question.

qinqian commented 1 month ago

Thanks @eb8680 ! I reinstall with pip install . using the master branch, now it can be run

eb8680 commented 1 month ago

Setting both pyro-ppl==1.9.1 and torch==2.4.1 did it!

@dimkab I see, good to know. We should probably cut a new ChiRho release at some point that drops backward compatibility with pyro-ppl<=1.9.0 and torch<=2.4.0.

I would guess master still works with torch==2.5.0, which I believe was released quite recently, but maybe we should add CI build stages for each version of PyTorch we support.

qinqian commented 1 month ago

I could reproduce the figures from @dimkab with pyro-ppl==1.9.1 and torch==2.4.1

eb8680 commented 1 month ago

@qinqian can you please provide more detail? Do you mean the correct figures in the notebook now on master, or the incorrect figures in the first post of this issue? Are you sure you're running the unmodified notebook in a fresh kernel/environment with the correct versions of ChiRho (the current master branch, unmodified) and its dependencies (particularly torch==2.4.1 and pyro-ppl==1.9.1) installed? What OS are you using, and what packages are installed in your Conda/pip environment?

qinqian commented 1 month ago

@eb8680 I mean the incorrect figures in the first post of this issue as in ( I am running the unmodified notebook in a fresh environment (the current master branch).

I am running on macOS Sonoma with PyTorch 2.4.1.post2 and Pyro-ppl 1.8.6, pip install . actually downgrade my pyro-ppl==1.9.1 to pyro-ppl 1.8.6.

Here are the packages I installed in a condo environment:

eb8680 commented 1 month ago

@qinqian can you try installing pyro-ppl==1.9.1 after you install ChiRho but before you run the notebook?

qinqian commented 1 month ago

Yes, @eb8680 I tried this version and it fixed the issue above. Below are the figures I reproduced as the original notebook:

image image

What's the key difference between two versions?