StanJulia / StanSample.jl

WIP: Wrapper package for the sample method in Stan's cmdstan executable.
MIT License
18 stars 4 forks source link

Updates to `inferencedata` method #67

Closed sethaxen closed 1 year ago

sethaxen commented 1 year ago

This PR proposes some changes to inferencedata1. The main ones are:

Here's an example output, using the notebook at https://github.com/StanJulia/Stan.jl/blob/master/Examples_Notebooks/InferenceObjects.jl

julia> idata = StanSample.inferencedata(
           m_schools;
           posterior_predictive_var=:y_hat,
           log_likelihood_var=[:log_lik],
           dims=(; (k => [:school] for k in [:theta, :theta_tilde, :y_hat, :log_lik])...),
       )
InferenceData with groups:
  > posterior
  > posterior_predictive
  > log_likelihood
  > sample_stats
  > warmup_posterior
  > warmup_posterior_predictive
  > warmup_sample_stats
  > warmup_log_likelihood

julia> idata.posterior
Dataset with dimensions: 
  Dim{:school} Sampled{Int64} Base.OneTo(8) ForwardOrdered Regular Points,
  Dim{:draw} Sampled{Int64} 1001:2000 ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
and 4 layers:
  :theta_tilde Float64 dims: Dim{:school}, Dim{:draw}, Dim{:chain} (8×1000×4)
  :mu          Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :tau         Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :theta       Float64 dims: Dim{:school}, Dim{:draw}, Dim{:chain} (8×1000×4)

with metadata Dict{String, Any} with 1 entry:
  "created_at" => "2022-12-08T22:34:23.197"

julia> idata.sample_stats
Dataset with dimensions: 
  Dim{:draw} Sampled{Int64} 1001:2000 ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
and 7 layers:
  :tree_depth      Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :energy          Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :diverging       Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :acceptance_rate Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :n_steps         Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :lp              Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :step_size       Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)

with metadata Dict{String, Any} with 1 entry:
  "created_at" => "2022-12-08T22:34:23.07"

julia> idata.log_likelihood
Dataset with dimensions: 
  Dim{:school} Sampled{Int64} Base.OneTo(8) ForwardOrdered Regular Points,
  Dim{:draw} Sampled{Int64} 1001:2000 ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
and 1 layer:
  :log_lik Float64 dims: Dim{:school}, Dim{:draw}, Dim{:chain} (8×1000×4)

with metadata Dict{String, Any} with 1 entry:
  "created_at" => "2022-12-08T22:34:23.095"

julia> idata.warmup_posterior
Dataset with dimensions: 
  Dim{:school} Sampled{Int64} Base.OneTo(8) ForwardOrdered Regular Points,
  Dim{:draw} Sampled{Int64} 1:1000 ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
and 4 layers:
  :theta_tilde Float64 dims: Dim{:school}, Dim{:draw}, Dim{:chain} (8×1000×4)
  :mu          Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :tau         Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :theta       Float64 dims: Dim{:school}, Dim{:draw}, Dim{:chain} (8×1000×4)

with metadata Dict{String, Any} with 1 entry:
  "created_at" => "2022-12-08T22:34:23.197"

Note that in the next breaking release of InferenceObjects, the dimension orders of arrays will change (https://github.com/arviz-devs/InferenceObjects.jl/pull/40), and the default indices for all dimensions will be the axes of the underlying arrays (https://github.com/arviz-devs/InferenceObjects.jl/pull/39; so after splitting samples from warmup, no reindexing will be needed)

sethaxen commented 1 year ago

Relates #60