CityofToronto / bdit_plotting_gallery

A gallery of static Python plots that the Big Data Innovation Team has produced
GNU General Public License v3.0
1 stars 0 forks source link

Add vertical bar chart option and example #5

Closed radumas closed 4 months ago

radumas commented 11 months ago
def bar_chart(data_in, xlab, ylab, horizontal=False, **kwargs):
        """Creates a bar chart

        Parameters
        -----------
        data : dataframe
            Data for the bar chart. The dataframe must have 2 columns, the first representing the y ticks, and the second representing the data
        xlab : str
            Label for the x axis.
        ylab : str
            Label for the y axis.
        horizontal: bool, Alignment of bar_chart
            True if horizontal else vertical 
        xymax : int, optional, default is the max s value
            The max value of the y axis
        xymin : int, optional, default is 0
            The minimum value of the x axis
        precision : int, optional, default is -1
            Decimal places in the annotations

        xyinc : int, optional
            The increment of ticks on the x axis/y axis depending on horizontal bool value.

        Returns 
        --------
        fig
            Matplotlib fig object
        ax 
            Matplotlib ax object

        """
        func()
        data = data_in.copy(deep=True)

        data.columns = ['name', 'values1']

        xymin = kwargs.get('xymin', 0)
        xymax = kwargs.get('xymax', None)
        precision = kwargs.get('precision', 0)

        xymax_flag = True
        if xymax == None:
            xymax = data['values1'].max()
            xymax_flag = False

        delta = (xymax - xymin)/4
        i = 0
        while True:
            if delta < 10:
                break
            delta /= 10
            i += 1
        xyinc = kwargs.get('xyinc', int(round(delta+1)*pow(10,i)))

        if xymax_flag == True:
            upper = xymax
        else:
            upper = int(4*xyinc+xymin)

        ind = np.arange(len(data))

        fig, ax = plt.subplots(dpi=450.0)
        fig.set_size_inches(6.1, 4.2)
#         fig.set_size_inches(6.1, len(data)*0.7)
        ax.grid(color='k', linestyle='-', linewidth=0.25)
        if(horizontal):
            p2 = ax.barh(ind, data['values1'], 0.75, align='center', color = colour.purple)
            ax.xaxis.set_major_formatter(mpl.ticker.StrMethodFormatter('{x:,.0f}'))
            ax.xaxis.grid(True)
            ax.yaxis.grid(False)
            ax.set_yticks(ind)
            ax.set_xlim(0,upper)
            ax.set_yticklabels(data['name'])
            ax.set_xlabel(xlab,  horizontalalignment='left', x=0, labelpad=10, fontname = font.normal, fontsize=10, fontweight = 'bold')
            if (ylab is not None):
                ax.set_ylabel(ylab, labelpad=10, fontname = font.normal, fontsize=10, fontweight = 'bold')
            plt.xticks(range(xymin,upper+int(0.1*xyinc), xyinc), fontname = font.normal, fontsize =10)
            plt.yticks( fontname = font.normal, fontsize =10)
        else:
            p2 = ax.bar(ind, data['values1'], 1.0, align='center', color = colour.purple)
            ax.yaxis.set_major_formatter(mpl.ticker.StrMethodFormatter('{x:,.0f}'))
            ax.yaxis.grid(True)
            ax.xaxis.grid(False)
            ax.set_xticks(ind)
            ax.set_ylim(0, upper)
            ax.set_xticklabels(data['name'], rotation=0.0)
            ax.set_ylabel(ylab, labelpad=10, fontname = font.normal, fontsize=10, fontweight = 'bold')
            if (xlab is not None):
                ax.set_xlabel(xlab, labelpad=10, fontname = font.normal, fontsize=10, fontweight = 'bold')
            plt.yticks(range(xymin, upper+int(0.1*xyinc), xyinc), fontname = font.normal, fontsize =10)
            plt.xticks( fontname = font.normal, fontsize =10)            

        ax.set_facecolor('xkcd:white')
        j=0

        if precision < 1:
            data['values1'] = data['values1'].astype(int)

        j=0
        if (horizontal == True and (precision != -1)):
            for i in data['values1']:
                if i < 0.1*upper:
                    ax.annotate(str(format(round(i, precision), ',')), xy=(i+0.015*upper, j-0.05), ha = 'left', color = 'k', fontname = font.normal, fontsize=10)
                else:
                    ax.annotate(str(format(round(i, precision), ',')), xy=(i-0.015*upper, j-0.05), ha = 'right', color = 'w', fontname = font.normal, fontsize=10)
                j=j+1
        elif (horizontal == False and (precision != -1)):
            for i in data['values1']:
                if i < 0.1*upper:
                    ax.annotate(str(format(round(i, precision), ',')), xy=(j-0.15, i+0.015*upper), ha = 'left', color = 'k', fontname = font.normal, fontsize=10, rotation=90.)
                else:
                    ax.annotate(str(format(round(i, precision), ',')), xy=(j+0.15, i-0.06*upper), ha = 'right', color = 'w', fontname = font.normal, fontsize=10, rotation=90.)
                j=j+1

        return fig, ax