fonsp / Pluto.jl

🎈 Simple reactive notebooks for Julia
https://plutojl.org/
MIT License
5k stars 296 forks source link

Turing.jl summary statistics tables do not display #868

Closed evjrob closed 2 years ago

evjrob commented 3 years ago

When running the Turing differential equations tutorial code below in the REPL everything works fine:

using Turing, Distributions, DifferentialEquations 

# Import MCMCChain, Plots, and StatsPlots for visualizations and diagnostics.
using MCMCChains, Plots, StatsPlots

# Set a seed for reproducibility.
using Random
Random.seed!(14);

function lotka_volterra(du,u,p,t)
  x, y = u
  α, β, γ, δ  = p
  du[1] = (α - β*y)x # dx =
  du[2] = (δ*x - γ)y # dy = 
end
p = [1.5, 1.0, 3.0, 1.0]
u0 = [1.0,1.0]
prob1 = ODEProblem(lotka_volterra,u0,(0.0,10.0),p)
sol = solve(prob1,Tsit5())
plot(sol)

sol1 = solve(prob1,Tsit5(),saveat=0.1)
odedata = Array(sol1) + 0.8 * randn(size(Array(sol1)))
plot(sol1, alpha = 0.3, legend = false); scatter!(sol1.t, odedata')

Turing.setadbackend(:forwarddiff)

@model function fitlv(data, prob1)
    σ ~ InverseGamma(2, 3) # ~ is the tilde character
    α ~ truncated(Normal(1.5,0.5),0.5,2.5)
    β ~ truncated(Normal(1.2,0.5),0,2)
    γ ~ truncated(Normal(3.0,0.5),1,4)
    δ ~ truncated(Normal(1.0,0.5),0,2)

    p = [α,β,γ,δ]
    prob = remake(prob1, p=p)
    predicted = solve(prob,Tsit5(),saveat=0.1)

    for i = 1:length(predicted)
        data[:,i] ~ MvNormal(predicted[i], σ)
    end
end

model = fitlv(odedata, prob1)

# This next command runs 3 independent chains without using multithreading. 
chain = mapreduce(c -> sample(model, NUTS(.65),1000), chainscat, 1:3)

My equivalent Pluto code is here:

### A Pluto.jl notebook ###
# v0.12.18

using Markdown
using InteractiveUtils

# ╔═╡ 52334ae8-5c67-11eb-07a9-bf514e31de43
begin
    using Turing, Distributions, DifferentialEquations 

    # Import MCMCChain, Plots, and StatsPlots for visualizations and diagnostics.
    using MCMCChains, Plots, StatsPlots

    # Set a seed for reproducibility.
    using Random
    Random.seed!(14);
end

# ╔═╡ b830a986-5c66-11eb-1da2-23fffb0dfc86
begin
    function lotka_volterra(du,u,p,t)
      x, y = u
      α, β, γ, δ  = p
      du[1] = (α - β*y)x # dx =
      du[2] = (δ*x - γ)y # dy = 
    end
    p = [1.5, 1.0, 3.0, 1.0]
    u0 = [1.0,1.0]
    prob1 = ODEProblem(lotka_volterra,u0,(0.0,10.0),p)
    sol = solve(prob1,Tsit5())
    plot(sol)

    sol1 = solve(prob1,Tsit5(),saveat=0.1)
    odedata = Array(sol1) + 0.8 * randn(size(Array(sol1)))
    plot(sol1, alpha = 0.3, legend = false); scatter!(sol1.t, odedata')

    Turing.setadbackend(:forwarddiff)

    @model function fitlv(data, prob1)
        σ ~ InverseGamma(2, 3) # ~ is the tilde character
        α ~ truncated(Normal(1.5,0.5),0.5,2.5)
        β ~ truncated(Normal(1.2,0.5),0,2)
        γ ~ truncated(Normal(3.0,0.5),1,4)
        δ ~ truncated(Normal(1.0,0.5),0,2)

        p = [α,β,γ,δ]
        prob = remake(prob1, p=p)
        predicted = solve(prob,Tsit5(),saveat=0.1)

        for i = 1:length(predicted)
            data[:,i] ~ MvNormal(predicted[i], σ)
        end
    end

    model = fitlv(odedata, prob1)

    # This next command runs 3 independent chains without using multithreading. 
    chain = mapreduce(c -> sample(model, NUTS(.65),1000), chainscat, 1:3)
end

# ╔═╡ Cell order:
# ╠═52334ae8-5c67-11eb-07a9-bf514e31de43
# ╠═b830a986-5c66-11eb-1da2-23fffb0dfc86

Running this yields the following error after the second cell instead of displaying the expected summary statistics table:

Failed to show value:

MethodError: no method matching iterate(::MCMCChains.Chains{Float64,AxisArrays.AxisArray{Float64,3,Array{Float64,3},Tuple{AxisArrays.Axis{:iter,StepRange{Int64,Int64}},AxisArrays.Axis{:var,Array{Symbol,1}},AxisArrays.Axis{:chain,UnitRange{Int64}}}},Missing,NamedTuple{(:parameters, :internals),Tuple{Array{Symbol,1},Array{Symbol,1}}},NamedTuple{(),Tuple{}}})

Closest candidates are:

iterate(!Matched::Base.EnvDict) at env.jl:119

iterate(!Matched::Base.EnvDict, !Matched::Any) at env.jl:119

iterate(!Matched::Tables.DictRowTable) at /home/everett/.julia/packages/Tables/8Ud85/src/dicts.jl:120

...

    isempty(::MCMCChains.Chains{Float64,AxisArrays.AxisArray{Float64,3,Array{Float64,3},Tuple{AxisArrays.Axis{:iter,StepRange{Int64,Int64}},AxisArrays.Axis{:var,Array{Symbol,1}},AxisArrays.Axis{:chain,UnitRange{Int64}}}},Missing,NamedTuple{(:parameters, :internals),Tuple{Array{Symbol,1},Array{Symbol,1}}},NamedTuple{(),Tuple{}}})@essentials.jl:737
    table_data(::MCMCChains.Chains{Float64,AxisArrays.AxisArray{Float64,3,Array{Float64,3},Tuple{AxisArrays.Axis{:iter,StepRange{Int64,Int64}},AxisArrays.Axis{:var,Array{Symbol,1}},AxisArrays.Axis{:chain,UnitRange{Int64}}}},Missing,NamedTuple{(:parameters, :internals),Tuple{Array{Symbol,1},Array{Symbol,1}}},NamedTuple{(),Tuple{}}}, ::IOContext{Base.GenericIOBuffer{Array{UInt8,1}}})@PlutoRunner.jl:827
    show_richest(::IOContext{Base.GenericIOBuffer{Array{UInt8,1}}}, ::Any)@PlutoRunner.jl:599
    #sprint_withreturned#28(::IOContext{Base.DevNull}, ::Int64, ::typeof(Main.PlutoRunner.sprint_withreturned), ::Function, ::MCMCChains.Chains{Float64,AxisArrays.AxisArray{Float64,3,Array{Float64,3},Tuple{AxisArrays.Axis{:iter,StepRange{Int64,Int64}},AxisArrays.Axis{:var,Array{Symbol,1}},AxisArrays.Axis{:chain,UnitRange{Int64}}}},Missing,NamedTuple{(:parameters, :internals),Tuple{Array{Symbol,1},Array{Symbol,1}}},NamedTuple{(),Tuple{}}})@PlutoRunner.jl:549
    format_output_default(::Any, ::Any)@PlutoRunner.jl:474
    #format_output#17@PlutoRunner.jl:491[inlined]
    formatted_result_of(::Base.UUID, ::Bool, ::Nothing)@PlutoRunner.jl:416
    top-level scope@none:1

This occurs in both the generic Linux on x86 64-bit Julia release on an Arch install, and the 64-bit Windows release on Windows 10.

00krishna commented 3 years ago

+1 yes I ran into this issue as well. The message is confusing because it is an MCMCChains error, even though the cause is something from Pluto. I first thought it was an MCMCChains or Turing error, but then I ran the same code in the console and it worked fine. So that was the key.

lukavdplas commented 3 years ago

Related to #196 ?

devmotion commented 3 years ago

Pluto wants to use the Tables interface for displaying the chain (instead of the standard display/show methods in the REPL) but MCMCChains's implementation violates some of its assumptions: https://github.com/TuringLang/MCMCChains.jl/issues/268

sethaxen commented 3 years ago

This issue has been fixed with MCMCChains v4.7.2.

vini-fda commented 3 years ago

On the julia REPL, I also get info about the iterations, number of chains, samples per chain etc. But in Pluto I don't get that info, just the DataFrame. Is this by design?

On the REPL:

image

Whereas, on the Pluto notebook:

image
rikhuijzer commented 3 years ago

Using describe is possible:

image