schymans / ESSM_plotting

Python modules to simplify plotting of equations generated using ESSM
GNU General Public License v3.0
0 stars 0 forks source link

allow for automatic unit conversions #1

Open schymans opened 3 months ago

schymans commented 3 months ago

I just dug up an old code that allowed for automatic unit conversions, would be good to implement something similar:

def plot_expr2(xvar_min_max, yldata, yllabel=None, yrdata=None,
               yrlabel='', clf=True, npoints=100, ylmin=None, ylmax=None,
               yrmin=None, yrmax=None, xlabel=None, xunit=None,
               ylunit=None, yrunit=None, colors=None,
               loc_legend_left='best', loc_legend_right='right',
               linestylesl=['-', '--', '-.', ':'], 
               linestylesr=['-', '--', '-.', ':'],
               fontsize=None, fontsize_ticks=None, fontsize_labels=None,
               fontsize_legend=None,
               fig1=None, **args):
    '''
    Plot expressions as function of xvar from xmin to xmax. 

    **Examples:**

    from essm.variables import Variable
    from essm.variables.physics.thermodynamics import T_a
    from essm.equations.physics.thermodynamics import eq_nua, eq_ka
    vdict = Variable.__defaults__.copy()    
    expr = eq_nua.subs(vdict)
    exprr = eq_ka.subs(vdict)
    xvar = T_a
    yldata = [(expr.rhs, 'full'), (expr.rhs/2, 'half')]
    yrdata = exprr
    plot_expr2((T_a, 273, 373), yldata, yllabel = (nu_a), yrdata=yrdata)
    plot_expr2((T_a, 273, 373), yldata, yllabel = (nu_a), 
               yrdata=[(1/exprr.lhs, 1/exprr.rhs)],
               loc_legend_right='lower right')
    plot_expr2((T_a, 273, 373), expr)
    plot_expr2((T_a, 273, 373), yldata, yllabel = (nu_a))
    '''
    (xvar, xmin, xmax) = xvar_min_max
    if not colors:
        if yrdata is not None:
            colors = ['black', 'blue', 'red', 'green']
        else:
            colors = ['blue', 'black', 'red', 'green']
    if fontsize:
        fontsize_labels = fontsize
        fontsize_legend = fontsize
        fontsize_ticks = fontsize
    # Allows appending to an existing fig:
    if not fig1:
        plt.close
        plt.clf
        fig = plt.figure(**args)
    else: 
        fig = fig1
    # Units and label for x-axis:
    if not xunit:
        if hasattr(xvar, 'definition'): 
            xunit = derive_unit(xvar)
    if xunit != 1:
        strunit = ' (' + markdown(xunit) + ')'
    else: 
        strunit = ''
    if not xlabel:
        xlabel = '$'+latex(xvar)+'$'+ strunit
    else: 
        if not xlabel:
            xlabel = xvar 
    # Computing xvals in the desired and standard units
    xstep = (xmax - xmin)/npoints
    xvals = arange(xmin, xmax, xstep)
    # xvals will be used for plotting, xvalsSU in standard units for calculation
    if xunit:
        xminSU = convert_to(xmin*xunit, xvar.definition.unit) / xvar.definition.unit
        xmaxSU = convert_to(xmax*xunit, xvar.definition.unit) / xvar.definition.unit
        xstepSU = (xmaxSU - xminSU)/npoints
        xvalsSU = arange(xminSU, xmaxSU, xstepSU)
    else:
        xvalsSU = xvals

    # Plotting data on left axis
    ax1 =  fig.add_subplot(1, 1, 1)
    if hasattr(yldata, 'rhs'):
        # if yldata is equation
        yldata = (yldata.rhs, yldata.lhs)
    if type(yldata) is not list and type(yldata) is not tuple:
        # If only an expression given
        yldata = [(yldata, '')]
    if type(yldata[0]) is not tuple:
        yldata = [yldata]

    if yrdata is not None:
        color = colors[0]
    else:
        color = 'black'
    if ylmin:    ax1.set_ylim(ymin=float(ylmin))
    if ylmax:    ax1.set_ylim(ymax=float(ylmax))
    ax1.set_xlabel(xlabel)

    # Units and label for left y-axis:
    if not yllabel:
        if type(yldata) is tuple:
            yllabel = yldata[1]
        else:
            try: 
                yllabel = yldata[0][1]
            except Exception as e1:
                print(e1)
                print('yldata must be equation or list of (expr, name) tuples')               
    if type(yllabel) is not str: 
        if not ylunit:
            ylunit = derive_unit(yllabel)
        if ylunit != 1:
            strunit = ' (' + markdown(ylunit) + ')'
        else: 
            strunit = ''       
        yllabel = '$'+latex(yllabel)+'$'+ strunit  
    ax1.set_ylabel(yllabel, color=color)
    ax1.tick_params(axis='y', labelcolor=color)
    i = 0
    for (expr1, y1var) in yldata:
        linestyle = linestylesl[i]
        if yrdata is None:
            color = colors[i]
        i= i + 1
        try:
            if ylunit:
                # determine units of expr1 and convert to ylunit
                if not hasattr(y1var, 'definition'):
                    print('{0} is not a Variable'.format(y1var))
                    return
                stunits = derive_unit(y1var)
                y1vals = [convert_to(expr1.subs(xvar, dummy).n() *
                                     y1var.definition.unit, ylunit) / ylunit
                          for dummy in xvalsSU]
            else:
                y1vals = [expr1.subs(xvar, dummy).n() for dummy in xvalsSU]                   
            ax1.plot(xvals, y1vals, color=color, linestyle=linestyle, label=y1var)
        except Exception as e1:
            print([expr1.subs(xvar, dummy) for dummy in xvals])
            print(e1)
        # Setting units for yllabel    
        if type(yllabel) is not str: 
            if ylunit:
                ylunitstr = markdown(ylunit)
            else:
                ylunitstr = markdown(derive_unit(y1var))    

    if i > 1 or yrdata is not None:
        plt.legend(loc=loc_legend_left, fontsize=fontsize_legend)

    for item in ([ax1.xaxis.label, ax1.yaxis.label]):
        item.set_fontsize(fontsize_labels)
    ax1.tick_params(axis='both', which='major', labelsize=fontsize_ticks)
    fig.tight_layout()  # otherwise the right y-label is slightly clipped
    return fig
