arviz-devs / arviz

Exploratory analysis of Bayesian models with Python
https://python.arviz.org
Apache License 2.0
1.61k stars 407 forks source link

select individual levels from the Dimensions to plot pystan #595

Open tommylees112 opened 5 years ago

tommylees112 commented 5 years ago

Short Description

I want to select individual levels from the Dimensions to plot because plotting all of the levels of a variable is slow and the plot uninterpretable.

Code Example or link

I am trying to reproduce the PyStan example here showing the use of multilevel modelling.

The code and extraction are below:

varying_intercept = """
data {                                             
  int<lower=0> J;                           // the number of counties                  
  int<lower=0> N;                           // the number of observations                  
  int<lower=1,upper=J> county[N];           // the county for each observation                                 
  vector[N] x;                              // predictor/regressor (floor/basement)              
  vector[N] y;                              // the output variable (log radon levels)  
}                                              
parameters {                                             
  vector[J] a;                              // the random intercept               
  real b;                                   // the fixed coefficient (FIXED EFFECT)
  real mu_a;                                // mean of the population of counties (CONSTANT FOR POPULATION)             
  real<lower=0,upper=100> sigma_a;          // the variance of the counties (CONSTANT FOR POPULATION) 
  real<lower=0,upper=100> sigma_y;          // the variance of the observations (CONSTANT FOR POPULATION)      
}                                              
transformed parameters {                                              
  vector[N] y_hat;                           // estimated log radon level for each datapoint                  

  for (i in 1:N)                             // for each datapoint                
    y_hat[i] <- a[county[i]] + x[i] * b;     //   estimate the mean of the log radon as a simple linear regression                                        
}                                             
model {                                             
  sigma_a ~ uniform(0, 100);                 // variation between the counties                            
  a ~ normal (mu_a, sigma_a);                // the intercept varying (RANDOM EFFECT)                             

  b ~ normal (0, 1);                         // the coefficient                    

  sigma_y ~ uniform(0, 100);                 // the sampling variation of the log-radon                            
  y ~ normal(y_hat, sigma_y);                // model the log radon levels
}                                             
"""

varying_intercept_data = {'N': len(log_radon),
                          'J': len(n_county),
                          'county': county+1, # Stan counts starting at 1
                          'x': floor_measure,
                          'y': log_radon}

varying_intercept_fit = pystan.stan(model_code=varying_intercept, data=varying_intercept_data, iter=1000, chains=2)

I then extract the data to ArViz

fit = varying_intercept_fit

data = az.from_pystan(posterior=fit,
                      posterior_predictive='y_hat',
                      observed_data=['y'],
                      coords={'county': n_county},
                      dims={'a': ['county']}) #, 'y': ['county'], 'log_lik': ['county'], 'y_hat': ['county'], 'theta_tilde': ['county']})

data

Out[]:
Inference data with groups:
    > posterior
    > sample_stats
    > posterior_predictive
    > observed_data

I want to make a plot of only a few of the counties (the model levels). The following takes an age to run because it is plotting ALL counties traces, but I want to select them.

az.traceplot(data)

I found this help here:

az.plot_trace(data, var_names='a', coords={'county': range(0, 5)});
az.plot_forest(data.posterior.sel(county=range(0, 5)), var_names='a');
az.plot_parallel(data, var_names='a', coords={'county': range(0, 5)});
az.plot_posterior(data, var_names='a', coords={'county': range(0, 5)});

But I get an error:

---------------------------------------------------------------------------
InvalidIndexError                         Traceback (most recent call last)
<ipython-input-35-46d226304076> in <module>
      1 # select only some of the levels to plot!!!
      2 # https://discourse.pymc.io/t/best-way-to-plot-and-do-ppc-with-variable-that-has-too-many-levels/2276/3
----> 3 az.plot_trace(data, var_names='a', coords={'county': range(0, 5)});
      4 # az.plot_forest(data.posterior.sel(county=range(0, 5)), var_names='a');
      5 # az.plot_parallel(data, var_names='a', coords={'county': range(0, 5)});

~/miniconda3/envs/stan/lib/python3.7/site-packages/arviz/plots/traceplot.py in plot_trace(data, var_names, coords, divergences, figsize, textsize, lines, combined, kde_kwargs, hist_kwargs, trace_kwargs)
    105         lines = ()
    106 
--> 107     plotters = list(xarray_var_iter(get_coords(data, coords), var_names=var_names, combined=True))
    108 
    109     if figsize is None:

