StatisticalRethinkingJulia / StatisticalRethinking.jl

Julia package with selected functions in the R package `rethinking`. Used in the SR2... projects.
MIT License
385 stars 32 forks source link

StatisticalRethinking v4 (SR4) #122

Closed goedman closed 2 years ago

goedman commented 3 years ago

Notes/ideas:

  1. The Stan side chains of SR4 will be based on AxisKeys.jl's KeyedArrays, e.g.:
    
    julia> cans = read_samples(m5_1s)
    3-dimensional KeyedArray(NamedDimsArray(...)) with keys:
    ↓   iteration ∈ 1000-element UnitRange{Int64}
    →   chain ∈ 4-element UnitRange{Int64}
    □   param ∈ 103-element Vector{Symbol}
    And data, 1000×4×103 Array{Float64, 3}:
    [showing 3 of 103 slices]
    [:, :, 1] ~ (:, :, :a):
         (1)          (2)          (3)           (4)
    (1)    0.164875    -0.0316173    0.117778     -0.0587963
    (2)    0.0683526   -0.112191    -0.152213     -0.132654
    (3)   -0.0570758   -0.0428441    0.0723148    -0.0818758
    (4)    0.0736573    0.0684627   -0.153873      0.0619487
      ⋮                                          
    (996)    0.100458    -0.0313996   -0.111339     -0.0978581
    (997)    0.0250855   -0.228102    -0.109103      0.0180321
    (998)   -0.0987535    0.12663      0.00225604   -0.00989224
    (999)   -0.0401901    0.127126    -0.0615762    -0.00764688
    (1000)    0.0292844   -0.0801682   -0.036483     -0.114162

[:, :, 52] ~ (:, :, Symbol("mu.49")): (1) (2) (3) (4) (1) 0.0642467 -0.12793 -0.00297304 -0.147075 (2) -0.0459464 -0.245202 -0.255521 -0.213369 (3) -0.159283 -0.141249 -0.0534471 -0.169041 (4) -0.0377947 -0.0500161 -0.280926 -0.0721203 ⋮
(996) 0.0236626 -0.0940149 -0.223359 -0.20557 (997) -0.108117 -0.331829 -0.226459 -0.0873184 (998) -0.198844 0.0263513 -0.0773827 -0.123203 (999) -0.136195 0.0142719 -0.150853 -0.135881 (1000) -0.117577 -0.18169 -0.155621 -0.249585

[:, :, 103] ~ (:, :, Symbol("log_lik.50")): (1) (2) (3) (4) (1) -1.00343 -0.793367 -1.05779 -0.863638 (2) -0.969577 -0.901926 -0.808977 -0.722824 (3) -0.807445 -0.838343 -1.07894 -0.74081 (4) -1.04428 -0.991569 -0.875301 -1.1158 ⋮
(996) -0.896697 -0.694199 -0.837646 -0.777261 (997) -1.07123 -0.986055 -0.859606 -0.879211 (998) -0.769609 -0.864846 -0.633001 -0.816084 (999) -0.742747 -0.98111 -0.624656 -1.04728 (1000) -1.19845 -0.779751 -0.962009 -0.893048

and:

