arviz-devs / ArviZ.jl

Exploratory analysis of Bayesian models with Julia
https://julia.arviz.org
Other
106 stars 10 forks source link

`from_cmdstan` #335

Open jfb-h opened 4 months ago

jfb-h commented 4 months ago

Is there something similar in ArviZ.jl to from_cmdstan() in the python library? If not, is there a specific way to plug into InferenceData that you would recommend? Thanks for all your work in the ArviZ julia ecosystem, it seems like a great project!

sethaxen commented 4 months ago

Is there something similar in ArviZ.jl to from_cmdstan() in the python library? If not, is there a specific way to plug into InferenceData that you would recommend?

Integration with Stan is implemented directly in StanSample.jl. For an example, check out https://julia.arviz.org/ArviZ/stable/quickstart/#Plotting-with-Stan.jl-outputs .

Thanks for all your work in the ArviZ julia ecosystem, it seems like a great project!

You're welcome!

jfb-h commented 4 months ago

Perfect, that looks like what I need. I quite like working with cmdstan directly on the command line so I'll probably try to extract the csv reading and conversion from StanSample.jl, which I see just uses from_namedtuple under the hood. Would you in principle be interested in a PR with that to mirror the python interface in ArviZ.jl?

ahartikainen commented 4 months ago

I will comment that it is very normal to run Stan from CLI and then later read the samples and do analysis with programming (e.g. R/Python/Julia).

sethaxen commented 4 months ago

Perfect, that looks like what I need. I quite like working with cmdstan directly on the command line so I'll probably try to extract the csv reading and conversion from StanSample.jl, which I see just uses from_namedtuple under the hood. Would you in principle be interested in a PR with that to mirror the python interface in ArviZ.jl?

This makes sense. We used to wrap Python ArviZ's from_cmdstan function. Adding our own sounds worthwhile. If you contributed one in a PR, that would be great! Recently StanIO.jl split out the functionality for reading cmdstan draws from CSV files from StanSample.jl, so it might save you some time to use that here.

jfb-h commented 4 months ago

Thanks, I'll take a closer look at StanIO.jl. At first sight, it seems to have quite a few dependencies that ArviZ.jl doesn't have. Would you be willing to take up StanIO as a dependency or would we rather copy over the necessary bits at the cost of some duplication?

I'd like to do a PR but it might be a while since the semester has just started, and I might have to come back to you for some guidance, if that is ok.

sethaxen commented 4 months ago

Thanks, I'll take a closer look at StanIO.jl. At first sight, it seems to have quite a few dependencies that ArviZ.jl doesn't have. Would you be willing to take up StanIO as a dependency or would we rather copy over the necessary bits at the cost of some duplication?

Indeed it does. DataFrames.jl in particular. I think it would make the most sense to define from_cmdstan in our main package code but only implement it in an extension module named ArviZStanIOExt. Then we can make StanIO as a weak dependency so that this module is only loaded if the user already loads StanIO. We do something similar already for from_mcmcchains and from_samplechains.

But also, StanIO itself can probably make most of its dependencies weak dependencies using extensions as well. I'll open an issue there. Either way, I think the extension route makes the most sense here.

I'd like to do a PR but it might be a while since the semester has just started, and I might have to come back to you for some guidance, if that is ok.

Of course! Happy to help.

jfb-h commented 4 months ago

Hi @sethaxen,

after a bit of tinkering, I now have a first basic implementation of this.

I ended up not relying on StanIO.jl for now, as the JuliaStan ecosystem was a bit hard to navigate for me (e.g., there seems to be some overlap in functionality between StanIO.jl and StanSample.jl, probably the transition to StanIO.jl is not completely through yet?). I'd be happy to take any kind of advice regarding this integration, but at least for the basic functionality this turned out to be not that much code, so maybe it could also live in ArviZ.jl directly for the time being?

I did take up a dependency on the DelimitedFiles.jl standard library, which can do the csv parsing very cleanly so we do not have to roll our own parser, as done in StanIO.jl at the moment.

My two main questions right now would be:

  1. I haven't done any real testing yet except on a very simple example. Do you know if there are there some CmdStan models / output files for which there exist reference posteriors, so that I can check that this gets the right thing?

  2. from_cmdstan(...) doesn't do that much right now beyond reading and reshaping samples and passing them to from_namedtuple (it does split out sample_stats). Does that seem like the right way to go here for you?

I didn't start a PR yet, because I wanted to ask your opinion on where to put this first. Here's what I got so far:

