Closed yiyuezhuo closed 4 years ago
Thank you, this looks definitely useful. Will have a look at it and your example and get back to you. I'm traveling right now so it might take until later this week. And thanks for including a testset!
Hi,
Apologies for not getting back sooner. I like what you did.
I would like to also (optionally) be able to return a NamedTuple, i.e. something like:
nt = (x = ex_dict["x"], y = ex_dict["y"], z = ex_dict["z"])
nt.y
That could be a second extract
method, e.g.:
function extract(chns::Array{Float64,3}, cnames::Vector{String}, ::Val{:NamedTuple})
res = extract(chns, cnames)
nt = ##### construct the NamedTuple #####
end
but haven't had the time to figure out the creation of the NamedTuple yet.
I propose to initially add your version of extract
. You want to make a PR or would you prefer me to do it?
In fact it is so easy to pull extract
and the test into CmdStan, I might as well do it and publish. How would you like to be referred to in the acknowledgement (@yiyuezhuo ?).
Yes, thank you (@goedman).
Hi @yiyuezhuo,
Would below version work for you? This version returns a NamedTuple instead of the Dict.
using Test
function extract(chns::Array{Float64,3}, cnames::Vector{String})
draws, vars, chains = size(chns)
ex_dict = Dict{Symbol, Array}()
group_map = Dict{Symbol, Array}()
for (i, cname) in enumerate(cnames)
sp_arr = split(cname, ".")
name = Symbol(sp_arr[1])
if length(sp_arr) == 1
ex_dict[name] = chns[:,i,:]
else
if !(name in keys(group_map))
group_map[name] = Any[]
end
push!(group_map[name], (i, [Meta.parse(i) for i in sp_arr[2:end]]))
end
end
for (name, group) in group_map
max_idx = maximum(hcat([idx for (i, idx) in group]...), dims=2)[:,1]
ex_dict[name] = similar(chns, max_idx..., draws, chains)
end
for (name, group) in group_map
for (i, idx) in group
ex_dict[name][idx..., :, :] = chns[:,i,:]
end
end
return (;ex_dict...)
end
cnames = ["x", "y.1", "y.2", "z.1.1", "z.2.1", "z.3.1", "z.1.2", "z.2.2", "z.3.2", "k.1.1.1.1.1"]
key_to_idx = Dict(name => idx for (idx, name) in enumerate(cnames))
draws = 100
vars = length(cnames_dummy)
chains = 2
chns = randn(draws, vars, chains)
nt = extract(chns, cnames)
@testset "extract" begin
# Write your tests here.
@test size(nt.x) == (draws, chains)
@test size(nt.y) == (2, draws, chains)
@test size(nt.z) == (3, 2, draws, chains)
@test size(nt.k) == (1, 1, 1, 1, 1, draws, chains)
@test nt.x[2,1] == chns[2, key_to_idx["x"], 1]
@test nt.y[2,3,2] == chns[3, key_to_idx["y.2"], 2]
@test nt.z[3, 1, 10, 1] == chns[10, key_to_idx["z.3.1"], 1]
@test nt.k[1,1,1,1,1,draws,2] == chns[draws, key_to_idx["k.1.1.1.1.1"], 2]
@test size(values(nt.z)) == (3, 2, 100, 2)
@test size(nt.z) == (3, 2, 100, 2)
end
Looks fine, just rename cnames
to cnames_dummy
to let it work.
Thanks, yes fixed that. Extract is really an extension of the output_format in convert_a3d().
Included in release CmdStan.jl v6.0.8. By the way, I also ported it to StanSample.jl.
It would be convenient to collect items
A.1, A.2, ...
into an array, like whatextract
does inPyStan
. I have implemented it in a package:,https://github.com/yiyuezhuo/CmdStanExtract.jl
The concrete behavior can be shown by this unit test:
Is it suitable to add it into
CmdStan.jl
?