TuringLang / Turing.jl

Bayesian inference with probabilistic programming.
https://turinglang.org
MIT License
2.04k stars 219 forks source link

Jags-style samplers #905

Closed itsdfish closed 2 years ago

itsdfish commented 5 years ago

I am opening this feature request after a discussion on Slack regarding the performance of PG. For continuous parameters in particular, particles tend to get stuck. It's not clear to me to what extent this may happen for discrete parameters. Here is an example:

using Turing,Random,StatsPlots
@model model(y) = begin
    μ ~ Normal(0,10)
    σ ~ Truncated(Cauchy(0,1),0,Inf)
    for j in 1:length(y)
        y[j] ~ Normal(μ,σ)
    end
end
Random.seed!(3431)
y = rand(Normal(0,1),50)
chain = sample(model(y),PG(40,4000))
chain = chain[2001:end,:,:]
println(chain)
plot(chain)

fig4

This required about 2.5 minutes to run on my system. Increasing the number of particles to 80 did not help much.

As a basis for comparison, here is the same model coded in Jags:

ENV["JAGS_HOME"] = "usr/bin/jags" #your path here
using Jags, StatsPlots, Random, Distributions
#cd(@__DIR__)
ProjDir = pwd()
Random.seed!(3431)

y = rand(Normal(0,1),50)

Model = "
model {
      for (i in 1:length(y)) {
            y[i] ~ dnorm(mu,sigma);
      }
      mu  ~ dnorm(0, 1/sqrt(10));
      sigma  ~ dt(0,1,1) T(0, );
  }
"

monitors = Dict(
  "mu" => true,
  "sigma" => true,
  )

jagsmodel = Jagsmodel(
  name="Gaussian",
  model=Model ,
  monitor=monitors,
  ncommands=4, nchains=1,
  #deviance=true, dic=true, popt=true,
  pdir=ProjDir
  )

println("\nJagsmodel that will be used:")
jagsmodel |> display

data = Dict{String, Any}(
  "y" => y,
)

inits = [
  Dict("mu" => 0.0,"sigma" => 1.0,
  ".RNG.name" => "base::Mersenne-Twister")
]

println("Input observed data dictionary:")
data |> display
println("\nInput initial values dictionary:")
inits |> display
println()
#######################################################################################
#                                 Estimate Parameters
#######################################################################################
sim = jags(jagsmodel, data, inits, ProjDir)
sim = sim[5001:end,:,:]
plot(sim)

jags

This required about .267 seconds on my machine, which is nearly a 600 fold speed up.

Here is a second example we found to perform poorly:

using Distributions
using Turing

n=500
p=20
X = rand(Float64, (n,p))
beta=[2.0 .^ (-i) for i in 0:(p-1)]
alpha=0
sigma=0.7
eps=rand(Normal(0, sigma), n)
y = alpha .+ X * beta + eps;

@model model(X, y) = begin

    n, p = size(X)

    alpha ~ Normal(0,1)
    sigma ~ Truncated(Cauchy(0,1),0,Inf)
    sigma_beta ~ Truncated(Cauchy(0,1),0,Inf)
    pind ~ Beta(2,8)

    beta = tzeros(Float64, p)
    betaT = tzeros(Float64, p)
    ind = tzeros(Int, p)

    for j in 1:p
        ind[j] ~ Bernoulli(pind)
        betaT[j] ~ Normal(0,sigma_beta)  # random effect
        beta[j] = ind[j] * betaT[j]
    end

    mu = tzeros(Float64, n)

    for i in 1:n
        mu[i] = alpha + X[i,:]' * beta 
        y[i] ~ Normal(mu[i], sigma)
    end

end

steps = 4000
chain = sample(model(X,y),PG(40,steps))

I think this would be a very useful addition. By adding Jags-style samplers, we could have the speed of Jags without the severe limitations of Jags. This would also provide Turing with an ability that Stan struggles to perform.

yebai commented 5 years ago