julia> axiskeys(chns) (1:1000, 1:4, [:a, :bA, :sigma, Symbol("mu.1"), Symbol("mu.2") ... Symbol("mu.49)" , Symbol("mu.50"), Symbol("log_lik.1"), Symbol("log_lik.2") ... Symbol("log_lik.49"), Symbol("log_lik.50")] )


Using the overloaded `matrix()` method from the Tables.jl API to extract a vector parameter:

julia> chns_log_lik = matrix(chns, :log_lik) 3-dimensional KeyedArray(NamedDimsArray(...)) with keys: ↓ iteration ∈ 1000-element UnitRange{Int64} → chain ∈ 4-element UnitRange{Int64} □ param ∈ 50-element view(::Vector{Symbol},...) And data, 1000×4×50 view(::Array{Float64, 3}, :, :, [54, 55 … 102, 103]) with eltype Float64: [showing 3 of 50 slices] [:, :, 1] ~ (:, :, Symbol("log_lik.1")): (1) (2) (3) (4) (1) -1.70719 -2.21329 -1.82935 -2.08948 (2) -1.89324 -2.28388 -2.30598 -2.51493 (3) -2.22125 -2.12325 -1.75062 -2.38044 (4) -1.72746 -1.89979 -2.24047 -1.81943 ⋮
(996) -1.87282 -2.46479 -2.22433 -2.36482 (997) -1.85819 -2.08163 -2.19739 -2.02657 (998) -2.34589 -2.09177 -2.59066 -2.38515 (999) -2.36031 -1.99972 -2.74304 -1.81838 (1000) -1.71913 -2.30865 -1.94024 -2.4405

[:, :, 26] ~ (:, :, Symbol("log_lik.26")): (1) (2) (3) (4) (1) -0.910393 -0.741409 -0.811551 -0.845247 (2) -0.807438 -0.676429 -0.763983 -0.718875 (3) -0.739511 -0.786707 -0.868633 -0.722171 (4) -0.937871 -0.795258 -0.734667 -0.809258 ⋮
(996) -0.897659 -0.713518 -0.749929 -0.691888 (997) -0.806372 -0.966478 -0.747799 -0.782096 (998) -0.717109 -0.73475 -0.625633 -0.639031 (999) -0.687475 -0.754141 -0.589675 -0.872287 (1000) -0.862781 -0.717588 -0.834025 -0.617522

[:, :, 50] ~ (:, :, Symbol("log_lik.50")): (1) (2) (3) (4) (1) -1.00343 -0.793367 -1.05779 -0.863638 (2) -0.969577 -0.901926 -0.808977 -0.722824 (3) -0.807445 -0.838343 -1.07894 -0.74081 (4) -1.04428 -0.991569 -0.875301 -1.1158 ⋮
(996) -0.896697 -0.694199 -0.837646 -0.777261 (997) -1.07123 -0.986055 -0.859606 -0.879211 (998) -0.769609 -0.864846 -0.633001 -0.816084 (999) -0.742747 -0.98111 -0.624656 -1.04728 (1000) -1.19845 -0.779751 -0.962009 -0.893048



Other commonly used methods are `DataFrame(chains)`, concatenation of chains, etc.

2. SR4 will decouple plot choices from SR4 (e.g. I plan to introduce separate packages StatisticalRethinking(Stats?)Plots and StatisticalRethinkingMakie).

3. All Stan and Turing specific parts will be handled in the corresponding @require sections (currently I envisage Turing, StanSample, StanQuap, DiffEqBayesStan, MCMCChains and LogDensityProblems sections). See [e.g. see issue](https://github.com/StatisticalRethinkingJulia/StatisticalRethinking.jl/issues/118#issuecomment-898441974), where the conversion to a DataFrame requires MCMCChains.jl.

4. MCMCChains will be dropped as a standard dependency of SR4 (it has StatsPlots as a dependency). If required it needs to be provided by the (project) environment.

5. I'm considering to drop the heavy use of @reexport in Sr4, [e.g. see isssue](https://github.com/StatisticalRethinkingJulia/StatisticalRethinking.jl/issues/121#issue-970370821)

6. I'm considering to produce all output results as a KeyedArray, [e.g. see an example](https://github.com/ParadaCarleton/ParetoSmooth.jl). This will in some places also affect packages in StanJulia, e.g. StanOptimize, StanQuap and DiffEqBayesStan. These will be updated accordingly and all brought to what I consider v4 level. StanSample.jl is pretty much done and by default returns a KeyedArray chain as shown above.

7. A lot of documentation updates!! Introductory docs for StatisticalRethinking.jl and Stan.jl on Github, on-line docs for functions in all sub packages.
goedman commented 3 years ago

Steps/sequencing:

These are very significant changes to SR. Not all of this will be available the first time around.

These are my current priorities for phase one:

  1. Initially I'll work primarily on

    1.1. StatisticalRethinking.jl v4, 1.2 StatisticalRethinkingStan.jl (the project) and 1.3 StatisticalRethinkingPlots.jl.

    Thus Stan, i.e. Stan's cmdstan executable, will remain my primary means to draw MCMC samples.

  2. During the first phase it's basically a refactoring of the existing code base (mainly the introduction/completion of @require barriers and moving graphics to StatisticalRethinkingPlots.jl)

  3. Adapting the code base to AxisArray.jl based chains.

  4. Where possible, existing functions will be updated to accept KeyedArray chains and DataFrame based input.

  5. It's simple to convert an MCMCChains.Chain object into a KeyedArray chains object. A method for this will be provided. This will possibly drop some info added e.g. by Turing.

ParadaCarleton commented 3 years ago

If you haven't already switched to AxisArray.jl, you might want to check out DimensionalData.jl, which does essentially the same thing, but seems to be better documented and has better support. The AxisArrays/DimensionalData devs have suggested merging the two packages at some point in the future.

goedman commented 3 years ago

Thank you @ParadaCarleton , I will certainly take a look! I'm pretty impressed with AxisKeys.jl.

AxisArrays.jl, in my opinion, for chains at least, was a bit of a mistake. But a lot can improve in a few years!

goedman commented 2 years ago

In StatisticalRethinking v4 most of above ideas have been implemented (and a few more in fact). For now I'll close this issue.