~/miniconda3/envs/stan/lib/python3.7/site-packages/arviz/plots/plot_utils.py in get_coords(data, coords)
    322     """
    323     try:
--> 324         return data.sel(**coords)
    325 
    326     except ValueError:

~/miniconda3/envs/stan/lib/python3.7/site-packages/xarray/core/dataset.py in sel(self, indexers, method, tolerance, drop, **indexers_kwargs)
   1608         indexers = either_dict_or_kwargs(indexers, indexers_kwargs, 'sel')
   1609         pos_indexers, new_indexes = remap_label_indexers(
-> 1610             self, indexers=indexers, method=method, tolerance=tolerance)
   1611         result = self.isel(indexers=pos_indexers, drop=drop)
   1612         return result._replace_indexes(new_indexes)

~/miniconda3/envs/stan/lib/python3.7/site-packages/xarray/core/coordinates.py in remap_label_indexers(obj, indexers, method, tolerance, **indexers_kwargs)
    353 
    354     pos_indexers, new_indexes = indexing.remap_label_indexers(
--> 355         obj, v_indexers, method=method, tolerance=tolerance
    356     )
    357     # attach indexer's coordinate to pos_indexers

~/miniconda3/envs/stan/lib/python3.7/site-packages/xarray/core/indexing.py in remap_label_indexers(data_obj, indexers, method, tolerance)
    256         else:
    257             idxr, new_idx = convert_label_indexer(index, label,
--> 258                                                   dim, method, tolerance)
    259             pos_indexers[dim] = idxr
    260             if new_idx is not None:

~/miniconda3/envs/stan/lib/python3.7/site-packages/xarray/core/indexing.py in convert_label_indexer(index, label, index_name, method, tolerance)
    192                 raise ValueError('Vectorized selection is not available along '
    193                                  'MultiIndex variable: ' + index_name)
--> 194             indexer = get_indexer_nd(index, label, method, tolerance)
    195             if np.any(indexer < 0):
    196                 raise KeyError('not all values found in index %r'

~/miniconda3/envs/stan/lib/python3.7/site-packages/xarray/core/indexing.py in get_indexer_nd(index, labels, method, tolerance)
    120 
    121     flat_labels = np.ravel(labels)
--> 122     flat_indexer = index.get_indexer(flat_labels, **kwargs)
    123     indexer = flat_indexer.reshape(labels.shape)
    124     return indexer

~/miniconda3/envs/stan/lib/python3.7/site-packages/pandas/core/indexes/base.py in get_indexer(self, target, method, limit, tolerance)
   2737 
   2738         if not self.is_unique:
-> 2739             raise InvalidIndexError('Reindexing only valid with uniquely'
   2740                                     ' valued Index objects')
   2741 

InvalidIndexError: Reindexing only valid with uniquely valued Index objects

Also include the ArviZ version and version of any other relevant packages.

Arviz Version:  0.3.2
numpy Version:  1.15.0
pandas Version:  0.24.1

Relevant documentation or public examples

https://mc-stan.org/users/documentation/case-studies/radon.html

ahartikainen commented 5 years ago

We need a more flexible selection for the variables, but not sure what would be our best option for the API.

@OriolAbril did we have any update on this issue?

OriolAbril commented 5 years ago

This issue is actually due to indexing properties of xarray. Simple ArviZ unrelated example below:

import xarray as xr
import numpy as np

data = xr.DataArray(
    data=np.random.random(size=(4,100,8)), 
    dims=("chain", "draw", "dim1"), 
    coords={"chain": range(4), "draw": range(100), "dim1": np.random.choice([0,1,2], size=8)}
)

print(data)
# output
# <xarray.DataArray (chain: 4, draw: 100, dim1: 8)>
# array([[[0.828962, 0.514844, ..., 0.180102, 0.365011],
# ...
#         [0.548044, 0.621308, ..., 0.373455, 0.586788]]])
# Coordinates:
#   * chain    (chain) int64 0 1 2 3
#   * draw     (draw) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99
#   * dim1     (dim1) int64 1 2 1 1 2 2 1 0

print(data.sel(dim1=1))
# output
# <xarray.DataArray (chain: 4, draw: 100, dim1: 4)>
# array([[[0.828962, 0.697391, 0.384503, 0.180102],
# ...
#         [0.548044, 0.793319, 0.735403, 0.373455]]])
# Coordinates:
#   * chain    (chain) int64 0 1 2 3
#   * draw     (draw) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99
#   * dim1     (dim1) int64 1 1 1 1

# but
data.sel(dim1=[1,2])
# output
# InvalidIndexError: Reindexing only valid with uniquely valued Index objects

To actually select a subset of a DataArray or Dataset based on a coordinate with repeated index values, where must be used.

data.where(data.dim1.isin((1,2)), drop=True) 
# drop is False by default, and it converts values not fulfilling to NaN, which is not our  goal

# output
# <xarray.DataArray (chain: 4, draw: 100, dim1: 7)>
# array([[[0.828962, 0.514844, ..., 0.042839, 0.180102],
# ...
#         [0.548044, 0.621308, ..., 0.110815, 0.373455]]])
# Coordinates:
#   * chain    (chain) int64 0 1 2 3
#   * draw     (draw) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99
#   * dim1     (dim1) int64 1 2 1 1 2 2 1

We could discuss on how to implement this into ArviZ. For now it must be done by the user before calling ArviZ functions. I guess that in your case it would be something like:

az.plot_forest(
    data.posterior.where(data.posterior.county.isin(range(0, 5)), drop=True), 
    var_names='a'
);
ahartikainen commented 1 year ago

@OriolAbril did we have .where described somewhere in the docs? It is really powerful function.