pymc-labs / CausalPy

A Python package for causal inference in quasi-experimental settings
https://causalpy.readthedocs.io
Apache License 2.0
875 stars 63 forks source link

Synthetic control: `IndexError` in `.plot()` method when using a data frame with `datetime` index #148

Closed AlxndrMlk closed 1 year ago

AlxndrMlk commented 1 year ago

Hi,

first thanks for your great work with CausalPy!

I experienced the following issue:

When using a data frame with datetime-typed index the .plot() method of cp.pymc_experiments.SyntheticControl class returns IndexError.

Code:

data = ... # A data frame with datetime-type index

treatment_index = '2022-10-28'

# Build the model
model = cp.pymc_models.WeightedSumFitter(sample_kwargs={'target_accept': 0.95})

# Run the experiment
results = cp.pymc_experiments.SyntheticControl(
    data,
    treatment_index,
    formula="y ~ 0 + x1 + x2 + x3",
    model=model,
)

# Plot results
results.plot(plot_predictors=True)
plt.show()

This results in the following error:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
File ~\anaconda3\envs\causal-pymc\lib\site-packages\matplotlib\axis.py:1665, in Axis.convert_units(self, x)
   1664 try:
-> 1665     ret = self.converter.convert(x, self.units, self)
   1666 except Exception as e:

File ~\anaconda3\envs\causal-pymc\lib\site-packages\matplotlib\dates.py:1980, in _SwitchableDateConverter.convert(self, *args, **kwargs)
   1979 def convert(self, *args, **kwargs):
-> 1980     return self._get_converter().convert(*args, **kwargs)

File ~\anaconda3\envs\causal-pymc\lib\site-packages\matplotlib\dates.py:1909, in DateConverter.convert(value, unit, axis)
   1903 """
   1904 If *value* is not already a number or sequence of numbers, convert it
   1905 with `date2num`.
   1906 
   1907 The *unit* and *axis* arguments are not used.
   1908 """
-> 1909 return date2num(value)

File ~\anaconda3\envs\causal-pymc\lib\site-packages\matplotlib\dates.py:449, in date2num(d)
    448     return d
--> 449 tzi = getattr(d[0], 'tzinfo', None)
    450 if tzi is not None:
    451     # make datetime naive:

IndexError: too many indices for array: array is 0-dimensional, but 1 were indexed

The above exception was the direct cause of the following exception:

ConversionError                           Traceback (most recent call last)
Cell In[49], line 2
      1 # Plot results
----> 2 results.plot(plot_predictors=True)
      3 plt.show()

File ~\anaconda3\envs\causal-pymc\lib\site-packages\causalpy\pymc_experiments.py:231, in SyntheticControl.plot(self, plot_predictors)
    229 def plot(self, plot_predictors=False):
    230     """Plot the results"""
--> 231     fig, ax = super().plot()
    232     if plot_predictors:
    233         # plot control units as well
    234         ax[0].plot(self.datapre.index, self.pre_X, "-", c=[0.8, 0.8, 0.8], zorder=1)

File ~\anaconda3\envs\causal-pymc\lib\site-packages\causalpy\pymc_experiments.py:200, in TimeSeriesExperiment.plot(self)
    198 # Intervention line
    199 for i in [0, 1, 2]:
--> 200     ax[i].axvline(
    201         x=self.treatment_time,
    202         ls="-",
    203         lw=3,
    204         color="r",
    205     )
    207 ax[0].legend(
    208     handles=(h_tuple for h_tuple in handles),
    209     labels=labels,
    210     fontsize=LEGEND_FONT_SIZE,
    211 )
    213 return (fig, ax)

File ~\anaconda3\envs\causal-pymc\lib\site-packages\matplotlib\axes\_axes.py:825, in Axes.axvline(self, x, ymin, ymax, **kwargs)
    822 xmin, xmax = self.get_xbound()
    824 # Strip away the units for comparison with non-unitized bounds.
--> 825 xx, = self._process_unit_info([("x", x)], kwargs)
    826 scalex = (xx < xmin) or (xx > xmax)
    828 trans = self.get_xaxis_transform(which='grid')

File ~\anaconda3\envs\causal-pymc\lib\site-packages\matplotlib\axes\_base.py:2603, in _AxesBase._process_unit_info(self, datasets, kwargs, convert)
   2601             if dataset_axis_name == axis_name and data is not None:
   2602                 axis.update_units(data)
-> 2603 return [axis_map[axis_name].convert_units(data)
   2604         if convert and data is not None else data
   2605         for axis_name, data in datasets]

File ~\anaconda3\envs\causal-pymc\lib\site-packages\matplotlib\axes\_base.py:2603, in <listcomp>(.0)
   2601             if dataset_axis_name == axis_name and data is not None:
   2602                 axis.update_units(data)
-> 2603 return [axis_map[axis_name].convert_units(data)
   2604         if convert and data is not None else data
   2605         for axis_name, data in datasets]

File ~\anaconda3\envs\causal-pymc\lib\site-packages\matplotlib\axis.py:1667, in Axis.convert_units(self, x)
   1665     ret = self.converter.convert(x, self.units, self)
   1666 except Exception as e:
-> 1667     raise munits.ConversionError('Failed to convert value(s) to axis '
   1668                                  f'units: {x!r}') from e
   1669 return ret

ConversionError: Failed to convert value(s) to axis units: '2022-10-28'
drbenvincent commented 1 year ago

Thanks for the bug report. My hunch is that treatment_index needs to be a date time, not a string. Can you see if it works when replacing with pd.to_datetime('2022-10-28').

Do let me know if that fixes it and I'll add in automatic conversion.

AlxndrMlk commented 1 year ago

@drbenvincent thank you, converting to date time worked!

I realized that it's actually how you'd do it in pandas as well.