mjskay / tidybayes

Bayesian analysis + tidy data + geoms (R package)
http://mjskay.github.io/tidybayes
GNU General Public License v3.0
726 stars 59 forks source link

`spread_draws` can't find indexed variables from stanfit generated by rstan::read_stan_csv #132

Open jburos opened 6 years ago

jburos commented 6 years ago

First, thanks for putting together a well-thought-out package. Very useful functions here to support common workflows.

I'm seeing a problem using spread_draws with a fit loaded from CmdStan using rstan::read_stan_csv().

Specifically, in the context of my code:

> priorpd$fit %>% tidybayes::spread_draws(e_beta[coef]) %>% head()
Error in spread_draws_long_(model, variable_names, dimension_names, regex = regex,  : 
  No variables found matching spec: c(e_beta)[coef]

And, when using a fit called m2 similar to that in your ABC_fit vignette (see below for reproducible example):

> m2 %>% tidybayes::spread_draws(., condition_mean[condition])
Error in spread_draws_long_(model, variable_names, dimension_names, regex = regex,  : 
  No variables found matching spec: c(condition_mean)[condition]

What I think is going on

It looks like the stanfit object is slightly differently formatted when fit using CmdStan+rstan::read_stan_csv vs those fit using Rstan::sampling. This may be an rstan bug rather than a tidybayes bug, but it impacts tidybayes more than Rstan.

In any case, this causes the data.frame returned by tidy_draws() to name the variables like name_of_param.index instead of name_of_param[[index]]. Since spread_draws is looking for variables using a regex like the following "^(e_beta)\\[(.+)\\]$" none of these variables is found by spread_draws_long_.

For example:

> priorpd$fit %>% tidybayes::spread_draws(., `e_beta.*`, regex = T) %>% str()
Classes ‘tbl_df’, ‘tbl’ and 'data.frame':   2000 obs. of  6 variables:
 $ .chain    : int  1 1 1 1 1 1 1 1 1 1 ...
 $ .iteration: int  1 2 3 4 5 6 7 8 9 10 ...
 $ .draw     : int  1 2 3 4 5 6 7 8 9 10 ...
 $ e_beta.1  : num  1.42 3.88 -3.03 5.99 -2.63 ...
 $ e_beta.2  : num  -3.439 -0.802 1.04 2.769 2.943 ...
 $ e_beta.3  : num  -1.48 -5.63 3.67 -2.9 3.43 ...

> priorpd$fit %>% tidybayes::tidy_draws() %>% head()
# A tibble: 6 x 12,025
  .chain .iteration .draw e_z_beta.1 e_z_beta.2 e_z_beta.3 e_aux_unscaled.1 e_aux_unscaled.2 e_aux_unscaled.3 e_aux_unscaled.4 e_aux_unscaled.5 e_aux_unscaled.6 e_aux_unscaled.7 e_beta.1
   <int>      <int> <int>      <dbl>      <dbl>      <dbl>            <dbl>            <dbl>            <dbl>            <dbl>            <dbl>            <dbl>            <dbl>    <dbl>
1      1          1     1    0.56842  -1.375590  -0.590039         8.967280         -0.69945         3.952980         -4.15052         1.906610        5.2288300        1.4153300  1.42105
2      1          2     2    1.55006  -0.320843  -2.250100         6.058300          4.05330         8.325630         -5.31994        -3.521040        0.3845300       -0.5067120  3.87515
3      1          3     3   -1.21000   0.415925   1.469620         2.587380          2.40697         0.378813         -4.67446         3.481290       -0.2286870       -0.1528370 -3.02501
4      1          4     4    2.39453   1.107500  -1.158390        -2.680140         -2.33679        -0.278966         -2.00348        -0.989773        0.2302250       -0.0489574  5.98634
5      1          5     5   -1.05236   1.177110   1.370120         2.943620          1.87041        -0.415789          1.95031        -0.511072        0.0171039       -0.5063310 -2.63090
6      1          6     6    1.11844  -1.588150  -1.143740        -0.403702          1.87825        -0.189716         -2.89321        -0.101369        5.9845200       -0.7510650  2.79610

Reproducible example

Here is a reproducible example based on your vignette (see full code & sampled chains in this gist).

Apologies, the output is somewhat verbose.

In the first part I set up the environment & fit the ABC model as described in your vignette Using tidy data with Bayesian Models.

library(tidyverse)
#> Warning: package 'ggplot2' was built under R version 3.4.4
#> Warning: package 'stringr' was built under R version 3.4.4
library(rstan)
#> Loading required package: StanHeaders
#> Warning: package 'StanHeaders' was built under R version 3.4.3
#> rstan (Version 2.17.3, GitRev: 2e1f913d3ca3)
#> For execution on a local, multicore CPU with excess RAM we recommend calling
#> options(mc.cores = parallel::detectCores()).
#> To avoid recompilation of unchanged Stan programs, we recommend calling
#> rstan_options(auto_write = TRUE)
#> 
#> Attaching package: 'rstan'
#> The following object is masked from 'package:tidyr':
#> 
#>     extract
library(tidybayes)
#> Warning: package 'tidybayes' was built under R version 3.4.4
#> NOTE: As of tidybayes version 1.0, several functions, arguments, and output column names
#>       have undergone significant name changes in order to adopt a unified naming scheme.
#>       See help('tidybayes-deprecated') for more information.
# source('~/projects/az-tumorsize/stan_helpers.function.R') # from https://github.com/jburos/cmdstan-wrapper/blob/master/stan_helpers.function.R 
# only needed if you want to fit your own stan code using CmdStan wrapper

set.seed(5)
n <- 10
n_condition <- 5
ABC <-
  data_frame(
    condition = rep(c("A","B","C","D","E"), n),
    response = rnorm(n * 5, c(0,1,2,1,-1), 0.5)
  )

ABC_stancode <- 'data {
  int<lower=1> n;
  int<lower=1> n_condition;
  int<lower=1, upper=n_condition> condition[n];
  real response[n];
}
parameters {
  real overall_mean;
  vector[n_condition] condition_zoffset;
  real<lower=0> response_sd;
  real<lower=0> condition_mean_sd;
}
transformed parameters {
  vector[n_condition] condition_mean;
  condition_mean = overall_mean + condition_zoffset * condition_mean_sd;
}
model {
  response_sd ~ cauchy(0, 1);         // => half-cauchy(0, 1)
  condition_mean_sd ~ cauchy(0, 1);   // => half-cauchy(0, 1)
  overall_mean ~ normal(0, 5);
  condition_zoffset ~ normal(0, 1);   // => condition_mean ~ normal(overall_mean, condition_mean_sd)
  for (i in 1:n) {
    response[i] ~ normal(condition_mean[condition[i]], response_sd);
  }
}
'
tmpd <- "/var/folders/1x/v6qn3g750k1bh84j7m_36gpw0000gp/T//RtmpCKCIYH"
ABC_stanfile <- file.path(tmpd, 'ABC_stancode.stan')
write_lines(ABC_stancode, path = ABC_stanfile)

# fit using rstan
m <-  rstan::stan(file = ABC_stanfile, data = compose_data(ABC), control = list(adapt_delta=0.99))
#> 
#> SAMPLING FOR MODEL 'ABC_stancode' NOW (CHAIN 1).
#> 
#> Gradient evaluation took 0.000116 seconds
#> 1000 transitions using 10 leapfrog steps per transition would take 1.16 seconds.
#> Adjust your expectations accordingly!
#> 
#> 
#> Iteration:    1 / 2000 [  0%]  (Warmup)
#> Iteration:  200 / 2000 [ 10%]  (Warmup)
#> Iteration:  400 / 2000 [ 20%]  (Warmup)
#> Iteration:  600 / 2000 [ 30%]  (Warmup)
#> Iteration:  800 / 2000 [ 40%]  (Warmup)
#> Iteration: 1000 / 2000 [ 50%]  (Warmup)
#> Iteration: 1001 / 2000 [ 50%]  (Sampling)
#> Iteration: 1200 / 2000 [ 60%]  (Sampling)
#> Iteration: 1400 / 2000 [ 70%]  (Sampling)
#> Iteration: 1600 / 2000 [ 80%]  (Sampling)
#> Iteration: 1800 / 2000 [ 90%]  (Sampling)
#> Iteration: 2000 / 2000 [100%]  (Sampling)
#> 
#>  Elapsed Time: 5.10533 seconds (Warm-up)
#>                6.78884 seconds (Sampling)
#>                11.8942 seconds (Total)
#> 
#> 
#> SAMPLING FOR MODEL 'ABC_stancode' NOW (CHAIN 2).
#> 
#> Gradient evaluation took 0.000106 seconds
#> 1000 transitions using 10 leapfrog steps per transition would take 1.06 seconds.
#> Adjust your expectations accordingly!
#> 
#> 
#> Iteration:    1 / 2000 [  0%]  (Warmup)
#> Iteration:  200 / 2000 [ 10%]  (Warmup)
#> Iteration:  400 / 2000 [ 20%]  (Warmup)
#> Iteration:  600 / 2000 [ 30%]  (Warmup)
#> Iteration:  800 / 2000 [ 40%]  (Warmup)
#> Iteration: 1000 / 2000 [ 50%]  (Warmup)
#> Iteration: 1001 / 2000 [ 50%]  (Sampling)
#> Iteration: 1200 / 2000 [ 60%]  (Sampling)
#> Iteration: 1400 / 2000 [ 70%]  (Sampling)
#> Iteration: 1600 / 2000 [ 80%]  (Sampling)
#> Iteration: 1800 / 2000 [ 90%]  (Sampling)
#> Iteration: 2000 / 2000 [100%]  (Sampling)
#> 
#>  Elapsed Time: 6.54299 seconds (Warm-up)
#>                6.69809 seconds (Sampling)
#>                13.2411 seconds (Total)
#> 
#> 
#> SAMPLING FOR MODEL 'ABC_stancode' NOW (CHAIN 3).
#> 
#> Gradient evaluation took 6.1e-05 seconds
#> 1000 transitions using 10 leapfrog steps per transition would take 0.61 seconds.
#> Adjust your expectations accordingly!
#> 
#> 
#> Iteration:    1 / 2000 [  0%]  (Warmup)
#> Iteration:  200 / 2000 [ 10%]  (Warmup)
#> Iteration:  400 / 2000 [ 20%]  (Warmup)
#> Iteration:  600 / 2000 [ 30%]  (Warmup)
#> Iteration:  800 / 2000 [ 40%]  (Warmup)
#> Iteration: 1000 / 2000 [ 50%]  (Warmup)
#> Iteration: 1001 / 2000 [ 50%]  (Sampling)
#> Iteration: 1200 / 2000 [ 60%]  (Sampling)
#> Iteration: 1400 / 2000 [ 70%]  (Sampling)
#> Iteration: 1600 / 2000 [ 80%]  (Sampling)
#> Iteration: 1800 / 2000 [ 90%]  (Sampling)
#> Iteration: 2000 / 2000 [100%]  (Sampling)
#> 
#>  Elapsed Time: 6.57498 seconds (Warm-up)
#>                7.63905 seconds (Sampling)
#>                14.214 seconds (Total)
#> 
#> 
#> SAMPLING FOR MODEL 'ABC_stancode' NOW (CHAIN 4).
#> 
#> Gradient evaluation took 5.7e-05 seconds
#> 1000 transitions using 10 leapfrog steps per transition would take 0.57 seconds.
#> Adjust your expectations accordingly!
#> 
#> 
#> Iteration:    1 / 2000 [  0%]  (Warmup)
#> Iteration:  200 / 2000 [ 10%]  (Warmup)
#> Iteration:  400 / 2000 [ 20%]  (Warmup)
#> Iteration:  600 / 2000 [ 30%]  (Warmup)
#> Iteration:  800 / 2000 [ 40%]  (Warmup)
#> Iteration: 1000 / 2000 [ 50%]  (Warmup)
#> Iteration: 1001 / 2000 [ 50%]  (Sampling)
#> Iteration: 1200 / 2000 [ 60%]  (Sampling)
#> Iteration: 1400 / 2000 [ 70%]  (Sampling)
#> Iteration: 1600 / 2000 [ 80%]  (Sampling)
#> Iteration: 1800 / 2000 [ 90%]  (Sampling)
#> Iteration: 2000 / 2000 [100%]  (Sampling)
#> 
#>  Elapsed Time: 6.66068 seconds (Warm-up)
#>                8.39963 seconds (Sampling)
#>                15.0603 seconds (Total)
print(m)
#> Inference for Stan model: ABC_stancode.
#> 4 chains, each with iter=2000; warmup=1000; thin=1; 
#> post-warmup draws per chain=1000, total post-warmup draws=4000.
#> 
#>                       mean se_mean   sd  2.5%   25%   50%   75% 97.5%
#> overall_mean          0.59    0.02 0.57 -0.59  0.26  0.59  0.93  1.69
#> condition_zoffset[1] -0.39    0.02 0.50 -1.40 -0.73 -0.37 -0.04  0.50
#> condition_zoffset[2]  0.37    0.02 0.47 -0.59  0.05  0.38  0.69  1.26
#> condition_zoffset[3]  1.15    0.02 0.57  0.06  0.76  1.15  1.52  2.28
#> condition_zoffset[4]  0.38    0.02 0.47 -0.58  0.05  0.39  0.70  1.25
#> condition_zoffset[5] -1.41    0.02 0.69 -2.81 -1.89 -1.37 -0.92 -0.16
#> response_sd           0.56    0.00 0.06  0.45  0.52  0.56  0.60  0.70
#> condition_mean_sd     1.22    0.02 0.56  0.61  0.87  1.09  1.43  2.49
#> condition_mean[1]     0.19    0.00 0.17 -0.14  0.08  0.19  0.31  0.55
#> condition_mean[2]     1.01    0.00 0.17  0.66  0.89  1.00  1.12  1.35
#> condition_mean[3]     1.84    0.00 0.18  1.49  1.72  1.84  1.96  2.18
#> condition_mean[4]     1.01    0.00 0.18  0.65  0.89  1.01  1.13  1.37
#> condition_mean[5]    -0.89    0.00 0.18 -1.24 -1.02 -0.89 -0.77 -0.54
#> lp__                  0.19    0.08 2.43 -5.36 -1.23  0.54  1.98  3.87
#>                      n_eff Rhat
#> overall_mean           825    1
#> condition_zoffset[1]   910    1
#> condition_zoffset[2]   935    1
#> condition_zoffset[3]   984    1
#> condition_zoffset[4]   953    1
#> condition_zoffset[5]   913    1
#> response_sd           1698    1
#> condition_mean_sd     1021    1
#> condition_mean[1]     4000    1
#> condition_mean[2]     4000    1
#> condition_mean[3]     4000    1
#> condition_mean[4]     4000    1
#> condition_mean[5]     4000    1
#> lp__                   985    1
#> 
#> Samples were drawn using NUTS(diag_e) at Fri Oct 12 12:43:02 2018.
#> For each parameter, n_eff is a crude measure of effective sample size,
#> and Rhat is the potential scale reduction factor on split chains (at 
#> convergence, Rhat=1).

# extract param values using tidybayes (thanks for this!)
m %>% tidybayes::spread_draws(., condition_mean[condition])
#> # A tibble: 20,000 x 5
#> # Groups:   condition [5]
#>    .chain .iteration .draw condition condition_mean
#>  *  <int>      <int> <int>     <int>          <dbl>
#>  1      1          1     1         1    -0.04883808
#>  2      1          1     1         2     1.19201971
#>  3      1          1     1         3     1.89843059
#>  4      1          1     1         4     0.96179226
#>  5      1          1     1         5    -0.87847760
#>  6      1          2     2         1     0.19303365
#>  7      1          2     2         2     1.07031747
#>  8      1          2     2         3     1.97499544
#>  9      1          2     2         4     1.12423975
#> 10      1          2     2         5    -1.09090160
#> # ... with 19,990 more rows
m %>% tidybayes::spread_draws(., `condition_mean.*`, regex = T)
#> # A tibble: 4,000 x 9
#>    .chain .iteration .draw condition_mean_sd `condition_mean[1]`
#>     <int>      <int> <int>             <dbl>               <dbl>
#>  1      1          1     1         1.1563071         -0.04883808
#>  2      1          2     2         1.2889995          0.19303365
#>  3      1          3     3         1.1908999          0.26339681
#>  4      1          4     4         1.5740539         -0.04230979
#>  5      1          5     5         1.0939859          0.37080496
#>  6      1          6     6         0.7799908          0.37870967
#>  7      1          7     7         0.9591682         -0.15633789
#>  8      1          8     8         0.6684509          0.26893897
#>  9      1          9     9         0.7527150          0.69452160
#> 10      1         10    10         0.8542510          0.55241468
#> # ... with 3,990 more rows, and 4 more variables:
#> #   `condition_mean[2]` <dbl>, `condition_mean[3]` <dbl>,
#> #   `condition_mean[4]` <dbl>, `condition_mean[5]` <dbl>

None of the above is that surprising -- everything works as expected.

Next we try the same fit except using CmdStan to fit the samples.

# fit using cmdstan, read in samples using rstan::read_stan_csv
# m2 <- fit_stan_model(stan_model = ABC_stanfile, stan_data = compose_data(ABC), with_predict = F)
# (the above is basically a wrapper for this part)
chains <- dir(tmpd, pattern = 'txt', full.names = T) %>%
  purrr::keep(stringr::str_detect, pattern = 'ABC_stancode.stan.chain\\d')
print(chains)
#> [1] "/var/folders/1x/v6qn3g750k1bh84j7m_36gpw0000gp/T//RtmpCKCIYH/ABC_stancode.stan.chain1.default.dba4bacf4301c133ef1f8827d96b7bd0.0.99-15-1-500-500-12345.720904.txt"
#> [2] "/var/folders/1x/v6qn3g750k1bh84j7m_36gpw0000gp/T//RtmpCKCIYH/ABC_stancode.stan.chain2.default.dba4bacf4301c133ef1f8827d96b7bd0.0.99-15-1-500-500-12345.875774.txt"
#> [3] "/var/folders/1x/v6qn3g750k1bh84j7m_36gpw0000gp/T//RtmpCKCIYH/ABC_stancode.stan.chain3.default.dba4bacf4301c133ef1f8827d96b7bd0.0.99-15-1-500-500-12345.760983.txt"
#> [4] "/var/folders/1x/v6qn3g750k1bh84j7m_36gpw0000gp/T//RtmpCKCIYH/ABC_stancode.stan.chain4.default.dba4bacf4301c133ef1f8827d96b7bd0.0.99-15-1-500-500-12345.886125.txt"
m2 <- rstan::read_stan_csv(csvfiles = chains)
# resulting object is a stanfit object
class(m2)
#> [1] "stanfit"
#> attr(,"package")
#> [1] "rstan"
print(m2)
#> Inference for Stan model: ABC_stancode.stan.chain1.default.dba4bacf4301c133ef1f8827d96b7bd0.0.99-15-1-500-500-12345.720904.
#> 4 chains, each with iter=1000; warmup=500; thin=1; 
#> post-warmup draws per chain=500, total post-warmup draws=2000.
#> 
#>                       mean se_mean   sd  2.5%   25%   50%   75% 97.5%
#> overall_mean          0.61    0.03 0.62 -0.55  0.27  0.60  0.96  1.77
#> condition_zoffset[1] -0.38    0.02 0.49 -1.35 -0.73 -0.38 -0.05  0.59
#> condition_zoffset[2]  0.37    0.02 0.49 -0.55  0.03  0.38  0.71  1.32
#> condition_zoffset[3]  1.14    0.03 0.60  0.02  0.73  1.14  1.54  2.39
#> condition_zoffset[4]  0.38    0.02 0.50 -0.56  0.03  0.39  0.72  1.35
#> condition_zoffset[5] -1.39    0.03 0.63 -2.67 -1.78 -1.41 -0.96 -0.19
#> response_sd           0.56    0.00 0.06  0.46  0.52  0.56  0.60  0.69
#> condition_mean_sd     1.21    0.02 0.50  0.61  0.89  1.09  1.38  2.60
#> condition_mean[1]     0.19    0.00 0.18 -0.15  0.08  0.19  0.31  0.55
#> condition_mean[2]     1.00    0.00 0.18  0.65  0.88  1.00  1.12  1.36
#> condition_mean[3]     1.83    0.00 0.18  1.47  1.71  1.84  1.95  2.19
#> condition_mean[4]     1.01    0.00 0.18  0.67  0.90  1.01  1.13  1.36
#> condition_mean[5]    -0.89    0.00 0.18 -1.25 -1.01 -0.89 -0.77 -0.52
#> lp__                  0.19    0.11 2.30 -5.23 -1.22  0.55  1.85  3.72
#>                      n_eff Rhat
#> overall_mean           391 1.01
#> condition_zoffset[1]   470 1.01
#> condition_zoffset[2]   516 1.00
#> condition_zoffset[3]   360 1.00
#> condition_zoffset[4]   502 1.00
#> condition_zoffset[5]   413 1.02
#> response_sd            740 1.00
#> condition_mean_sd      404 1.01
#> condition_mean[1]     2000 1.00
#> condition_mean[2]     2000 1.00
#> condition_mean[3]     2000 1.00
#> condition_mean[4]     2000 1.00
#> condition_mean[5]     2000 1.00
#> lp__                   405 1.01
#> 
#> Samples were drawn using NUTS(diag_e) at Fri Oct 12 12:13:03 2018.
#> For each parameter, n_eff is a crude measure of effective sample size,
#> and Rhat is the potential scale reduction factor on split chains (at 
#> convergence, Rhat=1).
m2 %>% tidybayes::spread_draws(., condition_mean[condition])
#> Error in spread_draws_long_(model, variable_names, dimension_names, regex = regex, : No variables found matching spec: c(condition_mean)[condition]
m2 %>% tidybayes::spread_draws(., `condition_mean.*`, regex = T)
#> # A tibble: 2,000 x 9
#>    .chain .iteration .draw condition_mean_sd condition_mean.1
#>     <int>      <int> <int>             <dbl>            <dbl>
#>  1      1          1     1          2.949930         0.272438
#>  2      1          2     2          2.993570         0.275023
#>  3      1          3     3          2.962110         0.323740
#>  4      1          4     4          2.932010         0.491940
#>  5      1          5     5          0.936189         0.131151
#>  6      1          6     6          1.042090         0.192548
#>  7      1          7     7          2.066500         0.171896
#>  8      1          8     8          1.921560         0.354787
#>  9      1          9     9          1.813920         0.279345
#> 10      1         10    10          1.801490         0.118684
#> # ... with 1,990 more rows, and 4 more variables: condition_mean.2 <dbl>,
#> #   condition_mean.3 <dbl>, condition_mean.4 <dbl>, condition_mean.5 <dbl>

SessionInfo

> sessionInfo()
R version 3.4.2 (2017-09-28)
Platform: x86_64-apple-darwin15.6.0 (64-bit)
Running under: macOS High Sierra 10.13.6

Matrix products: default
BLAS: /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/3.4/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] parallel  splines   stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] tidybayes_1.0.1    bindrcpp_0.2       loo_2.0.0          digest_0.6.16      glue_1.2.0         lubridate_1.7.4    bayesplot_1.6.0    doParallel_1.0.11  iterators_1.0.8    foreach_1.4.3     
[11] rstan_2.17.3       StanHeaders_2.17.2 empjm_0.0.0.9000   Formula_1.2-3      forcats_0.2.0      stringr_1.3.1      dplyr_0.7.4        purrr_0.2.4        readr_1.1.1        tidyr_0.7.2       
[21] tibble_1.3.4       ggplot2_3.0.0      tidyverse_1.2.1    survival_2.41-3   