schymans commented 3 months ago

Here is another version:

import matplotlib.pyplot as plt
from sympy import latex
from sympy import N
from numpy import arange
from essm.variables.units import derive_unit, SI, Quantity
from essm.variables.utils import markdown

def plot_expr2(xvar_min_max, yldata, yllabel=None, yrdata=None,
               yrlabel='', clf=True, npoints=100, ylmin=None, ylmax=None,
               yrmin=None, yrmax=None, xlabel=None, xunit=None,
               ylunit=None, yrunit=None, colors=None,
               loc_legend_left='best', loc_legend_right='right',
               linestylesl=['-', '--', '-.', ':'], 
               linestylesr=['-', '--', '-.', ':'],
               fontsize=None, fontsize_ticks=None, fontsize_labels=None,
               fontsize_legend=None,
               fig1=None, **args):
    '''
    Plot expressions as function of xvar from xmin to xmax. 

    **Examples:**

    from essm.variables import Variable
    from essm.variables.physics.thermodynamics import T_a
    from essm.equations.physics.thermodynamics import eq_nua, eq_ka
    vdict = Variable.__defaults__.copy()    
    expr = eq_nua.subs(vdict)
    exprr = eq_ka.subs(vdict)
    xvar = T_a
    yldata = [(expr.rhs, 'full'), (expr.rhs/2, 'half')]
    yrdata = exprr
    plot_expr2((T_a, 273, 373), yldata, yllabel = (nu_a), yrdata=yrdata)
    plot_expr2((T_a, 273, 373), yldata, yllabel = (nu_a), 
               yrdata=[(1/exprr.lhs, 1/exprr.rhs)],
               loc_legend_right='lower right')
    plot_expr2((T_a, 273, 373), expr)
    plot_expr2((T_a, 273, 373), yldata, yllabel = (nu_a))
    '''
    (xvar, xmin, xmax) = xvar_min_max
    if not colors:
        if yrdata is not None:
            colors = ['black', 'blue', 'red', 'green']
        else:
            colors = ['blue', 'black', 'red', 'green']
    if fontsize:
        fontsize_labels = fontsize
        fontsize_legend = fontsize
        fontsize_ticks = fontsize
    # Allows appending to an existing fig:
    if not fig1:
        plt.close
        plt.clf
        fig = plt.figure(**args)
    else: 
        fig = fig1
    # Units and label for x-axis:
    if hasattr(xvar, 'definition'): 
        unit1 = derive_unit(xvar)
        if unit1 != 1:
            strunit = ' (' + markdown(unit1) + ')'
        else: 
            strunit = ''
        if not xlabel:
            xlabel = '$'+latex(xvar)+'$'+ strunit
    else: 
        if not xlabel:
            xlabel = xvar 

    # Computing xvals in the desired and standard units
    xstep = (xmax - xmin)/npoints
    xvals = arange(xmin, xmax, xstep)
    # xvals will be used for plotting, xvalsSU in standard units for calculation
    if xunit:
        xminSU = convert_to(xmin*xunit, xvar.definition.unit)
        xmaxSU = convert_to(xmax*xunit, xvar.definition.unit)
        xstepSU = (xmaxu - xminu)/npoints
        xvalsSU = arange(xminSU, xmaxSU, xstepSU)
    else:
        xvalsSU = xvals

    # Plotting data on left axis
    ax1 =  fig.add_subplot(1, 1, 1)
    if type (yldata) is not list and type(yldata) is not tuple:
        # If only an expression given
        yldata = [(yldata, '')]
    if type(yldata[0]) is not tuple:
        yldata = [yldata]

    if yrdata is not None:
        color = colors[0]
    else:
        color = 'black'
    if ylmin:    ax1.set_ylim(ymin=float(ylmin))
    if ylmax:    ax1.set_ylim(ymax=float(ylmax))
    ax1.set_xlabel(xlabel)
    # Units and label for left y-axis:
    if hasattr(yldata, 'lhs'):
        yldata = (yldata.rhs, yldata.lhs)
    if not yllabel:
        if type(yldata) is tuple:
            yllabel = yldata[1]
        else:
            try: 
                yllabel = yldata[0][1]
            except Exception as e1:
                print(e1)
                print('yldata must be equation or list of (expr, name) tuples')               
    if type(yllabel) is not str: 
        unit1 = derive_unit(yllabel)
        if unit1 != 1:
            strunit = ' (' + markdown(unit1) + ')'
        else: 
            strunit = ''       
        yllabel = '$'+latex(yllabel)+'$'+ strunit  
    ax1.set_ylabel(yllabel, color=color)
    ax1.tick_params(axis='y', labelcolor=color)
    i = 0
    for (expr1, y1var) in yldata:
        linestyle = linestylesl[i]
        if yrdata is None:
            color = colors[i]
        i= i + 1
        try:
            if ylunit:
                y1vals = [convert_to(expr1.subs(xvar, dummy).n() *
                                     y1var.definition.unit, ylunit) / ylunit
                          for dummy in xvalsSU]
            else:
                y1vals = [expr1.subs(xvar, dummy).n() for dummy in xvalsSU]                   
            ax1.plot(xvals, y1vals, color=color, linestyle=linestyle, label=y1var)
        except Exception as e1:
            print([expr1.subs(xvar, dummy) for dummy in xvals])
            print(e1)
        # Setting units for yllabel    
        if type(yllabel) is not str: 
            if ylunit:
                ylunitstr = markdown(ylunit)
            else:
                ylunitstr = markdown(derive_unit(y1var))    

    if i > 1 or yrdata is not None:
        plt.legend(loc=loc_legend_left, fontsize=fontsize_legend)

    # Plotting data on right axis
    if yrdata is not None:   
        ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis
        if yrlabel == '':
            if hasattr(yrdata, 'lhs'):
                yrlabel = yrdata.lhs 
        if type (yrdata) is not list and type(yrdata) is not tuple:
            # If only an expression given
            yrdata = [yrdata] 
        # Label on right axis    
        if type(yrlabel) is not str:         
            if yrunit:
                yrunitstr = markdown(yrunit)

        color = colors[1]
        ax2.set_ylabel(yrlabel, color=color)
        i = 0

        for item in yrdata:
            if type(item) is tuple:   # if item is tuple
                (expr2, y2var) = item
            else:
                try: 
                    (y2var, expr2) = (item.lhs, item.rhs)
                except Exception as e1:
                    print(e1)
                    print('yrdata must be a list of equations or tuples (var, expr)')
                    return
            linestyle = linestylesr[i]
            i = i + 1
            if not yrunit:
                yrunitstr = markdown(derive_unit(y2var))

            try:
                if yrunit:
                    y2vals = [convert_to(expr1.subs(xvar, dummy).n() *
                                     y2var.definition.unit, yrunit) / yrunit
                          for dummy in xvalsSU]
                else:                 
                    y2vals = [expr2.subs(xvar, dummy).n() for dummy in xvalsSU]
                ax2.plot(xvals, y2vals, color=color, linestyle=linestyle, label=y2var)
            except Exception as e1:
                print(expr2)
                print([expr2.subs(xvar, dummy).n() for dummy in xvals])
                print(e1)
            # Setting units for yrlabel    
            if type(yrlabel) is not str: 
                if yrunit:
                    yrunitstr = markdown(yrunit)
                else:
                    yrunitstr = markdown(derive_unit(yvar))

        yrlabel1 = '${0}$ ({1})'.format(latex(yrlabel), yrunitstr)
        ax2.set_ylabel(yrlabel1, color=color)
        ax2.tick_params(axis='y', labelcolor=color)
        if yrmin:    ax2.set_ylim(ymin=float(yrmin))
        if yrmax:    ax2.set_ylim(ymax=float(yrmax))
        leg=ax2.legend(loc=loc_legend_right, fontsize=fontsize_legend)
        ax2.add_artist(leg);
        for item in ([ax2.xaxis.label, ax2.yaxis.label]):
            item.set_fontsize(fontsize_labels)
        ax2.tick_params(axis='both', which='major', labelsize=fontsize_ticks)

    for item in ([ax1.xaxis.label, ax1.yaxis.label]):
        item.set_fontsize(fontsize_labels)
    ax1.tick_params(axis='both', which='major', labelsize=fontsize_ticks)
    fig.tight_layout()  # otherwise the right y-label is slightly clipped
    return fig