StanJulia / StanSample.jl

WIP: Wrapper package for the sample method in Stan's cmdstan executable.
MIT License
18 stars 4 forks source link

read_samples returns incorrect results for samples from previous run #57

Closed itsdfish closed 2 years ago

itsdfish commented 2 years ago

Hi Rob,

read_samples returned incorrect results when I tried to re-load a chain from a previous run of a model. I assumed the benefit of reading the samples was the ability to re-load results from a long running model. Here is a MWE:

#######################################################################################
#                               set environment variables
#######################################################################################
# replace with your local directory
ENV["CMDSTAN_HOME"] = "/home/dfish/cmdstan"
#######################################################################################
#                                  load packages
#######################################################################################
cd(@__DIR__)
using Pkg
Pkg.activate("..")
using StanSample, SequentialSamplingModels, Random, MCMCChains
using StatsPlots, ACTRModels, DataFrames
tempdir = pwd() * "/tmp"
#######################################################################################
#                                     Generate Data
#######################################################################################
seed = 350
Random.seed!(seed)
n_obs = 50
y = randn(n_obs)

stan_data = Dict(
    "y" => y,
    "n_obs" => n_obs)
#######################################################################################
#                                     Load Model
#######################################################################################
model = "
data{
     // total observations
     int n_obs;
     // observations
    vector[n_obs] y;
}

parameters {
    real mu;
    real<lower=0> sigma;
}

model {
    mu ~ normal(0, 1);
    sigma ~ gamma(1, 1);
    y ~ normal(mu, sigma); 
}"
stan_model = SampleModel("temp", model, tempdir)
#######################################################################################
#                                  estimate parameters
#######################################################################################
# run the sampler
stan_sample(
    stan_model;
    data = stan_data,
    seed,
    num_chains = 4,
    num_samples = 1000,
    num_warmups = 1000,
    save_warmup = false
)

samples = read_samples(stan_model, :mcmcchains)

Start a new session and run everything except for the sampler:

#######################################################################################
#                               set environment variables
#######################################################################################
# replace with your local directory
ENV["CMDSTAN_HOME"] = "/home/dfish/cmdstan"
#######################################################################################
#                                  load packages
#######################################################################################
cd(@__DIR__)
using Pkg
Pkg.activate("..")
using StanSample, SequentialSamplingModels, Random, MCMCChains
using StatsPlots, ACTRModels, DataFrames
tempdir = pwd() * "/tmp"
#######################################################################################
#                                     Generate Data
#######################################################################################
seed = 350
Random.seed!(seed)
n_obs = 50
y = randn(n_obs)

stan_data = Dict(
    "y" => y,
    "n_obs" => n_obs)
#######################################################################################
#                                     Load Model
#######################################################################################
model = "
data{
     // total observations
     int n_obs;
     // observations
    vector[n_obs] y;
}

parameters {
    real mu;
    real<lower=0> sigma;
}

model {
    mu ~ normal(0, 1);
    sigma ~ gamma(1, 1);
    y ~ normal(mu, sigma); 
}"
stan_model = SampleModel("temp", model, tempdir)
samples = read_samples(stan_model, :mcmcchains)

Results

julia> samples = read_samples(stan_model, :mcmcchains)
Chains MCMC chain (1000×2×4 Array{Float64, 3}):

Iterations        = 1:1:1000
Number of chains  = 4
Samples per chain = 1000
parameters        = mu, sigma
internals         = 

Summary Statistics
  parameters      mean       std   naive_se      mcse       ess      rhat 
      Symbol   Float64   Float64    Float64   Float64   Float64   Float64 

          mu   -0.0554    0.1159     0.0018    0.0122   11.4043    1.8670
       sigma    0.2252    0.3927     0.0062    0.0494    8.1135    9.3601

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          mu   -0.3866    0.0000    0.0000    0.0000    0.0000
       sigma    0.0000    0.0000    0.0000    0.1700    1.0135

Note that rhat is wrong.

Version


Julia 1.8.1
StanSample v6.10.0
goedman commented 2 years ago

