stan-dev / posterior

The posterior R package
https://mc-stan.org/posterior/
Other
167 stars 24 forks source link

easy conversion to draws from rstantools format #251

Open wds15 opened 2 years ago

wds15 commented 2 years ago

I am struggling with converting posteriors I get from rstantools things like posterior_linpred to a draws object correctly. The problem is that the chain information gets dropped. Here is an example illustrating what I'd like to have:

library(posterior)
#> Warning: package 'posterior' was built under R version 4.1.2
#> This is posterior version 1.2.2
#> 
#> Attaching package: 'posterior'
#> The following objects are masked from 'package:stats':
#> 
#>     mad, sd, var

samp <- as_draws_matrix(example_draws())

## posterior_* functions from stan tools return matrices like

rstantools_samp <- matrix(as.matrix(samp), niterations(samp)*nchains(samp), nvariables(samp))
colnames(rstantools_samp) <- variables(samp)
head(rstantools_samp)
#>            mu      tau   theta[1]    theta[2]   theta[3] theta[4]    theta[5]
#> [1,] 2.005831 2.767367  3.9617520  0.27123540 -0.7431706 2.104805  0.92348879
#> [2,] 1.458316 6.979976  0.1237101 -0.06901539  0.9518270 7.281225 -0.06195211
#> [3,] 5.814947 9.677075 21.2510465 14.93055775  1.8290945 1.381443  0.53106337
#> [4,] 6.849586 4.788366 14.6996540  8.58618604  2.6749150 4.393232  4.75807198
#> [5,] 1.805168 2.848165  5.9600546  1.15573721  3.1088628 1.994890  0.76885094
#> [6,] 3.841243 4.083357  5.7601096  9.90920447 -0.9956266 5.328625  5.88894271
#>       theta[6]  theta[7]  theta[8]
#> [1,]  1.650237  3.320019  4.848542
#> [2,] 11.257502  9.621128 -8.640446
#> [3,]  7.155371 14.802013 -1.736363
#> [4,]  8.101547  9.491277  5.281551
#> [5,]  4.656270  1.208251 -4.540236
#> [6,] -1.701463  2.780403  7.075855

dim(rstantools_samp)
#> [1] 400  10

## things are order by chain, so we have

all(rstantools_samp[1:100,1] == subset_draws(samp, variable="mu", chain=1))
#> [1] TRUE
all(rstantools_samp[101:200,1] == subset_draws(samp, variable="mu", chain=2))
#> [1] TRUE

## now we should have a posterior function which lets me create from
## rstantools_samp a posterior draws thing which knows the number of
## chains. This does not work:

as_draws_matrix(rstantools_samp, .nchains=4)
#> # A draws_matrix: 400 iterations, 1 chains, and 10 variables
#>     variable
#> draw   mu tau theta[1] theta[2] theta[3] theta[4] theta[5] theta[6]
#>   1  2.01 2.8     3.96    0.271    -0.74      2.1    0.923      1.7
#>   2  1.46 7.0     0.12   -0.069     0.95      7.3   -0.062     11.3
#>   3  5.81 9.7    21.25   14.931     1.83      1.4    0.531      7.2
#>   4  6.85 4.8    14.70    8.586     2.67      4.4    4.758      8.1
#>   5  1.81 2.8     5.96    1.156     3.11      2.0    0.769      4.7
#>   6  3.84 4.1     5.76    9.909    -1.00      5.3    5.889     -1.7
#>   7  5.47 4.0     4.03    4.151    10.15      6.6    3.741     -2.2
#>   8  1.20 1.5    -0.28    1.846     0.47      4.3    1.467      3.3
#>   9  0.15 3.9     1.81    0.661     0.86      4.5   -1.025      1.1
#>   10 7.17 1.8     6.08    8.102     7.68      5.6    7.106      8.5
#> # ... with 390 more draws, and 2 more variables

Created on 2022-08-04 by the reprex package (v2.0.1)

What I can do is to crudely set the nchains attribute to the number of chains. So I think the above should just work and give me a draws thing with 4 chains... this obviously requires documented formatting of the input samples to be column major sorted...

mjskay commented 2 years ago

One option is to ingest posterior_...() function output via rvar(), as the internal format of rvar() is (by design) the same as the output of those functions. rvar() lets you set the number of chains:

library(posterior)
library(rstanarm)

mtcars_subset = mtcars[, c("hp", "cyl", "mpg")]

m = stan_glm(mpg ~ hp*cyl, data = mtcars_subset, chains = 4)

