arviz-devs / arviz

Exploratory analysis of Bayesian models with Python
https://python.arviz.org
Apache License 2.0
1.59k stars 393 forks source link

Labeling forest plots with alternate names (and colors) for variables #1467

Open elnjensen opened 3 years ago

elnjensen commented 3 years ago

Short Description

I'm using arviz to make plots from a trace generated with pymc3, and I'd like more descriptive labels on the plots than the variable names themselves.

Code Example or link

For example, using plot_forest I created the ridgeplot shown below:

import arviz as az
axes = az.plot_forest(trace, var_names=['r3', 'r0', 'r1', 'r2'], combined=True, 
                    kind='ridgeplot', figsize=(4,4), ridgeplot_quantiles=[.16, .5, .84],
                    ridgeplot_truncate=False, hdi_prob=0.95,
                    ridgeplot_overlap=4, ridgeplot_alpha=0.5)
plt.xlabel('$R_p/R_*$')
plt.ylabel('Band')

example_ridgeplot

I'd like to label the y axis with something other than 'r0', 'r1', etc. If there's a general solution, the same desire would apply to plots made with plot_trace, for example.

Ideally I'd also like each variable to be plotted in a different color if possible as well.

Thanks in advance for any help!

Python version : 3.7.3 IPython version : 7.19.0

matplotlib: 3.3.2 arviz : 0.10.0 pymc3 : 3.9.3

canyon289 commented 3 years ago

Hey @elnjensen, this is a workaround solution, for the names I modify them in the source object, in this case trace, and they'll be renamed in the plot. Not the most elegant but it works.

For the alternating colors, thats a neat idea. Off the top of my head I cant think of an api, but I suppose the most natural thing to do would be to copy matplotlibs implementation. Thank you for the suggestion

sushmit86 commented 3 years ago

@canyon289 Can you help me how can I rename variables for multivariate distribution

  with pm.Model() as category_model:
    a = np.ones(k)
    theta = pm.Dirichlet("theta", a=a)
    count_max_categories = pm.Multinomial("count_max_categories", n=n, p=theta, observed=list_prod_cat_val)
    trace_category_model = pm.sample(draws=1000,tune = 1000)
    return trace_category_model

In the above k=3 instead of theta[0], theta[1], theta[2] I am looking for something like A,B and C

OriolAbril commented 3 years ago

To rename what is shown permanently for multidimensional variables, you need to modify the coordinate values of the xarray dataset. For changes that affect plotting but not the inferencedata itself (so label based indexing stays the same) you should take a look at https://arviz-devs.github.io/arviz/user_guide/label_guide.html (and use ArviZ development version, we'll release labellers soon but it hasn't happened yet)

jamesvrt commented 1 year ago

I'd also like different colors for variables. There is a colors parameter in az.plot_forest but if I provide a list of colors or the "cycle" option it only uses the first color for all variables. Maybe I'm misunderstanding the purpose of this parameter. I'm using Arviz 0.14.0.

Here's an example using colors="cycle":

import arviz as az

non_centered_data = az.load_arviz_data('non_centered_eight')

axes = az.plot_forest(non_centered_data,
    kind='forestplot',
    var_names=["^the"],
    filter_vars="regex",
    combined=True,
    figsize=(9, 7),
    colors="cycle")

axes[0].set_title('Estimated theta for 8 schools model')

image

Here's another example, this time using a list of colors:

import arviz as az
from matplotlib import cm
import matplotlib.pyplot as plt
import numpy as np

non_centered_data = az.load_arviz_data('non_centered_eight')

n_schools = len(non_centered_data.posterior.coords["school"].values)
cmap = cm.get_cmap("turbo")
colors = [cmap(x) for x in np.linspace(0, 1, n_schools)]
print(colors)

axes = az.plot_forest(non_centered_data,
    kind='forestplot',
    var_names=["^the"],
    filter_vars="regex",
    combined=True,
    figsize=(9, 7),
    colors=colors)

axes[0].set_title('Estimated theta for 8 schools model')
plt.show()

image

OriolAbril commented 1 year ago

@jamesvrt colors is currently used for different models (one example at https://python.arviz.org/en/stable/examples/plot_forest_comparison.html), not for different variables. The function would need to be modified to achieve this.

jamesvrt commented 1 year ago

@OriolAbril Thanks, that makes sense. I compare variables within a model more than across models, so it would be nice to have the same functionality for those too. These plots are already fantastic time savers though, so appreciate all your hard work so far!

As an aside for other users: I can produce the plot I want by splitting the input xarray dataset into multiple models. Not how the API is supposed to be used but works as a hack for now.

Update: this approach produced incorrect scaling in the resultant plot.

OriolAbril commented 1 year ago

this approach produced incorrect scaling in the resultant plot.

What do you mean by incorrect scaling? Might be a bug that needs fixing.

I am also not sure about the API for coloring on variables. Feel free to share ideas on the issue

jamesvrt commented 1 year ago

What do you mean by incorrect scaling? Might be a bug that needs fixing.

Sorry, scaling for ridge plots, I should have said. For forest plots they look fine. Here's the code:

import arviz as az
from matplotlib import cm
import matplotlib.pyplot as plt
import numpy as np

non_centered_data = az.load_arviz_data('non_centered_eight')

schools = non_centered_data.posterior.coords["school"].values
n_schools = len(schools)
cmap = cm.get_cmap("turbo")
colors = [cmap(x) for x in np.linspace(0, 1, n_schools)]
print(colors)

axes = az.plot_forest(
    [non_centered_data.sel(school=school) for school in schools],
    model_names=schools,
    kind='forestplot',
    var_names=["^the"],
    filter_vars="regex",
    combined=True,
    figsize=(9, 7),
    colors=colors,
    legend=False)

axes[0].set_title('Estimated theta for 8 schools model')
plt.show()

image

jasontilley commented 1 year ago

I too would like to be able to choose colors for different covariates, particularly to distinguish when 0 is contained within the 95 percentile or quartiles. At the moment, it seems the only "lines" exposed in the AxesSubplot object is the circles. MCMCPlot in R has this functionality and it is very useful.