Thanks for opening this issue. This is already on the priority list of Turing team. Adding support for handling discrete variables, and combining different sampling algorithms to form more efficient inference engines are among the original motivations of Turing. However, the challenge is not from the inference side. We can quickly implement samplers currently available in JAGS. The real barrier is the compiler, which currently only tracks values of random variables, but ignores their dependencies. This lack of dependency information makes it hard to derive Gibbs conditionals automatically.

One reason why it's harder to implement dependency tracking in Turing, compared to other libraries like JAGS, Mamba.jl, is that Turing takes a tracing approach (aka define-by-run) for defining models. Libraries like JAGS take a different approach, which is based on a scripting (aka define-and-run) approach. The tracing approach is argubly more general and user-friendly: 1) it supports models with varying dimensionality, like Dirichlet processes; 2) it makes models easier to implement and debug. Unfortunately, these properties also mean that the graphical model underlying a Turing program can be dynamic, i.e. both edges and the total number of nodes could vary during inference.

To address these issues, add support for JAGS style inference and other advanced inference methods in Turing, we have started several projects. Below is some related ongoing PRs/work:

As a side note, there is also an alternative approach to avoid dependency tracking. It requires the user to write their models in several smaller Turing programs, and run a different sampler on each Turing program, in a way similar to JAGS, then "glue" together inference results from these smaller models. It only requires a relatively small amount of work to support this approach after the MCMC Interface PR (https://github.com/TuringLang/Turing.jl/pull/793) is merged. I don't really like this approach because it requires the user to break one model into several smaller programs. But it loosely fits into the "models as code" philosophy, in the sense that it encourages modularity in modelling, and encourages building complex models by composing common modelling parts if possible.

Pls, let me know if any parts of the above plan are unclear, and/or if you have any thoughts and suggestions!

itsdfish commented 5 years ago

Thank you for taking the time to write a detailed reply. It looks like some real exciting new features are on the horizon. I realize that this might be difficult to answer, but do you have a rough idea of when Jags-style sampling might be implemented? Approximately, six months, or a year? This will help me plan and prioritize some projects, including the benchmarking work I am doing with Rob. Thanks!

yebai commented 5 years ago

We're targeting 3-6 months, but it might take a bit longer.

elizavetasemenova commented 5 years ago

For the record, the second example in the initial pull request (an important case for my work) takes about 2 hours to run and the trace plots of some parameters look as follows:

Screenshot 2019-09-09 at 14 26 45
itsdfish commented 5 years ago

On a related note, I also want to point out that the Hidden Markov Model from the tutorial produces very low effective sample size, consistently less than 10.

Summary Statistics

│ Row │ parameters │ mean      │ std         │ naive_se    │ mcse       │ ess     │ r_hat    │
│     │ Symbol     │ Float64   │ Float64     │ Float64     │ Float64    │ Any     │ Any      │
├─────┼────────────┼───────────┼─────────────┼─────────────┼────────────┼─────────┼──────────┤
│ 1   │ T[1][1]    │ 0.60352   │ 0.0305084   │ 0.00096476  │ 0.00964355 │ 4.23888 │ 1.59418  │
│ 2   │ T[1][2]    │ 0.309543  │ 0.0206837   │ 0.000654076 │ 0.00630832 │ 6.18149 │ 1.26903  │
│ 3   │ T[1][3]    │ 0.086937  │ 0.0135024   │ 0.000426984 │ 0.00439716 │ 4.01606 │ 1.88707  │
│ 4   │ T[2][1]    │ 0.706185  │ 0.0210481   │ 0.0006656   │ 0.00628791 │ 6.92471 │ 1.04481  │
│ 5   │ T[2][2]    │ 0.253944  │ 0.0181811   │ 0.000574936 │ 0.00547099 │ 7.6714  │ 0.999274 │
│ 6   │ T[2][3]    │ 0.0398708 │ 0.00523195  │ 0.000165449 │ 0.00158937 │ 4.01606 │ 2.09816  │
│ 7   │ T[3][1]    │ 0.430283  │ 0.0183518   │ 0.000580334 │ 0.00535442 │ 4.60138 │ 1.64891  │
│ 8   │ T[3][2]    │ 0.450252  │ 0.0186215   │ 0.000588864 │ 0.00555454 │ 4.5526  │ 1.5442   │
│ 9   │ T[3][3]    │ 0.119464  │ 0.00988752  │ 0.000312671 │ 0.0029534  │ 7.08331 │ 1.00227  │
│ 10  │ m[1]       │ 2.30276   │ 0.16282     │ 0.00514881  │ 0.0352831  │ 6.55215 │ 1.03373  │
│ 11  │ m[2]       │ 0.991943  │ 0.0645865   │ 0.00204241  │ 0.0153109  │ 10.7751 │ 1.04687  │
│ 12  │ m[3]       │ 0.159171  │ 0.148796    │ 0.00470534  │ 0.0471829  │ 4.01606 │ 1.76961  │
│ 13  │ s[1]       │ 1.994     │ 0.0772656   │ 0.00244335  │ 0.006      │ 6.49518 │ 1.00505  │
│ 14  │ s[2]       │ 1.991     │ 0.113719    │ 0.0035961   │ 0.009      │ 7.81415 │ 1.00528  │
│ 15  │ s[3]       │ 1.993     │ 0.0834144   │ 0.00263779  │ 0.007      │ 6.96785 │ 1.00607  │
│ 16  │ s[4]       │ 1.991     │ 0.0944877   │ 0.00298796  │ 0.009      │ 6.4939  │ 1.00811   
xukai92 commented 5 years ago

@itsdfish Do you mean https://turing.ml/dev/tutorials/4-bayeshmm/?

itsdfish commented 5 years ago

Yeah. PG seems to perform poorly on that model. I suppose the number of samples could be increased, but it would slow it down more.

itsdfish commented 4 years ago

Hi @yebai. Just out of curiosity, I was wondering if there are any status updates?

yebai commented 4 years ago

Hi @itsdfish, there are promising progress towards this goal, e.g.

These PRs are gradually paving the way for a JAGS-style sampler. There is still one important missing part, being able to represent and manipulate dynamic computational graphs to automatically derive Gibbs conditionals. It is quite hard to implement this in a generic way and @phipsgabler is still working on this in DynamicComputationGraphs.jl.

Also, @mohamed82008 found a way to use caching to speed up Gibbs substantially. This has a similar spirit to DynamicComputationGraphs in terms of saving unnecessary computation in Gibbs. See performance tips. We might automate this caching, or make it substantially easier to use (in fact, it's already easy to use) to provide efficient JAGS-style sampling.

Perhaps improving compiler to automate caching could be an interesting GSoC project? @mohamed82008 @cpfiffer

mohamed82008 commented 4 years ago

That might be a good one yes. Refactoring Gibbs sampling using traits might also be a good one. Personally though, my availability this summer might be a bit limited because I am having my wedding in July. So it will be hard to commit to any work in July. Let's see. I can write the proposal for now and let's worry about mentoring logistics later.

cscherrer commented 4 years ago

Congrats @mohamed82008!

mohamed82008 commented 4 years ago

Thanks :)

trappmartin commented 4 years ago

As a short update. @phipsgabler is working on a PR for Turing, implementing an interface for Gibbs conditionals. Feel free to comment and help if you feel like it. See: https://github.com/TuringLang/Turing.jl/pull/1172

And in the near future, there will even be a JAGS style Gibbs sampler. Which needs a bit more work but it seems that Philipp is doing good progress.

itsdfish commented 4 years ago

Hello-

Out of curiosity, can you provide a status update? Thanks!

trappmartin commented 4 years ago

Sure.

We recently merged the PR that allows users to use custom Gibbs conditionals and Philipp is currently finishing up his work on AutoGibbs, which automatically computes Gibbs conditionals for discrete RVs in any Turing model. The AutoGibbs code passes the test for simpler models atm. and will hopefully work for dynamic models soon too. Shouldn't take too long anymore.

yebai commented 2 years ago

Closed in favour of https://github.com/TuringLang/AbstractPPL.jl/pull/44