G = len(np.unique(did_object['group']))
Y = len(np.unique(did_object['t']))
g = np.unique(did_object['group'])[np.argsort(np.unique(did_object['group']))].astype(int)
y = np.unique(did_object['t'])
results = pd.DataFrame({'year': np.tile(y, G)})
results['group'] = np.repeat(g, Y)
results['grtitle'] = grtitle + ' ' + results['group'].astype(str)
results['att'] = did_object['att']
results['att_se'] = did_object['se']
results['post'] = np.where(results['year'] >= results['group'], 1, 0)
results['year'] = results['year']
results['c'] = did_object['c']
legend_1 = False
if group is None:
group = g
if any(group not in g for group in group):
raise ValueError("Some of the specified groups do not exist in the data. Reporting all available groups.")
group = g
fig, axes = plt.subplots(nrows=len(group), ncols=1)
handles = []
labels = []
for i, group_cat in enumerate(group):
group_data = results.loc[results['group'] == group_cat]
title = group_data['grtitle'].unique()[0]
ax = axes[i]
ax = gplot(group_data, ax, ylim, xlab, ylab, title, xgap, legend_1, ref_line, theming)
plt.tight_layout()
if legend is True:
handles_ax, labels_ax = ax.get_legend_handles_labels()
handles.extend(handles_ax)
labels.extend(labels_ax)
fig.legend(handles, labels, loc='lower center', fontsize='small')
plt.subplots_adjust(bottom=0.15)
plt.show()
return fig
def ggdid_attge(did_object, ylim=None, xlab=None, ylab=None, title="Group", xgap=1, ncol=1, legend=True, group=None, ref_line=0, theming=True, grtitle="Group", **kwargs):
def gplot(ssresults, ax, ylim=None, xlab=None, ylab=None, title="Group", xgap=1, legend=True, ref_line=0, theming=True):