TuringLang / Turing.jl

Bayesian inference with probabilistic programming.
https://turinglang.org
MIT License
2.04k stars 219 forks source link

`MethodError: MCMCChains.Chains` when using `chainscat` on Bayesian Diff Eq Tutorial #1536

Closed 00krishna closed 3 years ago

00krishna commented 3 years ago

Hey folks,

I was working through the Bayesian Differential Equations tutorial, and I encountered an error when trying to run the model. One of the folks on Slack suggested I post an issue.

I am using Julia 1.5 and version 4.6.0 of MCMCChains. Note that I was running Turing in a Pluto notebook session.

Here is the code and then the error message below. Seems like there is some unimplemented dispatch in the MCMCChains package or such?

using Turing, Distributions, DifferentialEquations, DataFrames
using MCMCChains, Plots, StatsPlots
using Random
Random.seed!(14);
Turing.turnprogress(false)
plotlyjs()

# setup problem
function lotka_volterra(du,u,p,t)
  x, y = u
  α, β, γ, δ  = p
  du[1] = (α - β*y)x # dx =
  du[2] = (δ*x - γ)y # dy = 
end
p = [1.5, 1.0, 3.0, 1.0]
u0 = [1.0,1.0]
prob1 = ODEProblem(lotka_volterra,u0,(0.0,10.0),p)
sol = solve(prob1,Tsit5())

# Generate some data for estimation
sol1 = solve(prob1,Tsit5(),saveat=0.1)
odedata = Array(sol1) + 0.8 * randn(size(Array(sol1)))
plot(sol1, alpha = 0.3, legend = false); scatter!(sol1.t, odedata')

# Setup Turing model
Turing.setadbackend(:forwarddiff)

@model function fitlv(data, prob1)
    σ ~ InverseGamma(2, 3) # ~ is the tilde character
    α ~ truncated(Normal(1.5,0.5),0.5,2.5)
    β ~ truncated(Normal(1.2,0.5),0,2)
    γ ~ truncated(Normal(3.0,0.5),1,4)
    δ ~ truncated(Normal(1.0,0.5),0,2)

    p = [α,β,γ,δ]
    prob = remake(prob1, p=p)
    predicted = solve(prob,Tsit5(),saveat=0.1)

    for i = 1:length(predicted)
        data[:,i] ~ MvNormal(predicted[i], σ)
    end
end

model = fitlv(odedata, prob1)
chain = mapreduce(c -> sample(model, NUTS(.65),1000), chainscat, 1:3)

The last command to use mapreduce will generate the following error:

chain

Failed to show value:

MethodError: no method matching iterate(::MCMCChains.Chains{Float64,AxisArrays.AxisArray{Float64,3,Array{Float64,3},Tuple{AxisArrays.Axis{:iter,StepRange{Int64,Int64}},AxisArrays.Axis{:var,Array{Symbol,1}},AxisArrays.Axis{:chain,UnitRange{Int64}}}},Missing,NamedTuple{(:parameters, :internals),Tuple{Array{Symbol,1},Array{Symbol,1}}},NamedTuple{(),Tuple{}}})

Closest candidates are:

iterate(!Matched::Base.RegexMatchIterator) at regex.jl:552

iterate(!Matched::Base.RegexMatchIterator, !Matched::Any) at regex.jl:552

iterate(!Matched::Libtask.TRef, !Matched::Any...) at /home/krishnab/.julia/packages/Libtask/00Il9/src/tref.jl:86

...

    isempty(::MCMCChains.Chains{Float64,AxisArrays.AxisArray{Float64,3,Array{Float64,3},Tuple{AxisArrays.Axis{:iter,StepRange{Int64,Int64}},AxisArrays.Axis{:var,Array{Symbol,1}},AxisArrays.Axis{:chain,UnitRange{Int64}}}},Missing,NamedTuple{(:parameters, :internals),Tuple{Array{Symbol,1},Array{Symbol,1}}},NamedTuple{(),Tuple{}}})@essentials.jl:737
    table_data(::MCMCChains.Chains{Float64,AxisArrays.AxisArray{Float64,3,Array{Float64,3},Tuple{AxisArrays.Axis{:iter,StepRange{Int64,Int64}},AxisArrays.Axis{:var,Array{Symbol,1}},AxisArrays.Axis{:chain,UnitRange{Int64}}}},Missing,NamedTuple{(:parameters, :internals),Tuple{Array{Symbol,1},Array{Symbol,1}}},NamedTuple{(),Tuple{}}}, ::IOContext{Base.GenericIOBuffer{Array{UInt8,1}}})@PlutoRunner.jl:828
    show_richest(::IOContext{Base.GenericIOBuffer{Array{UInt8,1}}}, ::Any)@PlutoRunner.jl:600
    #sprint_withreturned#28(::IOContext{Base.DevNull}, ::Int64, ::typeof(Main.PlutoRunner.sprint_withreturned), ::Function, ::MCMCChains.Chains{Float64,AxisArrays.AxisArray{Float64,3,Array{Float64,3},Tuple{AxisArrays.Axis{:iter,StepRange{Int64,Int64}},AxisArrays.Axis{:var,Array{Symbol,1}},AxisArrays.Axis{:chain,UnitRange{Int64}}}},Missing,NamedTuple{(:parameters, :internals),Tuple{Array{Symbol,1},Array{Symbol,1}}},NamedTuple{(),Tuple{}}})@PlutoRunner.jl:550
    format_output_default(::Any, ::Any)@PlutoRunner.jl:475
    #format_output#17@PlutoRunner.jl:492[inlined]
    formatted_result_of(::Base.UUID, ::Bool, ::Nothing)@PlutoRunner.jl:417
    top-level scope@none:1

Also note, if I use the following code, it works just fine.

chain2 = sample(model, NUTS(.65), MCMCThreads(), 5000, 3, progress=false)
devmotion commented 3 years ago

Do you get the same error in the REPL? The stacktrace indicates that the problem might be Pluto-specific.

00krishna commented 3 years ago

@devmotion You know, I just tried this in the console and it seems to work. So that is good news. But the error message is pointing to MCMCChains, which is throwing me off.

I saw the reference to Plutorunner way down in the stacktrace, figured that the reference to MCMCChains was the main thing. Perhaps Pluto is having a hard time outputting the summary information from the chain, and that is what is somehow triggering the issue.

I can try and reroute this issue to the Pluto folks.

00krishna commented 3 years ago

@devmotion yes, this seems to be a documented Pluto issue. Just a note that because the error keeps showing MCMCChains as the cause of the error, other folks might open issues on this topic here.

Below is the link to the existing Pluto issue with displaying Turing output. https://github.com/fonsp/Pluto.jl/issues/868