Hi Chris, got this message pretty late yesterday evening (I'm currently traveling in Europe) so just took a look. Two quick questions:

When you run the 2nd part, the first process is still running? That is pretty dangerous. I have no idea how the timing of the (in your example 4) processes that create the 4 chains is or if the .csv file are in a suitable state to be read.

Also, am I correct you run this MWE in an environment with all these packages installed and in the directory where functions.jl is stored?

itsdfish commented 2 years ago

Rob, no problem. I hope you are enjoying Europe.

Let me backup and explain my goal. I have a long running model and I would like to re-load the samples into a different session as a chain to avoid re-running the model. So to answer your first question, the first process has completed. My understanding is that it should be safe to re-load the completed chain with samples = read_samples(stan_model, :mcmcchains) without re-running stan_sample. As for the second question, you are correct. I am using a project specific environment. I accidentally forgot to remove include and the Pkg statements (I removed these statements from the code above). I don't think that part of the code was critical to reproducing the problem.

goedman commented 2 years ago

Hi Chris,

If I understand your example correctly, that works for DataFrames but not for MCMCChains:

cd(@__DIR__)
using StanSample, Random, MCMCChains
tempdir = pwd() * "/tmp"
####################################################################
#                                     Generate Data
####################################################################
seed = 350
Random.seed!(seed)
n_obs = 50
y = randn(n_obs)

stan_data = Dict(
    "y" => y,
    "n_obs" => n_obs)

model = "
data{
     // total observations
     int n_obs;
     // observations
    vector[n_obs] y;
}

parameters {
    real mu;
    real<lower=0> sigma;
}

model {
    mu ~ normal(0, 1);
    sigma ~ gamma(1, 1);
    y ~ normal(mu, sigma); 
}"
sm_01 = SampleModel("temp", model, tempdir)

# run the sampler
rc_01 = stan_sample(
    sm_01;
    data = stan_data,
    seed,
    num_chains = 4,
    num_samples = 1000,
    num_warmups = 1000,
    save_warmup = false
)

if success(rc_01)
     post_01 = read_samples(sm_01, :dataframe)
     post_01 |> display
     chn_01 = read_samples(sm_01, :mcmcchains)
     chn_01 |> display
end

sdf_01 = read_summary(sm_01)
sdf_01[:, 1:5] |> display
sdf_01[:, [1, 6,7,8, 9, 10]] |> display

sm_01 = SampleModel("temp", model, tempdir)
if success(rc_01)
     post_02 = read_samples(sm_01, :dataframe)
     post_02 |> display
     chn_02 = read_samples(sm_01, :mcmcchains)
     chn_02 |> display
end

sdf_02 = read_summary(sm_01)
sdf_02[:, 1:5] |> display
sdf_02[:, [1, 6,7,8, 9, 10]] |> display

which produces:

4000×2 DataFrame
  Row │ mu          sigma    
      │ Float64     Float64  
──────┼──────────────────────
    1 │ -0.13932    0.941601
    2 │ -0.282811   0.79359
    3 │ -0.114992   0.897347
    4 │ -0.289828   0.878025
    5 │ -0.340496   0.996096
    6 │ -0.0992094  1.02213
    7 │ -0.311144   0.870863
    8 │  0.0183361  0.986961
    9 │ -0.197492   0.93263
   10 │ -0.118029   0.989948
   11 │ -0.128918   1.01096
   12 │ -0.369459   0.771521
   13 │ -0.105138   0.745827
   14 │ -0.104644   0.970093
   15 │ -0.195001   0.871644
   16 │ -0.287177   0.871411
  ⋮   │     ⋮          ⋮
 3986 │ -0.456994   0.902656
 3987 │ -0.535066   0.958956
 3988 │ -0.220415   1.00202
 3989 │ -0.225893   0.885779
 3990 │ -0.330735   0.886907
 3991 │  0.0115522  0.824976
 3992 │ -0.035688   0.862346
 3993 │ -0.409221   0.870086
 3994 │ -0.333642   0.961561
 3995 │ -0.307725   0.941666
 3996 │ -0.493181   0.890972
 3997 │ -0.33653    0.846751
 3998 │ -0.295537   0.894515
 3999 │ -0.192301   0.854393
 4000 │ -0.284196   0.90213
            3969 rows omitted
Chains MCMC chain (1000×2×4 Array{Float64, 3}):

Iterations        = 1:1:1000
Number of chains  = 4
Samples per chain = 1000
parameters        = mu, sigma
internals         = 

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess       ⋯
      Symbol   Float64   Float64    Float64   Float64     Float64   Flo ⋯

          mu   -0.2177    0.1259     0.0020    0.0023   2764.8900    1. ⋯
       sigma    0.9007    0.0926     0.0015    0.0018   2664.0283    1. ⋯
                                                         1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          mu   -0.4607   -0.3039   -0.2176   -0.1322    0.0295
       sigma    0.7423    0.8343    0.8927    0.9580    1.1000

9×5 DataFrame
 Row │ parameters     mean        mcse           std        5%         
     │ Symbol         Float64     Float64        Float64    Float64    
─────┼─────────────────────────────────────────────────────────────────
   1 │ lp__           -20.5866      0.024313     1.0027     -22.5778
   2 │ accept_stat__    0.928174    0.00141704   0.0938884    0.739732
   3 │ stepsize__       0.797606    0.000663868  0.0296891    0.764128
   4 │ treedepth__      2.07925     0.0104739    0.639978     1.0
   5 │ n_leapfrog__     4.122       0.0325533    2.01944      1.0
   6 │ divergent__      0.0       NaN            0.0          0.0
   7 │ energy__        21.5715      0.0351367    1.41791     19.9443
   8 │ mu              -0.21774     0.00241207   0.125947    -0.424523
   9 │ sigma            0.900656    0.00180192   0.0925502    0.762674
9×6 DataFrame
 Row │ parameters     50%         95%          ess      n_eff/s   r_hat ⋯
     │ Symbol         Float64     Float64      Float64  Float64   Float ⋯
─────┼───────────────────────────────────────────────────────────────────
   1 │ lp__           -20.2809    -19.6317     1700.84   39554.4    1.0 ⋯
   2 │ accept_stat__    0.962805    1.0        4389.96  102092.0    1.0
   3 │ stepsize__       0.792195    0.845499   2000.0    46511.6    1.0
   4 │ treedepth__      2.0         3.0        3733.46   86824.7    1.0
   5 │ n_leapfrog__     3.0         7.0        3848.31   89495.7    1.0 ⋯
   6 │ divergent__      0.0         0.0         NaN        NaN    NaN
   7 │ energy__        21.2626     24.2826     1628.45   37870.9    1.0
   8 │ mu              -0.217634   -0.0111363  2726.44   63405.5    1.0
   9 │ sigma            0.892779    1.06554    2638.05   61350.0    1.0 ⋯
                                                         1 column omitted
4000×2 DataFrame
  Row │ mu          sigma    
      │ Float64     Float64  
──────┼──────────────────────
    1 │ -0.13932    0.941601
    2 │ -0.282811   0.79359
    3 │ -0.114992   0.897347
    4 │ -0.289828   0.878025
    5 │ -0.340496   0.996096
    6 │ -0.0992094  1.02213
    7 │ -0.311144   0.870863
    8 │  0.0183361  0.986961
    9 │ -0.197492   0.93263
   10 │ -0.118029   0.989948
   11 │ -0.128918   1.01096
   12 │ -0.369459   0.771521
   13 │ -0.105138   0.745827
   14 │ -0.104644   0.970093
   15 │ -0.195001   0.871644
   16 │ -0.287177   0.871411
  ⋮   │     ⋮          ⋮
 3986 │  0.0        0.0
 3987 │  0.0        0.0
 3988 │  0.0        0.0
 3989 │  0.0        0.0
 3990 │  0.0        0.0
 3991 │  0.0        0.0
 3992 │  0.0        0.0
 3993 │  0.0        0.0
 3994 │  0.0        0.0
 3995 │  0.0        0.0
 3996 │  0.0        0.0
 3997 │  0.0        0.0
 3998 │  0.0        0.0
 3999 │  0.0        0.0
 4000 │  0.0        0.0
            3969 rows omitted
Chains MCMC chain (1000×2×4 Array{Float64, 3}):

Iterations        = 1:1:1000
Number of chains  = 4
Samples per chain = 1000
parameters        = mu, sigma
internals         = 

Summary Statistics
  parameters      mean       std   naive_se      mcse       ess      rh ⋯
      Symbol   Float64   Float64    Float64   Float64   Float64   Float ⋯

          mu   -0.0545    0.1133     0.0018    0.0120   11.3090    1.90 ⋯
       sigma    0.2240    0.3906     0.0062    0.0492    8.1103    9.47 ⋯
                                                         1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          mu   -0.3782    0.0000    0.0000    0.0000    0.0000
       sigma    0.0000    0.0000    0.0000    0.1619    1.0070

9×5 DataFrame
 Row │ parameters     mean        mcse           std        5%         
     │ Symbol         Float64     Float64        Float64    Float64    
─────┼─────────────────────────────────────────────────────────────────
   1 │ lp__           -20.5866      0.024313     1.0027     -22.5778
   2 │ accept_stat__    0.928174    0.00141704   0.0938884    0.739732
   3 │ stepsize__       0.797606    0.000663868  0.0296891    0.764128
   4 │ treedepth__      2.07925     0.0104739    0.639978     1.0
   5 │ n_leapfrog__     4.122       0.0325533    2.01944      1.0
   6 │ divergent__      0.0       NaN            0.0          0.0
   7 │ energy__        21.5715      0.0351367    1.41791     19.9443
   8 │ mu              -0.21774     0.00241207   0.125947    -0.424523
   9 │ sigma            0.900656    0.00180192   0.0925502    0.762674
9×6 DataFrame
 Row │ parameters     50%         95%          ess      n_eff/s   r_hat ⋯
     │ Symbol         Float64     Float64      Float64  Float64   Float ⋯
─────┼───────────────────────────────────────────────────────────────────
   1 │ lp__           -20.2809    -19.6317     1700.84   39554.4    1.0 ⋯
   2 │ accept_stat__    0.962805    1.0        4389.96  102092.0    1.0
   3 │ stepsize__       0.792195    0.845499   2000.0    46511.6    1.0
   4 │ treedepth__      2.0         3.0        3733.46   86824.7    1.0
   5 │ n_leapfrog__     3.0         7.0        3848.31   89495.7    1.0 ⋯
   6 │ divergent__      0.0         0.0         NaN        NaN    NaN
   7 │ energy__        21.2626     24.2826     1628.45   37870.9    1.0
   8 │ mu              -0.217634   -0.0111363  2726.44   63405.5    1.0
   9 │ sigma            0.892779    1.06554    2638.05   61350.0    1.0 ⋯
                                                         1 column omitted

The 2nd set of values produced by MCMCChains are bogus. Is that indeed the issue? The .csv file is still unchanged as seen when read in as DataFrame or analyses using read_summary.

itsdfish commented 2 years ago

Hi Rob. That is indeed the issue. I did not test DataFrames, but your example show that it works. For some reason, mcmcchains is incorrect on the second read.

goedman commented 2 years ago

Yip, at least the summary produced by MCMCChains.

The other part of your question, how many draws are you trying to collect? The bogus ess in your output fragment is really low for 4000 samples, but that is the same problem. Do you have an idea what the n_eff/s is?

itsdfish commented 2 years ago

I am trying to collect 1000 per chain, for a total of 4000 samples. The output for mcmchains does not provide an estimate of n_eff/s. Here is what it provides on the first run:

julia> samples = read_samples(stan_model, :mcmcchains)
Chains MCMC chain (1000×2×4 Array{Float64, 3}):

Iterations        = 1:1:1000
Number of chains  = 4
Samples per chain = 1000
parameters        = mu, sigma
internals         = 

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat 
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64 

          mu   -0.2182    0.1277     0.0020    0.0023   2707.7748    1.0009
       sigma    0.9023    0.0940     0.0015    0.0018   2705.7406    0.9997

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          mu   -0.4664   -0.3065   -0.2196   -0.1301    0.0349
       sigma    0.7441    0.8358    0.8926    0.9615    1.1030

Given that the model runs quickly, I would suspect n_eff/s is several orders of magnitude larger than ess.

goedman commented 2 years ago

My remark was based on "I assumed the benefit of reading the samples was the ability to re-load results from a long running model". So I assumed your actual runs take 10s of minutes or even hours and you wanted to know after say 400 iterations if the results looked ok. A few years ago I had such a case and I decided to collect a limited number of draws (after warmup) and append the DataFrames collected from subsequent (newly started) runs.

If this is feasible of course depends on num_warmups required, the time it take to complete warmup, etc.

itsdfish commented 2 years ago

I see how looking at interim results could be helpful. In my particular case, I have a hierarchical model that takes a long time to run (30 - 90 minutes). What I was hoping to do was have a way to read in samples from a previous completed model run in case I wanted to modify a plot or run additional analyses days or weeks later. Is that a unintended use of read_samples or do you think there is a bug somewhere?

goedman commented 2 years ago

We should be able to simplify below function and generate a quick table (a3d_array) and associated column names (cnames) and show the same problem:

using .MCMCChains

function convert_a3d(a3d_array, cnames, ::Val{:mcmcchains};
  start=1,
  kwargs...)
  cnames = String.(cnames)
  pi = filter(p -> length(p) > 2 && p[end-1:end] == "__", cnames)
  p = filter(p -> !(p in  pi), cnames)

  MCMCChains.Chains(a3d_array[start:end,:,:],
    cnames,
    Dict(
      :parameters => p,
      :internals => pi
    );
    start=start
  )
end
goedman commented 2 years ago

No, that is perfectly feasible (provided you store the results in a tmp dir as you do in your example and hold on to or recreate the SampleModel object.

In the above example I can did recreate a new SampleModel (sm_02) and it works fine:

sm_02 = SampleModel("temp", model, tempdir)
if success(rc_01)
     post_02 = read_samples(sm_02, :dataframe)
     post_02 |> display
     chn_02 = read_samples(sm_02, :mcmcchains)
     chn_02 |> display
end

sdf_02 = read_summary(sm_02)
sdf_02[:, 1:5] |> display
sdf_02[:, [1, 6,7,8, 9, 10]] |> display

Same result. So the problem is really why the new MCMCChains object is not correct.

itsdfish commented 2 years ago

The main difference is whether stan_sample is called. I suspect stan_sample might modify the stan_model object in some way that is important for reading in the samples correctly to mcmcchains.

goedman commented 2 years ago

But we're not calling `stan_sample()'. I'll try above MWE idea in creating an a3d_array for 2 parameters.

goedman commented 2 years ago

I'm getting:

using MCMCChains, Random, Distributions

function convert_a3d(a3d_array, cnames; start=1)
  cnames = String.(cnames)

  MCMCChains.Chains(a3d_array[start:end,:,:],
    cnames,
    Dict(
      :parameters => cnames,
      :internals => []
    );
    start=start
  )
end

N=100
cnames = ["mu", "sigma"]
a3d = hcat(rand(Normal(-1, 1), N), rand(Normal(0.1, 0.1), N))

chains_01 = convert_a3d(a3d, cnames)
chains_01 |> display

chains_02 = convert_a3d(a3d, cnames)
chains_02 |> display

and output:

julia> include("/Users/rob/.julia/dev/StanSample/test/test_chris/chris_02.jl");
Chains MCMC chain (100×2×1 Array{Float64, 3}):

Iterations        = 1:1:100
Number of chains  = 1
Samples per chain = 100
parameters        = mu, sigma
internals         = 

Summary Statistics
  parameters      mean       std   naive_se      mcse       ess      rhat 
      Symbol   Float64   Float64    Float64   Float64   Float64   Float64 

          mu   -0.9000    0.9211     0.0921    0.0927   95.5359    0.9925
       sigma    0.1176    0.1010     0.0101    0.0119   65.6785    1.0002

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          mu   -2.7929   -1.4471   -0.9425   -0.3741    0.8873
       sigma   -0.0698    0.0497    0.1144    0.1962    0.2955

Chains MCMC chain (100×2×1 Array{Float64, 3}):

Iterations        = 1:1:100
Number of chains  = 1
Samples per chain = 100
parameters        = mu, sigma
internals         = 

Summary Statistics
  parameters      mean       std   naive_se      mcse       ess      rhat 
      Symbol   Float64   Float64    Float64   Float64   Float64   Float64 

          mu   -0.9000    0.9211     0.0921    0.0927   95.5359    0.9925
       sigma    0.1176    0.1010     0.0101    0.0119   65.6785    1.0002

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          mu   -2.7929   -1.4471   -0.9425   -0.3741    0.8873
       sigma   -0.0698    0.0497    0.1144    0.1962    0.2955
goedman commented 2 years ago

Extending the MWE to 4 chains also shows convert_a3d() works and produces proper MCMCChains objects both times.

julia> include("/Users/rob/.julia/dev/StanSample/test/test_chris/chris_02.jl");
Chains MCMC chain (100×2×4 Array{Float64, 3}):

Iterations        = 1:1:100
Number of chains  = 4
Samples per chain = 100
parameters        = mu, sigma
internals         = 

Summary Statistics
  parameters      mean       std   naive_se      mcse        ess      rhat 
      Symbol   Float64   Float64    Float64   Float64    Float64   Float64 

          mu   -0.9515    0.9877     0.0494    0.0471   364.3159    0.9979
       sigma    0.0929    0.1043     0.0052    0.0059   392.5938    0.9994

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          mu   -2.9493   -1.5587   -0.9220   -0.2796    1.0009
       sigma   -0.1109    0.0268    0.0929    0.1583    0.2940

Chains MCMC chain (100×2×4 Array{Float64, 3}):

Iterations        = 1:1:100
Number of chains  = 4
Samples per chain = 100
parameters        = mu, sigma
internals         = 

Summary Statistics
  parameters      mean       std   naive_se      mcse        ess      rhat 
      Symbol   Float64   Float64    Float64   Float64    Float64   Float64 

          mu   -0.9515    0.9877     0.0494    0.0471   364.3159    0.9979
       sigma    0.0929    0.1043     0.0052    0.0059   392.5938    0.9994

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

          mu   -2.9493   -1.5587   -0.9220   -0.2796    1.0009
       sigma   -0.1109    0.0268    0.0929    0.1583    0.2940

Interesting!

goedman commented 2 years ago

Ok, found a clue. Reading in the draws as an array:

if success(rc_01)
     post_01 = read_samples(sm_01, :array)
     mean(post_01; dims=1) |> display
end

sm_02 = SampleModel("temp", model, tempdir)
if success(rc_01)
     post_02 = read_samples(sm_02, :array)
     mean(post_02; dims=1) |> display
end

produces:

1×2×4 Array{Float64, 3}:
[:, :, 1] =
 -0.218063  0.896114

[:, :, 2] =
 -0.21141  0.902679

[:, :, 3] =
 -0.218086  0.903136

[:, :, 4] =
 -0.223403  0.900697

1×2×4 Array{Float64, 3}:
[:, :, 1] =
 -0.218063  0.896114

[:, :, 2] =
 0.0  0.0

[:, :, 3] =
 0.0  0.0

[:, :, 4] =
 0.0  0.0

Must be a bug in read_samples(). Will have a look.

itsdfish commented 2 years ago

Very interesting indeed. I'm glad you found a clue. Thanks for taking the time to look into this problem.

goedman commented 2 years ago

Ok, found it. It happens when the SampleModel is recreated. The various ..._num_chains settings are updated when stan_sample() is called. Probably introduced when I enabled C++ threads.

Not sure how quickly I can fix this, might have to wait until I'm back in the US later this week.

For now, a workaround is to set sm_02.num_julia_chains = sm_02.num_chains.

itsdfish commented 2 years ago

Awesome. Thanks for the workaround. I will use that until you return and have time to look into it.

goedman commented 2 years ago

Hi Chris,

There was a bug in the handling of the default case for SampleModels (4 chains, Julia processes). I'll fix that.

But in case the defaults are updated when calling stan_sample() (e.g. use_cpp_chains=true or num_chains=1) this can only be done using serialization of the SampleModel object:

cd(@__DIR__)
using StanSample, Random, MCMCChains, Serialization
tempdir = pwd() * "/tmp"
####################################################################
#                                     Generate Data
####################################################################
seed = 350
Random.seed!(seed)
n_obs = 50
y = randn(n_obs)

stan_data = Dict(
    "y" => y,
    "n_obs" => n_obs)

model = "
data{
     // total observations
     int n_obs;
     // observations
    vector[n_obs] y;
}

parameters {
    real mu;
    real<lower=0> sigma;
}

model {
    mu ~ normal(0, 1);
    sigma ~ gamma(1, 1);
    y ~ normal(mu, sigma); 
}"
sm_01 = SampleModel("temp", model, tempdir)

# run the sampler
rc_01 = stan_sample(
    sm_01;
    data = stan_data,
    seed,
    num_chains = 4,
    num_samples = 1000,
    num_warmups = 1000,
    save_warmup = false
)

# Serialize after calling stan_sample!
serialize(joinpath(sm_01.tmpdir, "sm_01"), sm_01)

if success(rc_01)
     chn_01 = read_samples(sm_01, :mcmcchains)
     chn_01 |> display
end

sm_02 = deserialize(joinpath(sm_01.tmpdir, "sm_01"))

if success(rc_01)
     chn_02 = read_samples(sm_02, :mcmcchains)
     chn_02 |> display
end

For this purpose I'll add Serialization to the dependencies for StanSample.jl in the next release (v6.10.1).

In your case, you used the defaults in your MWE, so that will work out of the box in v6.10.1.