Open fatalfeel opened 1 month ago
def plot_probabilities(self, dataframe, name, exponential):
"""
Sets chart parameters, generates the chart, and saves it.
dataframe: dataframe, the dataframe to be plotted
name: sting, chart title
exponential: boolean, whether to plot exponentially weighted output data or not
"""
dataframe.rename(columns=self.prediction_names, inplace=True)
if exponential == True:
dataframe = self.exponential_conversion(dataframe=dataframe)
is_recession = (dataframe['Recession'] == 1)
is_not_recession = (dataframe['Recession'] == 0)
dataframe.loc[is_recession, 'Recession'] = 100
dataframe.loc[is_not_recession, 'Recession'] = -1
dataframe = dataframe[['Dates'] + list(self.prediction_names.values())]
plt.figure(figsize=(12, 6))
subplot = seaborn.lineplot(data=pd.melt(dataframe, ['Dates']), x='Dates', y='value', hue='variable')
subplot.set_title(name, fontsize=20)
subplot.set_ylabel('Probability')
subplot.set_ylim((0, 1))
lastjanidx = 0
mticks = plt.gca().xaxis.get_major_ticks()
xlabels = subplot.get_xticklabels()
for i, label in enumerate(dataframe['Dates']):
strymd = label.split('-')
year = int(strymd[0])
month = int(strymd[1])
date = int(strymd[2])
if month == 1 and date == 1:
lastjanidx = i
if year % 10 == 0 and month == 1 and date == 1:
mticks[i].tick1line.set_markersize(5)
xlabels[i].set_text(strymd[0])
xlabels[i].set_color('black')
else:
mticks[i].tick1line.set_markersize(0)
xlabels[i].set_text('')
xlabels[i].set_color('None')
strymd = dataframe['Dates'][lastjanidx].split('-')
mticks[lastjanidx].tick1line.set_markersize(5)
xlabels[lastjanidx].set_text(strymd[0])
xlabels[lastjanidx].set_color('red')
# use set_xticks avoid UserWarning: FixedFormatter should only be used together with FixedLocator
plt.gca().set_xticks(subplot.get_xticks())
plt.gca().set_xticklabels(xlabels)
# Loop to detect all blocks of recession and highlight them
start = None
for i in range(len(is_recession)):
if start == None and is_recession[i] == True: # Start of a recession period
start = i
elif start != None and is_recession[i] == False: # End of a recession period
plt.gca().axvspan(dataframe['Dates'].iloc[start], dataframe['Dates'].iloc[i - 1], color='grey', alpha=0.2)
start = None # Reset start for next detection
if start != None:
plt.gca().axvspan(dataframe['Dates'].iloc[start], dataframe['Dates'].iloc[len(is_recession) - 1], color='grey', alpha=0.2)
self.pdf_object.savefig()
fixed as follows