tzhangwps / Recession-Predictor

Project description: https://medium.com/p/recession-prediction-using-machine-learning-de6eee16ca94?source=email-2adc3d3cd2ed--writer.postDistributed&sk=2f1dab9738769f9658634e61576a08bd
MIT License
30 stars 28 forks source link

pdf plot x label black block #22

Open fatalfeel opened 1 month ago

fatalfeel commented 1 month ago

fixed as follows

def plot_probabilities(self, dataframe, name, exponential):
    """
    Sets chart parameters, generates the chart, and saves it.

    dataframe: dataframe, the dataframe to be plotted

    name: sting, chart title

    exponential: boolean, whether to plot exponentially weighted output data or not
    """
    dataframe.rename(columns=self.prediction_names, inplace=True)
    if exponential == True:
        dataframe = self.exponential_conversion(dataframe=dataframe)
    is_recession = dataframe['Recession'] == 1
    is_not_recession = dataframe['Recession'] == 0
    dataframe.loc[is_recession, 'Recession'] = 100
    dataframe.loc[is_not_recession, 'Recession'] = -1
    dataframe   = dataframe[['Dates'] + list(self.prediction_names.values())]

    plt.figure(figsize=(15, 5))
    sbplot = seaborn.lineplot(data=pd.melt(dataframe, ['Dates']), x='Dates', y='value', hue='variable')
    sbplot.set_title(name, fontsize=20)
    sbplot.set_ylabel('Probability')
    sbplot.set_ylim((0, 1))
    plt.gcf().canvas.draw() #sbplot(AxesSubplot) draw to main plot canvas

    xticks  = plt.gca().get_xticks()
    xlabels = plt.gca().get_xticklabels()
    for i, label in enumerate(xlabels):
        year    = int(label.get_text().split('-')[0])
        month   = int(label.get_text().split('-')[1])
        date    = int(label.get_text().split('-')[2])
        if year % 10 == 0 and month == 1 and date == 1:
            label.set_text(str(year))
        else:
            label.set_text('')

    # set_xticks avoid UserWarning: FixedFormatter should only be used together with FixedLocator
    plt.gca().set_xticks(xticks)
    plt.gca().set_xticklabels(xlabels)
    plt.gca().tick_params(axis='x', which='major', length=0, width=0) #remove tick line

    self.pdf_object.savefig()
fatalfeel commented 1 month ago

V2 fixed

def plot_probabilities(self, dataframe, name, exponential):
    """
    Sets chart parameters, generates the chart, and saves it.

    dataframe: dataframe, the dataframe to be plotted

    name: sting, chart title

    exponential: boolean, whether to plot exponentially weighted output data or not
    """
    dataframe.rename(columns=self.prediction_names, inplace=True)
    if exponential == True:
        dataframe = self.exponential_conversion(dataframe=dataframe)
    is_recession        = (dataframe['Recession'] == 1)
    is_not_recession    = (dataframe['Recession'] == 0)
    dataframe.loc[is_recession, 'Recession']        = 100
    dataframe.loc[is_not_recession, 'Recession']    = -1
    dataframe = dataframe[['Dates'] + list(self.prediction_names.values())]

    plt.figure(figsize=(12, 6))
    subplot = seaborn.lineplot(data=pd.melt(dataframe, ['Dates']), x='Dates', y='value', hue='variable')
    subplot.set_title(name, fontsize=20)
    subplot.set_ylabel('Probability')
    subplot.set_ylim((0, 1))

    lastjanidx  = 0
    mticks      = plt.gca().xaxis.get_major_ticks()
    xlabels     = subplot.get_xticklabels()
    for i, label in enumerate(dataframe['Dates']):
        strymd  = label.split('-')
        year    = int(strymd[0])
        month   = int(strymd[1])
        date    = int(strymd[2])
        if month == 1 and date == 1:
            lastjanidx = i
        if year % 10 == 0 and month == 1 and date == 1:
            mticks[i].tick1line.set_markersize(5)
            xlabels[i].set_text(strymd[0])
            xlabels[i].set_color('black')
        else:
            mticks[i].tick1line.set_markersize(0)
            xlabels[i].set_text('')
            xlabels[i].set_color('None')

    strymd = dataframe['Dates'][lastjanidx].split('-')
    mticks[lastjanidx].tick1line.set_markersize(5)
    xlabels[lastjanidx].set_text(strymd[0])
    xlabels[lastjanidx].set_color('red')

    # use set_xticks avoid UserWarning: FixedFormatter should only be used together with FixedLocator
    plt.gca().set_xticks(subplot.get_xticks())
    plt.gca().set_xticklabels(xlabels)

    # Loop to detect all blocks of recession and highlight them
    start = None
    for i in range(len(is_recession)):
        if start == None and is_recession[i] == True:  # Start of a recession period
            start = i
        elif start != None and is_recession[i] == False:  # End of a recession period
            plt.gca().axvspan(dataframe['Dates'].iloc[start], dataframe['Dates'].iloc[i - 1], color='grey', alpha=0.2)
            start = None  # Reset start for next detection
    if start != None:
        plt.gca().axvspan(dataframe['Dates'].iloc[start], dataframe['Dates'].iloc[len(is_recession) - 1], color='grey', alpha=0.2)

    self.pdf_object.savefig()