d2cml-ai / csdid

CSDID
https://d2cml-ai.github.io/csdid/index.html
MIT License
20 stars 5 forks source link

Make the plots with matplotlib #6

Closed carlosguevara1 closed 1 year ago

carlosguevara1 commented 1 year ago

def ggdid_attge(did_object, ylim=None, xlab=None, ylab=None, title="Group", xgap=1, ncol=1, legend=True, group=None, ref_line=0, theming=True, grtitle="Group", **kwargs):

G = len(np.unique(did_object['group']))
Y = len(np.unique(did_object['t']))
g = np.unique(did_object['group'])[np.argsort(np.unique(did_object['group']))].astype(int)
y = np.unique(did_object['t'])

results = pd.DataFrame({'year': np.tile(y, G)})
results['group'] = np.repeat(g, Y)
results['grtitle'] = grtitle + ' ' + results['group'].astype(str)
results['att'] = did_object['att']
results['att_se'] = did_object['se']
results['post'] = np.where(results['year'] >= results['group'], 1, 0)
results['year'] = results['year']
results['c'] = did_object['c']
legend_1 = False

if group is None:
    group = g
if any(group not in g for group in group):
    raise ValueError("Some of the specified groups do not exist in the data. Reporting all available groups.")
    group = g

fig, axes = plt.subplots(nrows=len(group), ncols=1)
handles = []
labels = []
for i, group_cat in enumerate(group):
    group_data = results.loc[results['group'] == group_cat]
    title = group_data['grtitle'].unique()[0]
    ax = axes[i]
    ax = gplot(group_data, ax, ylim, xlab, ylab, title, xgap, legend_1, ref_line, theming)
plt.tight_layout()
if legend is True:
    handles_ax, labels_ax = ax.get_legend_handles_labels()
    handles.extend(handles_ax)
    labels.extend(labels_ax)
    fig.legend(handles, labels, loc='lower center', fontsize='small')
    plt.subplots_adjust(bottom=0.15)
plt.show()
return fig

def gplot(ssresults, ax, ylim=None, xlab=None, ylab=None, title="Group", xgap=1, legend=True, ref_line=0, theming=True):

pre_points = ssresults.loc[ssresults['post'] == 0]
post_points = ssresults.loc[ssresults['post'] == 1]

ax.errorbar(pre_points['year'], pre_points['att'], yerr=pre_points['c']*pre_points['att_se'],
             fmt='o', markersize=5, color='#e87d72', ecolor='#e87d72', capsize=5, label='Pre')   
ax.errorbar(post_points['year'], post_points['att'], yerr=post_points['c']*post_points['att_se'],
             fmt='o', markersize=5, color='#56bcc2', ecolor='#56bcc2', capsize=5, label='Post')  

ax.set_xticks(list(range(int(min(ssresults['year'])), int(max(ssresults['year']))+1)))
ax.set_ylim(ylim)
ax.set_title(title)
ax.set_xlabel(xlab)    
ax.set_ylabel(ylab)    

handles, labels = ax.get_legend_handles_labels()    

if ref_line is not None:
    ax.axhline(ref_line, linestyle='dashed', color='#1F1F1F')
if theming:
    ax.set_facecolor('white')
    ax.set_title(title, color="#1F1F1F", fontweight="bold", fontsize=10)
    ax.spines['bottom'].set_color('#1F1F1F')
    ax.spines['left'].set_color('#1F1F1F')
    ax.tick_params(axis='x', colors='#1F1F1F')
    ax.tick_params(axis='y', colors='#1F1F1F')
    if not pre_points.empty and not post_points.empty:
        ax.legend(handles[0:2], labels[0:2], loc='lower center',fontsize='small')
    elif not pre_points.empty:
        ax.legend(handles[:1], labels[:1], loc='lower center',fontsize='small')
    elif not post_points.empty:
        ax.legend(handles[1:2], labels[1:2], loc='lower center',fontsize='small')    
if not legend:
    ax.legend().set_visible(False)

plt.show()

return ax
TJhon commented 1 year ago

fix with PR #7