rmcelreath / rethinking

Statistical Rethinking course and book package
2.1k stars 596 forks source link

Fix saving and loading ulam() models when using cmdstanr; Fix log likelihood bugs for multi_normal outcome variables #425

Open timjzee opened 3 months ago

timjzee commented 3 months ago

By default, cmdstanr saves model output in temporary files and only loads that data into memory when it is required. This means that when a fitted model is written to a file, a lot of data is not included in that file. So when the current R session is exited, the temporary files are deleted, and the fitted models can no longer be fully loaded. See example code below:

> library(rethinking)
Loading required package: cmdstanr
This is cmdstanr version 0.7.1
- CmdStanR documentation and vignettes: mc-stan.org/cmdstanr
- CmdStan path: /vol/customopt/cmdstan-2.34.1
- CmdStan version: 2.34.1

> fit_stan <- ulam(
    alist(
        y ~ dnorm( mu , sigma ),
        mu ~ dnorm( 0 , 10 ),
        sigma ~ dexp( 1 )
    ), data=list(y=c(-1,1)), file = "test"
)
Compiling Stan program...
Running MCMC with 1 chain, with 1 thread(s) per chain...

Chain 1 Iteration:   1 / 1000 [  0%]  (Warmup)
Chain 1 Iteration: 100 / 1000 [ 10%]  (Warmup)
Chain 1 Iteration: 200 / 1000 [ 20%]  (Warmup)
Chain 1 Iteration: 300 / 1000 [ 30%]  (Warmup)
Chain 1 Iteration: 400 / 1000 [ 40%]  (Warmup)
Chain 1 Iteration: 500 / 1000 [ 50%]  (Warmup)
Chain 1 Iteration: 501 / 1000 [ 50%]  (Sampling)
Chain 1 Iteration: 600 / 1000 [ 60%]  (Sampling)
Chain 1 Iteration: 700 / 1000 [ 70%]  (Sampling)
Chain 1 Iteration: 800 / 1000 [ 80%]  (Sampling)
Chain 1 Iteration: 900 / 1000 [ 90%]  (Sampling)
Chain 1 Iteration: 1000 / 1000 [100%]  (Sampling)
Chain 1 finished in 0.0 seconds.
Saving result as test.rds

# Temporary files still exist, so we can access our model data:
> precis(fit_stan)
       mean   sd  5.5% 94.5% rhat ess_bulk
mu    -0.12 1.06 -1.91  1.46 1.00   159.00
sigma  1.43 0.79  0.60  2.75 1.01   169.57

> q()

# however when we exit and start a new session:

> library(rethinking)
> fit_stan <- readRDS("test.rds")
> precis(fit_stan)
Error in read_cmdstan_csv(files = self$output_files(include_failed = FALSE),  :
  Assertion on 'files' failed: File does not exist: '/scratch/RtmpDWbuV6/ulam_cmdstanr_8b4a70079e6636864e593d38ad729313-202403191850-1-5b97c7.csv'.

The fix consists of two lines of code, and was inspired by these relevant pages: https://discourse.mc-stan.org/t/error-in-read-cmdstan-csv-assertion-on-files-failed/30150 https://mc-stan.org/cmdstanr/reference/read_cmdstan_csv.html https://mc-stan.org/cmdstanr/reference/fit-method-save_output_files.html

Using stanfit <- as_cmdstan_fit(cmdstanfit$output_files()) we assign the data from the temporary files to stanfit and store the data inside the ulam object.

Using this fix, the example code works:

> library(rethinking)
> fit_ulam <- readRDS("test.rds")
> precis(fit_ulam)
       mean   sd  5.5% 94.5% rhat ess_bulk
mu    -0.03 0.95 -1.62  1.31    1   218.40
sigma  1.53 0.72  0.72  2.88    1   100.27

This was tested on:

timjzee commented 2 days ago

I also fixed issues that arise when specifying log_lik==TRUE on multivariate models. The first issue, fixed by 1a040c3, is that currently ulam assumes that Stan can calculate the log likelihood of multivariate models using a vectorised notation. This is not the case, as noted by a Stan developer here.

The second issue turned out to be an indexing issue specific to multivariate outcomes, see this explanation for full details, including example model. I made sure that my fix did not break anything for models with a single outcome variable.

Paging @rmcelreath , as these issues (including the temporary file issue) are likely to affect other users and my fixes are small enough to quickly check.