StanJulia / CmdStan.jl

CmdStan.jl v6 provides an alternative, older Julia wrapper to Stan's `cmdstan` executable. CmdStan will be deprecated in 2022.
MIT License
30 stars 12 forks source link

Add "extract" output format. #99

Closed yiyuezhuo closed 4 years ago

yiyuezhuo commented 4 years ago

It would be convenient to collect items A.1, A.2, ... into an array, like what extract does in PyStan. I have implemented it in a package:,

https://github.com/yiyuezhuo/CmdStanExtract.jl

The concrete behavior can be shown by this unit test:

using CmdStanExtract
using Test

cnames_dummy = ["x", "y.1", "y.2", "z.1.1", "z.2.1", "z.3.1", "z.1.2", "z.2.2", "z.3.2", "k.1.1.1.1.1"]

key_to_idx = Dict(name => idx for (idx, name) in enumerate(cnames_dummy))

draws = 100
vars = length(cnames_dummy)
chains = 2
chns_dummy = randn(draws, vars, chains)

@testset "CmdStanExtract.jl" begin
    # Write your tests here.
    ex_dict = extract(chns_dummy, cnames_dummy)

    @test size(ex_dict["x"]) == (draws, chains)
    @test size(ex_dict["y"]) == (2, draws, chains)
    @test size(ex_dict["z"]) == (3, 2, draws, chains)
    @test size(ex_dict["k"]) == (1, 1, 1, 1, 1, draws, chains)

    @test ex_dict["x"][2,1] == chns_dummy[2, key_to_idx["x"], 1]
    @test ex_dict["y"][2,3,2] == chns_dummy[3, key_to_idx["y.2"], 2]
    @test ex_dict["z"][3, 1, 10, 1] == chns_dummy[10, key_to_idx["z.3.1"], 1]
    @test ex_dict["k"][1,1,1,1,1,draws,2] == chns_dummy[draws, key_to_idx["k.1.1.1.1.1"], 2]
end

Is it suitable to add it into CmdStan.jl?

goedman commented 4 years ago

Thank you, this looks definitely useful. Will have a look at it and your example and get back to you. I'm traveling right now so it might take until later this week. And thanks for including a testset!

goedman commented 4 years ago

Hi,

Apologies for not getting back sooner. I like what you did.

I would like to also (optionally) be able to return a NamedTuple, i.e. something like:

nt = (x = ex_dict["x"], y = ex_dict["y"], z = ex_dict["z"])
nt.y

That could be a second extract method, e.g.:

function extract(chns::Array{Float64,3}, cnames::Vector{String}, ::Val{:NamedTuple})
   res = extract(chns, cnames)
   nt = ##### construct the NamedTuple #####
end

but haven't had the time to figure out the creation of the NamedTuple yet.

I propose to initially add your version of extract. You want to make a PR or would you prefer me to do it?

goedman commented 4 years ago

In fact it is so easy to pull extract and the test into CmdStan, I might as well do it and publish. How would you like to be referred to in the acknowledgement (@yiyuezhuo ?).

yiyuezhuo commented 4 years ago

Yes, thank you (@goedman).

goedman commented 4 years ago

Hi @yiyuezhuo,

Would below version work for you? This version returns a NamedTuple instead of the Dict.

using Test

function extract(chns::Array{Float64,3}, cnames::Vector{String})
    draws, vars, chains = size(chns)

    ex_dict = Dict{Symbol, Array}()

    group_map = Dict{Symbol, Array}()
    for (i, cname) in enumerate(cnames)
        sp_arr = split(cname, ".")
        name = Symbol(sp_arr[1])
        if length(sp_arr) == 1
            ex_dict[name] = chns[:,i,:]
        else
            if !(name in keys(group_map))
                group_map[name] = Any[]
            end
            push!(group_map[name], (i, [Meta.parse(i) for i in sp_arr[2:end]]))
        end
    end

    for (name, group) in group_map
        max_idx = maximum(hcat([idx for (i, idx) in group]...), dims=2)[:,1]
        ex_dict[name] = similar(chns, max_idx..., draws, chains)
    end

    for (name, group) in group_map
        for (i, idx) in group
            ex_dict[name][idx..., :, :] = chns[:,i,:]
        end
    end

    return (;ex_dict...)
end

cnames = ["x", "y.1", "y.2", "z.1.1", "z.2.1", "z.3.1", "z.1.2", "z.2.2", "z.3.2", "k.1.1.1.1.1"]

key_to_idx = Dict(name => idx for (idx, name) in enumerate(cnames))

draws = 100
vars = length(cnames_dummy)
chains = 2
chns = randn(draws, vars, chains)
nt = extract(chns, cnames)

@testset "extract" begin
    # Write your tests here.

    @test size(nt.x) == (draws, chains)
    @test size(nt.y) == (2, draws, chains)
    @test size(nt.z) == (3, 2, draws, chains)
    @test size(nt.k) == (1, 1, 1, 1, 1, draws, chains)

    @test nt.x[2,1] == chns[2, key_to_idx["x"], 1]
    @test nt.y[2,3,2] == chns[3, key_to_idx["y.2"], 2]
    @test nt.z[3, 1, 10, 1] == chns[10, key_to_idx["z.3.1"], 1]
    @test nt.k[1,1,1,1,1,draws,2] == chns[draws, key_to_idx["k.1.1.1.1.1"], 2]

    @test size(values(nt.z)) == (3, 2, 100, 2)
    @test size(nt.z) == (3, 2, 100, 2)

end
yiyuezhuo commented 4 years ago

Looks fine, just rename cnames to cnames_dummy to let it work.

goedman commented 4 years ago

Thanks, yes fixed that. Extract is really an extension of the output_format in convert_a3d().

goedman commented 4 years ago

Included in release CmdStan.jl v6.0.8. By the way, I also ported it to StanSample.jl.