bernardodionisi / differences

difference-in-differences in Python
https://bernardodionisi.github.io/differences/latest/
GNU General Public License v3.0
92 stars 19 forks source link

Treated Cohort Error #12

Closed achinmay17 closed 10 months ago

achinmay17 commented 10 months ago

Hi, I am trying to run Doubly Robust S-DID with unbalanced panel and varying base period. the control group is 'not_yet_treated' My code is as following:

    att_gt = ATTgt(data=diddata, cohort_name="course_month_end_date", base_period='varying', freq='M') 
    att_gt.fit(formula = formula, est_method='dr',control_group=control_group, progress_bar = True) 

however, I am getting following error which I am not able to understand

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File ~/anaconda3/lib/python3.11/site-packages/pandas/core/indexes/base.py:3790, in Index.get_loc(self, key)
   3789 try:
-> 3790     return self._engine.get_loc(casted_key)
   3791 except KeyError as err:

File index.pyx:152, in pandas._libs.index.IndexEngine.get_loc()

File index.pyx:181, in pandas._libs.index.IndexEngine.get_loc()

File pandas/_libs/hashtable_class_helper.pxi:7080, in pandas._libs.hashtable.PyObjectHashTable.get_item()

File pandas/_libs/hashtable_class_helper.pxi:7088, in pandas._libs.hashtable.PyObjectHashTable.get_item()

KeyError: '_course_month_end_date'

The above exception was the direct cause of the following exception:

KeyError                                  Traceback (most recent call last)
Cell In[176], line 15
     13 config = 'config' + str(i)
     14 att_gt = ATTgt(data=diddata, cohort_name="course_month_end_date", base_period='varying', freq='M') 
---> 15 att_gt.fit(formula = formula, est_method='dr',control_group=control_group, progress_bar = True)

File ~/anaconda3/lib/python3.11/site-packages/differences/attgt/attgt.py:597, in ATTgt.fit(self, formula, weights_name, control_group, base_delta, est_method, as_repeated_cross_section, boot_iterations, random_state, alpha, cluster_var, split_sample_by, n_jobs, backend, progress_bar)
    593     is_panel, is_balanced_panel = False, False
    595 # --------------------- filter cohort times ---------------------
--> 597 group_time = self.group_time(feasible=False)
    599 # todo: if balance panel if filter? should already be done if balance 2*2
    600 filter_gt = None

File ~/anaconda3/lib/python3.11/site-packages/differences/attgt/attgt.py:213, in ATTgt.group_time(self, feasible)
    204 """
    205 Returns
    206 -------
    207 a list of dictionaries where each dictionary keys are:
    208 ``cohort``, ``base_period``, ``time``, (``stratum``)
    209 """
    211 if self.base_period_type == "varying":
    212     cbt = varying_base_period(
--> 213         cohort_ar=self._cohorts,
    214         time_ar=self._times,
    215         anticipation=self.anticipation,
    216     )
    218 if self.base_period_type == "universal":
    219     cbt = universal_base_period(
    220         cohort_ar=self._cohorts,
    221         time_ar=self._times,
    222         anticipation=self.anticipation,
    223     )

File ~/anaconda3/lib/python3.11/site-packages/differences/attgt/attgt.py:162, in ATTgt._cohorts(self)
    160 @property
    161 def _cohorts(self):
--> 162     cohorts = np.array(sorted(self.data[self.cohort_name].dropna().unique()))
    163     return cohorts[cohorts > self._times[0] + self.anticipation]

File ~/anaconda3/lib/python3.11/site-packages/pandas/core/frame.py:3896, in DataFrame.__getitem__(self, key)
   3894 if self.columns.nlevels > 1:
   3895     return self._getitem_multilevel(key)
-> 3896 indexer = self.columns.get_loc(key)
   3897 if is_integer(indexer):
   3898     indexer = [indexer]

File ~/anaconda3/lib/python3.11/site-packages/pandas/core/indexes/base.py:3797, in Index.get_loc(self, key)
   3792     if isinstance(casted_key, slice) or (
   3793         isinstance(casted_key, abc.Iterable)
   3794         and any(isinstance(x, slice) for x in casted_key)
   3795     ):
   3796         raise InvalidIndexError(key)
-> 3797     raise KeyError(key) from err
   3798 except TypeError:
   3799     # If we have a listlike key, _check_indexing_error will raise
   3800     #  InvalidIndexError. Otherwise we fall through and re-raise
   3801     #  the TypeError.
   3802     self._check_indexing_error(key)

KeyError: '_course_month_end_date'

The 'course_month_end_date' column does exists in the dataframe. I will really appreciate your help in debugging this.

bernardodionisi commented 10 months ago

Hi,

I can't reproduce the error (see below for working examples)

Is there any way you could share a bit more information regarding the time and cohort columns? You can check the processed data (after initializing ATTgt) using (in your instance) att_gt.data and make sure the columns are correctly specified.

It seems you are using monthly data and taking advantage of the freq argument which casts the datetimes to integers, thus the time and cohort columns should be integer data types in your instance data attribute.

If the error arises when casting to integers using the freq argument you could use integers time and cohort columns by casting those yourself.


One way to see how the functionality works with datetime columns is to try the following code, which runs fine for me:

from differences import simulate_data, ATTgt

panel_data = (
    simulate_data(datetime=True)  # yearly data where the date and cohort are datetimes
    .sample(frac=0.9)  # making it unbalanced
)

att_gt = ATTgt(data=panel_data, cohort_name='cohort', base_period="varying", freq="YS")

print(att_gt.is_balanced_panel)

att_gt.fit("y ~ x0", est_method='dr', control_group="not_yet_treated", progress_bar=True)
import pandas as pd
from differences import simulate_data, ATTgt

to_month = {
    1900: "1900-01-01",
    1901: "1900-02-01",
    1902: "1900-03-01",
    1903: "1900-04-01",
    1904: "1900-05-01",
    1905: "1900-06-01",
    1906: "1900-07-01",
    1907: "1900-08-01",
    1908: "1900-09-01",
    1909: "1900-10-01",
    1910: "1900-11-01",
}

panel_data = (
    simulate_data()
    .sample(frac=0.9)
    .reset_index()
    .assign(
        time=lambda x: pd.to_datetime(x["time"].map(to_month)),
        cohort=lambda x: pd.to_datetime(x["cohort"].map(to_month)),
    )
    .set_index(["entity", "time"])
)

att_gt = ATTgt(data=panel_data, cohort_name='cohort', base_period="varying", freq="MS")

print(att_gt.is_balanced_panel)

att_gt.fit("y ~ x0", est_method='dr', control_group="not_yet_treated", progress_bar=True)
achinmay17 commented 10 months ago

I could find out the error using the debugging method. There was some duplication on the id level. Thanks for the prompt response!