epred = rvar(posterior_epred(m), nchains = 4)
epred
#> rvar<1000,4>[32] mean ± sd:
#> 
#>           Mazda RX4       Mazda RX4 Wag          Datsun 710      Hornet 4 Drive 
#>          20 ± 0.79           20 ± 0.79           26 ± 0.93           20 ± 0.79  
#>   Hornet Sportabout             Valiant          Duster 360           Merc 240D 
#>          16 ± 0.88           21 ± 0.80           15 ± 1.01           28 ± 1.18  
#>            Merc 230            Merc 280           Merc 280C          Merc 450SE 
#>          26 ± 0.94           20 ± 0.81           20 ± 0.81           15 ± 0.85  
#>          Merc 450SL         Merc 450SLC  Cadillac Fleetwood Lincoln Continental 
#>          15 ± 0.85           15 ± 0.85           15 ± 0.79           15 ± 0.81  
#>   Chrysler Imperial            Fiat 128         Honda Civic      Toyota Corolla 
#>          15 ± 0.89           28 ± 1.10           29 ± 1.41           28 ± 1.12  
#>       Toyota Corona    Dodge Challenger         AMC Javelin          Camaro Z28 
#>          26 ± 0.97           16 ± 1.09           16 ± 1.09           15 ± 1.01  
#>    Pontiac Firebird           Fiat X1-9       Porsche 914-2        Lotus Europa 
#>          16 ± 0.88           28 ± 1.10           26 ± 0.91           24 ± 1.25  
#>      Ford Pantera L        Ferrari Dino       Maserati Bora          Volvo 142E 
#>          14 ± 1.21           18 ± 1.44           13 ± 2.11           25 ± 1.17

This can be especially useful for the posterior_...() functions since you can put the resulting rvars in data frame alongside the data used to make the predictions:

cbind(mtcars_subset, epred = epred)
#>                      hp cyl  mpg                epred
#> Mazda RX4           110   6 21.0 20.48962 ± 0.7908150
#> Mazda RX4 Wag       110   6 21.0 20.48962 ± 0.7908150
#> Datsun 710           93   4 22.8 25.80881 ± 0.9256434
#> Hornet 4 Drive      110   6 21.4 20.48962 ± 0.7908150
#> Hornet Sportabout   175   8 18.7 15.51820 ± 0.8760158
#> Valiant             105   6 18.1 20.70694 ± 0.8036316
#> Duster 360          245   8 14.3 14.55358 ± 1.0112446
#> Merc 240D            62   4 24.4 28.07635 ± 1.1825719
#> Merc 230             95   4 22.8 25.66252 ± 0.9437801
#> Merc 280            123   6 19.2 19.92459 ± 0.8124373
#> Merc 280C           123   6 17.8 19.92459 ± 0.8124373
#> Merc 450SE          180   8 16.4 15.44930 ± 0.8459401
#> Merc 450SL          180   8 17.3 15.44930 ± 0.8459401
#> Merc 450SLC         180   8 15.2 15.44930 ± 0.8459401
#> Cadillac Fleetwood  205   8 10.4 15.10479 ± 0.7862891
#> Lincoln Continental 215   8 10.4 14.96699 ± 0.8091412
#> Chrysler Imperial   230   8 14.7 14.76028 ± 0.8889254
#> Fiat 128             66   4 32.4 27.78376 ± 1.1026873
#> Honda Civic          52   4 30.4 28.80781 ± 1.4145373
#> Toyota Corolla       65   4 33.9 27.85691 ± 1.1217972
#> Toyota Corona        97   4 21.5 25.51622 ± 0.9659045
#> Dodge Challenger    150   8 15.5 15.86271 ± 1.0899219
#> AMC Javelin         150   8 15.2 15.86271 ± 1.0899219
#> Camaro Z28          245   8 13.3 14.55358 ± 1.0112446
#> Pontiac Firebird    175   8 19.2 15.51820 ± 0.8760158
#> Fiat X1-9            66   4 27.3 27.78376 ± 1.1026873
#> Porsche 914-2        91   4 26.0 25.95510 ± 0.9117325
#> Lotus Europa        113   4 30.4 24.34588 ± 1.2535618
#> Ford Pantera L      264   8 15.8 14.29175 ± 1.2067307
#> Ferrari Dino        175   6 19.7 17.66450 ± 1.4376777
#> Maserati Bora       335   8 15.0 13.31335 ± 2.1102372
#> Volvo 142E          109   4 21.4 24.63846 ± 1.1669329

And if you do want it as a draws_matrix, you can use as_draws_matrix():

as_draws_matrix(epred)
#> # A draws_matrix: 1000 iterations, 4 chains, and 32 variables
#>     variable
#> draw x[Mazda RX4] x[Mazda RX4 Wag] x[Datsun 710] x[Hornet 4 Drive]
#>   1            21               21            26                21
#>   2            21               21            25                21
#>   3            18               18            25                18
#>   4            21               21            24                21
#>   5            21               21            25                21
#>   6            21               21            24                21
#>   7            23               23            28                23
#>   8            21               21            26                21
#>   9            21               21            28                21
#>   10           22               22            25                22
#>     variable
#> draw x[Hornet Sportabout] x[Valiant] x[Duster 360] x[Merc 240D]
#>   1                    15         21            13           28
#>   2                    15         21            13           28
#>   3                    14         19            15           27
#>   4                    17         21            17           26
#>   5                    17         21            16           27
#>   6                    17         21            17           26
#>   7                    15         23            12           30
#>   8                    16         21            16           28
#>   9                    15         22            13           30
#>   10                   18         22            17           27
#> # ... with 3990 more draws, and 24 more variables

That said, it does seem like since draws_matrix() has an .nchains argument, perhaps as_draws_matrix() should too?

wds15 commented 2 years ago

Hi!

Indeed, as_draws_matrix(rvar(rstantools_samp, nchains=4)) gives me what I want for the example I quoted. Maybe all of the as_draws_* should have a .nchains argument? Certainly, the as_draws_matrix needs it... and this needs doc on the format posterior expects things to be (column-major).

Thanks!