matheusfacure / python-causality-handbook

Causal Inference for the Brave and True. A light-hearted yet rigorous approach to learning about impact estimation and causality.
https://matheusfacure.github.io/python-causality-handbook/landing-page.html
MIT License
2.61k stars 456 forks source link

Ch 25: `join_weights()` Function Correction #317

Closed trevorvogel closed 1 year ago

trevorvogel commented 1 year ago

I'm proposing a change to join_weights() in the synthetic DiD section. As it stands, the function will apply a uniform time weight to post treatment periods and a uniform unit weight to treated units. But, depending on the proportion of pre periods to post periods and control to treated units, the current calculation may not generate the correct weights (1/(number of post periods) and 1/(number of treated units), respectively). The means of the treat_col and post_col are correct in some cases, but this method is not generalizable to all panel data.

Take, for example a simple panel with 5 units and 4 time periods, of which the last two have active treatment. Suppose that 2 of the units are treated. In this panel, the treated unit weights will be 0.4 (the mean of treat_col) when they should be 1/2. In this case, the mean of the post_col is 1/2, which is also equal to 1 divided by the number of post-treatment time periods, resulting in correct post time period weights. If there was an additional pre-treatment time period in the panel, however, the weights that result from averaging post_col would be wrong.

The following should generalize join_weights() to apply the correct uniform weighting in all cases: ` def join_weights(data, unit_w, time_w, year_col, state_col, treat_col, post_col):

       n_treat = data[[state_col, treat_col]].drop_duplicates(subset=[state_col, treat_col])[treat_col].sum()
       t_post = data[[year_col, post_col]].drop_duplicates(subset=[year_col, post_col])[post_col].sum()

       return (
                 data
                   .set_index([year_col, state_col])
                   .join(time_w)
                   .join(unit_w)
                   .reset_index()
                   .fillna({time_w.name: 1/t_post,
                               unit_w.name: 1/n_treat})
                   .assign(**{"weights": lambda d: (d[time_w.name]*d[unit_w.name]).round(10)})
                   .astype({treat_col:int, post_col:int}))

`

matheusfacure commented 1 year ago

Fixed with https://github.com/matheusfacure/python-causality-handbook/issues/333