py-why / dowhy

DoWhy is a Python library for causal inference that supports explicit modeling and testing of causal assumptions. DoWhy is based on a unified language for causal inference, combining causal graphical models and potential outcomes frameworks.
https://www.pywhy.org/dowhy
MIT License
6.88k stars 916 forks source link

Vectorize operations for propensity score matching #1179

Closed rahulbshrestha closed 4 weeks ago

rahulbshrestha commented 1 month ago

This PR addresses this issue by introducing vectorized operations instead of the existing for-loops. This should speed up operations for large datasets.

This PR is a work in progress, and the remaining tasks include:

amit-sharma commented 1 month ago

Thanks for starting this, @rahulbshrestha . Let us know once the PR is ready for review.

rahulbshrestha commented 1 month ago

I ran some tests to check if the values of att and atc are the same before and after changes made in this PR:


### PREVIOUS IMPLEMENTATION
        att = 0
        numtreatedunits = treated.shape[0]
        treated_outcomes_old = []
        control_outcomes_old = []

        for i in range(numtreatedunits):

            treated_outcome = treated.iloc[i][self._target_estimand.outcome_variable[0]].item()
            control_outcome = control.iloc[indices[i]][self._target_estimand.outcome_variable[0]].item()
            treated_outcomes_old.append(treated_outcome)
            control_outcomes_old.append(control_outcome)
            att += treated_outcome - control_outcome

        att /= numtreatedunits

        print('Checking values of ATT: ')
        print('ATT (before): ', att)

        outcome_variable = self._target_estimand.outcome_variable[0]
        treated_outcomes = treated[outcome_variable]
        control_outcomes = list(control.iloc[indices.flatten()][outcome_variable])

        att = (treated_outcomes - control_outcomes).mean()

        print('ATT (after): ', att)
        print('Treated outcomes ', treated_outcomes_old == treated_outcomes)
        print('Control outcomes', control_outcomes_old == control_outcomes)

and the results when running on some test data:

Checking values of ATT: 
ATT (before):  10.923190922091228
ATT (after):  10.923190922091242
Treated outcomes  True
Control outcomes True
Checking values of ATC: 
ATC (before):  10.506587873468016
ATC (after):  10.506587873468012
Treated outcomes  True
Control outcomes True

Both lists, treated outcomes and control outcomes are the same before and after the changes I made. The ATT and ATC seems to be off by a couple digits after averaging (check last 3 digits in the example above), which is probably a rounding error. Is this a problem @amit-sharma?

rahulbshrestha commented 1 month ago

Hey @amit-sharma! I think this PR is ready to be merged :)

amit-sharma commented 4 weeks ago

@all-contributors please add @rahulbshrestha for code.

allcontributors[bot] commented 4 weeks ago

@amit-sharma

I've put up a pull request to add @rahulbshrestha! :tada: