JuliaStats / Klara.jl

MCMC inference in Julia
Other
166 stars 38 forks source link

HMC with reverse-mode autodiff #170

Open remoore opened 6 years ago

remoore commented 6 years ago

I tried running the last example in the README with an HMC sampler substituted for the MALA one:

using Klara

plogtarget(z) = -dot(z, z)

p = BasicContMuvParameter(:p, logtarget=plogtarget, diffopts=DiffOptions(mode=:reverse))

model = likelihood_model(p, false)

sampler = HMC()

mcrange = BasicMCRange(nsteps=10000, burnin=1000)

v0 = Dict(:p=>[5.1, -0.9])

outopts = Dict{Symbol, Any}(:monitor=>[:value, :logtarget, :gradlogtarget], :diagnostics=>[:accept])

job = BasicMCJob(model, sampler, mcrange, v0, tuner=VanillaMCTuner(verbose=true), outopts=outopts)

run(job)

chain = output(job)

It crashes with the following error message when the job is run:

ERROR: MethodError: no method matching reverse_autodiff_gradient(::DiffBase.DiffResult{1,Float64,Tuple{Array{Float64,1}}}, ::Void, ::Array{Float64,1})
Closest candidates are:
  reverse_autodiff_gradient(::DiffBase.DiffResult, ::Union{ReverseDiff.CompiledTape, ReverseDiff.GradientTape}, ::Array{T,1} where T) at /home/remoore/.julia/v0.6/Klara/src/autodiff/reverse.jl:6
Stacktrace:
 [1] (::Klara.##335#373)(::Klara.BasicContMuvParameterState{Float64}, ::Array{Klara.VariableState,1}) at /home/remoore/.julia/v0.6/Klara/src/variables/parameters/BasicContMuvParameter.jl:584
 [2] leapfrog!(::Klara.BasicContMuvParameterState{Float64}, ::Array{Float64,1}, ::Klara.BasicContMuvParameterState{Float64}, ::Array{Float64,1}, ::Float64, ::Klara.##274#291{BasicContMuvParameter,Tuple{Void,Void,Void,Void,Klara.##328#366{Tuple{Void,Void,Void,Void,#plogtarget,Void,Void,Void,Void,Void,Void,Void,Void,Void,Void,Void,Void},Array{Symbol,1}},Void,Void,Klara.##335#373,Void,Void,Void,Void,Void,Void,Klara.##336#374,Void,Void}}) at /home/remoore/.julia/v0.6/Klara/src/samplers/samplers.jl:132
 [3] iterate!(::Klara.BasicMCJob, ::Type{Klara.HMC}, ::Type{Distributions.Multivariate}) at /home/remoore/.julia/v0.6/Klara/src/samplers/iterate/HMC.jl:147
 [4] run(::Klara.BasicMCJob) at /home/remoore/.julia/v0.6/Klara/src/jobs/BasicMCJob.jl:224

It does run when using the forward-mode though. Any idea why?

papamarkou commented 6 years ago

Will look at this soon, I am currently dealing with a wrist health issue that prevents me from working.