using DelimitedFiles
using InferenceObjects: from_namedtuple

"""
    readheader(file)

Read the header with variable names from a stan csv file.
"""
function readheader(file)
    for line in eachline(file)
        startswith(line, "#") && continue
        return string.(split(line, ","))
    end
end

"""
    readfiles(files)

Read one or more Stan output csv files into an ndraws x nvars x nchains
array. Return a tuple containing the array and the variable names.

Assumes that all files have the same schema.
"""
function readfiles(files)
    header = readheader(first(files))
    values = stack(files) do file
        arr, _ = readdlm(file, ','; comments=true, header=true)
        arr
    end
    return values, header
end
readfiles(file::AbstractString) = readfiles([file])

"""
    vardims_from_names(names)

Parse the dimensions from the variable names as contained in a Stan 
output csv file. Return a dict with the dimensions for each variable.

A list of names such as `["a.1.1", "a.1.2", "a.2.1", "a2.2"]` would
yield a `Dict("a" => (1,1))`.

Assumes that the names are ordered by dims.
"""
function vardims_from_names(names)
    res = Dict{String,NTuple}()
    # this relies on dims being sorted in the csv
    for name in reverse(names)
        var, dims... = split(name, ".")
        var = string(var)
        haskey(res, var) && continue
        res[var] = tuple(parse.(Int, dims)...)
    end

    return res
end

"""
    varindexes(names)

Map variables to their position range in the list of names as
contained in the csv header.
"""
function varindexes(names)
    vars = first.(split.(names, "."))
    vind = map(unique(vars)) do var
        from = findfirst(==(var), vars)
        to = findlast(==(var), vars)
        string(var) => from:to
    end
    return Dict(vind)
end

"""
    to_namedtuple(arr, names)

Split and reshape the draws in `arr` according to the variable dimensions
parsed from `names` and return a named tuple 
"""
function to_namedtuple(arr, names)
    draws, nvars, chains = size(arr)
    vind = varindexes(names)
    vdim = vardims_from_names(names)
    vars = string.(first.(split.(names, ".")))
    res = map(unique(vars)) do var
        a = arr[:, vind[var], :]
        Symbol(var) => reshape(a, (draws, chains, vdim[var]...))
    end

    return (; res...)
end

## This is copied from the StanSample.jl extension

const SAMPLE_STATS_KEYMAP = (
    n_leapfrog__=:n_steps,
    treedepth__=:tree_depth,
    energy__=:energy,
    lp__=:lp,
    stepsize__=:step_size,
    divergent__=:diverging,
    accept_stat__=:acceptance_rate,
)

function rekey(nt::NamedTuple, keymap)
    new_keys = map(k -> get(keymap, k, k), keys(nt))
    return NamedTuple{new_keys}(values(nt))
end

function split_post_stats(nt, keymap)
    stats = filter(in(values(keymap)), keys(nt))
    post = filter(!in(values(keymap)), keys(nt))
    return NamedTuple{post}(nt), NamedTuple{stats}(nt)
end

is_file(arg::AbstractString) = endswith(arg, ".csv")
is_file(arg::AbstractVector{<:String}) = all(endswith(a, ".csv") for a in arg)

"""
    from_cmdstan(posterior::Union{<:AbstractString,Vector{<:AbstractString}}; kwargs...)

Create an `InferenceData` from CmdStan csv files. `kwargs` can be filenames indicating CmdStan output,
such as prior draws or generated quantities, or named tuples. If they are files, the contained draws are
reshaped into (draws x chains x vardims...) arrays and are passed to `InferenceObjects.from_namedtuple`.
"""
function from_cmdstan(
    posterior::Union{<:AbstractString,Vector{<:AbstractString}};
    prior = nothing,
    sample_stats_prior = nothing,
    kwargs...
)
    post, sample_stats = let
        nt = to_namedtuple(readfiles(posterior)...)
        nt = rekey(nt, SAMPLE_STATS_KEYMAP)
        split_post_stats(nt, SAMPLE_STATS_KEYMAP)
    end

    if !isnothing(prior)
        nt = to_namedtuple(readfiles(prior)...)
        nt = rekey(nt, SAMPLE_STATS_KEYMAP)
        prior, sample_stats_prior = split_post_stats(nt, SAMPLE_STATS_KEYMAP)
    end

    kwargs = map(NamedTuple(kwargs)) do arg
        is_file(arg) || return arg
        to_namedtuple(readfiles(arg)...)
    end

    return from_namedtuple(post; sample_stats, prior, sample_stats_prior, kwargs...)
end