BiomedSciAI / causallib

A Python package for modular causal inference analysis and model evaluations
Apache License 2.0
728 stars 97 forks source link

IPW computation #66

Closed winston-zillow closed 10 months ago

winston-zillow commented 10 months ago

Briefly read the codes around IPW, in particular this line:

# weight_matrix = 1.0 / probabilities
weight_matrix = probabilities.rdiv(1.0)

My understanding of IPW is that outcomes are weighted by: $w_i = \frac{a_i}{e_i} + \frac{(1 - a_i)}{(1 - e_i)}$. do I miss something or this is a bug?

ehudkr commented 10 months ago

Hi Winston, not a bug 🙃 🐞

The equation you presented is limited to binary treatment scenarios, while causallib's implementation scales to arbitrary number of treatment values. Reformulating your equation in the binary case, treated observations ($A_i=1$) are weighted with $\frac{1}{\Pr[A=1|X]}$ and control observations ($A_i=0$) are weighted with $\frac{1}{1-\Pr[A=1|X]} = \frac{1}{\Pr[A=0|X]}$. The generalization is therefore quite straightforward if we observe that each observation $i$ is weighted by the inverse probability of being assigned to their observed treatment group: $w_i=\frac{1}{\Pr[A=a_i|X=x_i]}$. Causallib's implementation tries to make the most generalized calculation by first considering for each $i$ the probabilities for every possible treatment (hence the probabilities matrix, which is similar to the output of a predict_proba()), and only slicing for the required values according to the treatment assignment at the very end.

Thanks for bringing this up, and sorry for causing any confusion. I hope I was able to clear some of it up. Please feel free if there's anything else.

ehudkr commented 10 months ago

To justify this further. There are other small benefits to this implementation that I have yet to seize. For example, generalizing the average treatment effect on the treated/controls (ATT/ATC) to any arbitrary treatment level v can be simply done by multiplying the inverse probability matrix (weight_matrix) by the corresponding probability column: weight_matrix *= probabilities[v]. Then the column corresponding to v is all ones ($p_{i,v}$ times its inverse) and all the rest of the columns are the odds with respect to v.

I hope this gives better context to the design decisions.

winston-zillow commented 10 months ago

I see. You turn the binary one-class label to a multi-class label first. That works. Thanks.