Open elnjensen opened 3 years ago
Hey @elnjensen, this is a workaround solution, for the names I modify them in the source object, in this case trace, and they'll be renamed in the plot. Not the most elegant but it works.
For the alternating colors, thats a neat idea. Off the top of my head I cant think of an api, but I suppose the most natural thing to do would be to copy matplotlibs implementation. Thank you for the suggestion
@canyon289 Can you help me how can I rename variables for multivariate distribution
with pm.Model() as category_model:
a = np.ones(k)
theta = pm.Dirichlet("theta", a=a)
count_max_categories = pm.Multinomial("count_max_categories", n=n, p=theta, observed=list_prod_cat_val)
trace_category_model = pm.sample(draws=1000,tune = 1000)
return trace_category_model
In the above k=3 instead of theta[0], theta[1], theta[2] I am looking for something like A,B and C
To rename what is shown permanently for multidimensional variables, you need to modify the coordinate values of the xarray dataset. For changes that affect plotting but not the inferencedata itself (so label based indexing stays the same) you should take a look at https://arviz-devs.github.io/arviz/user_guide/label_guide.html (and use ArviZ development version, we'll release labellers soon but it hasn't happened yet)
I'd also like different colors for variables. There is a colors
parameter in az.plot_forest
but if I provide a list of colors or the "cycle" option it only uses the first color for all variables. Maybe I'm misunderstanding the purpose of this parameter. I'm using Arviz 0.14.0.
Here's an example using colors="cycle"
:
import arviz as az
non_centered_data = az.load_arviz_data('non_centered_eight')
axes = az.plot_forest(non_centered_data,
kind='forestplot',
var_names=["^the"],
filter_vars="regex",
combined=True,
figsize=(9, 7),
colors="cycle")
axes[0].set_title('Estimated theta for 8 schools model')
Here's another example, this time using a list of colors:
import arviz as az
from matplotlib import cm
import matplotlib.pyplot as plt
import numpy as np
non_centered_data = az.load_arviz_data('non_centered_eight')
n_schools = len(non_centered_data.posterior.coords["school"].values)
cmap = cm.get_cmap("turbo")
colors = [cmap(x) for x in np.linspace(0, 1, n_schools)]
print(colors)
axes = az.plot_forest(non_centered_data,
kind='forestplot',
var_names=["^the"],
filter_vars="regex",
combined=True,
figsize=(9, 7),
colors=colors)
axes[0].set_title('Estimated theta for 8 schools model')
plt.show()
@jamesvrt colors
is currently used for different models (one example at https://python.arviz.org/en/stable/examples/plot_forest_comparison.html), not for different variables. The function would need to be modified to achieve this.
@OriolAbril Thanks, that makes sense. I compare variables within a model more than across models, so it would be nice to have the same functionality for those too. These plots are already fantastic time savers though, so appreciate all your hard work so far!
As an aside for other users: I can produce the plot I want by splitting the input xarray dataset into multiple models. Not how the API is supposed to be used but works as a hack for now.
Update: this approach produced incorrect scaling in the resultant plot.
this approach produced incorrect scaling in the resultant plot.
What do you mean by incorrect scaling? Might be a bug that needs fixing.
I am also not sure about the API for coloring on variables. Feel free to share ideas on the issue
What do you mean by incorrect scaling? Might be a bug that needs fixing.
Sorry, scaling for ridge plots, I should have said. For forest plots they look fine. Here's the code:
import arviz as az
from matplotlib import cm
import matplotlib.pyplot as plt
import numpy as np
non_centered_data = az.load_arviz_data('non_centered_eight')
schools = non_centered_data.posterior.coords["school"].values
n_schools = len(schools)
cmap = cm.get_cmap("turbo")
colors = [cmap(x) for x in np.linspace(0, 1, n_schools)]
print(colors)
axes = az.plot_forest(
[non_centered_data.sel(school=school) for school in schools],
model_names=schools,
kind='forestplot',
var_names=["^the"],
filter_vars="regex",
combined=True,
figsize=(9, 7),
colors=colors,
legend=False)
axes[0].set_title('Estimated theta for 8 schools model')
plt.show()
I too would like to be able to choose colors for different covariates, particularly to distinguish when 0 is contained within the 95 percentile or quartiles. At the moment, it seems the only "lines" exposed in the AxesSubplot object is the circles. MCMCPlot in R has this functionality and it is very useful.
Short Description
I'm using arviz to make plots from a trace generated with pymc3, and I'd like more descriptive labels on the plots than the variable names themselves.
Code Example or link
For example, using
plot_forest
I created the ridgeplot shown below:I'd like to label the y axis with something other than 'r0', 'r1', etc. If there's a general solution, the same desire would apply to plots made with
plot_trace
, for example.Ideally I'd also like each variable to be plotted in a different color if possible as well.
Thanks in advance for any help!
Python version : 3.7.3 IPython version : 7.19.0
matplotlib: 3.3.2 arviz : 0.10.0 pymc3 : 3.9.3