And I'm using cmdstan at stan-dev/cmdstan@6bdc8ba

mjskay commented 6 years ago

Thanks for this! There are a couple of different functions provided by rstan for extracting draws from a stanfit object, and it looks like for stanfit objects created by read_stan_csv, some functions will give names with [] and some with .. It should be easy enough for me to switch tidybayes to use one of the functions that gives the names with [], which should fix spread_draws and other functions relying on that syntax.

jburos commented 4 years ago

It's been a while since this issue was addressed, but posting my work-around here in case it's helpful to others.

My process now is:

  1. read in posterior-draws from CSV using vroom::vroom (or, one could read them in using readr::read_csv.
    • pass a repair_names function to rename variables to use tidybayes-expected format (see below)
  2. convert this to an mcmc object using coda::as.mcmc() or coda::as.mcmc.list()
  3. proceed as usual for analysis.

Here is my current .repair_names function, in case it helps anyone:

#' change incoming format (`parameter.i.j`) to standard format (`parameter[i,j]`)
#' @param .names character or list of character names
#' @return character names with 1, 2, and 3-d structures renamed to v[i], v[i,j], and v[i,j,k]
.repair_names <- function(.names) {
  .names %>%
    # reformat 1-D parameters -> parameter[i]
    purrr::map_if(.p = ~stringr::str_detect(.x, pattern = '^[^\\\\.]+\\.?\\d+$'),
                  .f = ~stringr::str_replace(.x, pattern = '\\.(\\d+)$', replacement = '[\\1]')) %>%
    # reformat 2-D parameters -> parameter[i,j]
    purrr::map_if(.p = ~stringr::str_detect(.x, pattern = '^[^\\\\.]+\\.\\d+\\.\\d+$'),
                  .f = ~stringr::str_replace(.x, pattern = '\\.(\\d+)\\.(\\d+)$', replacement = '[\\1,\\2]')) %>%
    # reformat 3-D parameters -> parameter[i,j,k]
    purrr::map_if(.p = ~stringr::str_detect(.x, pattern = '^[^\\\\.]+\\.\\d+\\.\\d+\\.\\d+$'),
                  .f = ~stringr::str_replace(.x, pattern = '\\.(\\d+)\\.(\\d+)\\.(\\d+)$', replacement = '[\\1,\\2,\\3]')) %>%
    as.character()
}

I'm pretty sure there is a slightly more elegant regex one could use, but I decided to forgo that complexity for now.

mjskay commented 4 years ago

Thanks, this is very helpful! Sorry I haven't gotten back to this one yet.

At some point it might be good for spread/gather_draws to support more custom column names inputs so you don't have to do this manually. A simple change that might fix this is an analog to the sep argument... currently setting sep = "[.]" in spread_draws() would get you halfway there, maybe a solution would be to add a similar open_bracket and close_bracket, or a more generic way to build up these expressions using regexes.

Raoul-Kima commented 4 years ago

Here's a simpler workaround. I don't know whether it works in all cases, but in my case this was sufficient to solve the problem.

Assuming "stanfit" is a stanfit object created by rstan::read_stan_csv from a cmdstan csv-file:

for(currentChainId in seq_len(length(stanfit@sim$samples))) # for each chain ...
{
    names(stanfit@sim$samples[[currentChainId]])= # ... rename variables to restore tidybayes compatibility.
        names(stanfit)
}