biomedia-mira / causal-gen

(ICML 2023) High Fidelity Image Counterfactuals with Probabilistic Causal Models
https://arxiv.org/abs/2306.15764
MIT License
54 stars 7 forks source link

Request for python version #10

Open changeworld-star opened 4 months ago

changeworld-star commented 4 months ago

Thank you for sharing this project. Could you please specify which version of Python you used for this code? This will help me ensure compatibility and avoid any issues while running it.

changeworld-star commented 4 months ago

I noticed that the requirements file contains two versions of pandas: pandas==1.5.3 and pandas==2.0.1. This causes an error due to conflicting dependencies. Could you please clarify which version of pandas should be used for this project?

fabio-deep commented 4 months ago

Hi! I believe the minimum Python version that should work is 3.8. As for pandas, I think this came from auto-generating the requirements file, I'd try the 2.0.1 version and see.

Best, Fabio

changeworld-star commented 4 months ago

Dear Fabio: Hello! Firstly, I want to express my sincere gratitude for your previous assistance regarding the Python version for configuring the environment. With your guidance, I successfully set up the environment using the recommended version.

I am currently working with your codebase and have encountered a couple of questions regarding its implementation details. I would appreciate it if you could provide some clarification on the following:

In the train_pgm.py script, on line 132, there is the statement bs = batch["x"].shape[0]. Could you please clarify what "x" represents in this context? My batch dictionary does not have a key named "x", and I am unsure of its intended meaning.

Similarly, in the flow_pgm.py script, the svi_model method contains the line with pyro.plate("observations", obs["x"].shape[0]):. Here, could you explain what "x" signifies within obs["x"].shape[0]? I assume it relates to the input data or features, but I would like to confirm its purpose.

Your support has been invaluable, and I truly appreciate the time and effort you have dedicated to helping me. Thank you for your assistance. I look forward to hearing from you soon.

Best regard Ling Lin

fabio-deep commented 4 months ago

No worries, happy to help!

To answer your question, the "x" key in the batch dictionary refers to an input image variable - it is returned by the respective dataloader.

If you only ever plan to use probabilistic graphical models (PGMs) that do not involve an "x" variable at all then you can edit the bits of code that depend on it.

The following might help with any lingering doubts, lifted from here:

Pyro models can use the context manager pyro.plate to declare that certain batch dimensions are independent. Inference algorithms can then take advantage of this independence to e.g. construct lower variance gradient estimators or to enumerate in linear space rather than exponential space. An example of an independent dimension is the index over data in a minibatch: each datum should be independent of all others.

Best, Fabio

changeworld-star commented 4 months ago

Dear Fabio: I hope this message finds you well. Thank you for your previous response. Your explanation helped me understand the role of the "x" key in the batch dictionary. I realized that my issue was due to a problem with how I was loading the dataset. After making the necessary adjustments, I successfully resolved the issue.

I now have another question regarding the checkpoint files needed to run train_cf.py. Here is my current workflow:

I ran main.py to obtain the checkpoint.pt file. I then ran train_pgm.py to get another checkpoint.pt file. However, it seems that train_cf.py requires three checkpoint files. I understand that:

--pgm_path should correspond to the checkpoint file generated by train_pgm.py. --vae_path should correspond to the checkpoint file generated by main.py. My question is regarding the --predictor_path argument:

How do I obtain the pre-trained checkpoint file for predictor_path? I would greatly appreciate any guidance you can provide on this matter. I hope my questions do not cause any inconvenience.

Thank you for your time and assistance.

Best regards, Ling Lin

fabio-deep commented 4 months ago

My apologies for the delay.

Note that train_cf.py is for the optional counterfactual training/fine-tuning step described in the paper, which you may not need depending on how well your base model performs without it.

If you do in fact want to use counterfactual training, then you need to first train some classifiers/regressors for each parent of x which will give you the predictor checkpoint you're missing. This can be done using train_pgm.py under the sup_aux setup for example, although you may need some additional changes to train_pgm.py and flow_pgm.py to accommodate new datasets/scenarios.

Best, Fabio

changeworld-star commented 3 months ago

Hi Fabio, Thanks for your previous response. Your assistance has been invaluable. Could you please share the plotting code for the example result from the UKBB dataset that you mentioned on the GitHub page? Your help would be greatly appreciated!

changeworld-star commented 3 months ago

Hi Fabio,

I hope this message finds you well. I am writing to follow up on my previous request regarding the plotting code for the example result from the UKBB dataset mentioned on the GitHub page.

Your work has been incredibly insightful and has greatly benefited my project. I deeply appreciate the effort and expertise you have put into it. If you could share the plotting code, it would be immensely helpful and I would be very grateful.

Thank you once again for your time and for your outstanding contributions.