TuringLang / Turing.jl

Bayesian inference with probabilistic programming.
https://turinglang.org
MIT License
2.05k stars 219 forks source link

Transpilation of pure WinBUGS code when reimplementing Prior and Posterior Prediction #2148

Closed CMoebus closed 10 months ago

CMoebus commented 11 months ago

Hi, I am a newbee to Turing.jl. So I try to reimplement the WinBUGS scripts of Lee & Wagenmakers' book BAYESIAN COGNITIVE MODELING. Now, I am stuck with the problem 'Prior and Posterior Prediction' in ch3.4. I want to stay close to the WinBUGS code with only 5 code lines. I tried to not export any values out of the model macro. But my chain and the plots generate infos which are completely misleading. I looked around and found two guys semihcanaktepe; quangtiencs doing similar reimplementations. But the circumvented the pure WinBUGS equivalent and wrote more verbose code. For documentation purpose I attach the code below:

begin  # cf. BAYESIAN COGNITIVE MODELING 
    #     Lee & Wagenmakers, 2013, ch.3.4, Prior and Posterior Prediction
    #---------------------------------------------------------------------------
    using Turing, MCMCChains
    using LaTeXStrings
    using StatsPlots, Random
    #---------------------------------------------------------------------------
    @model function priorPosteriorPredictive(n; k=missing)
        #----------------------------------------------------
        # prior on rate θ
        θ ~ Beta(1, 1)
        #----------------------------------------------------
        # likelihood of observed data
        k ~ Binomial(n, θ)
        #----------------------------------------------------
        # prior predictive
        θPriorPred ~ Beta(1, 1)
        kPriorPred ~ Binomial(n, θPriorPred)
        #----------------------------------------------------
        # posterior predictive
        return kPostPred ~ Binomial(n, θ)
        #----------------------------------------------------
    end # function priorPosteriorPredictive
    #---------------------------------------------------------------------------
    modelPriorPredictive = let k = 1
        n = 15
        priorPosteriorPredictive(n)
    end # let
    #---------------------------------------------------------------------------
    chainPriorPredictive =                          # chain is ok
        let iterations = 3000
            sampler = Prior()
            sample(modelPriorPredictive, sampler, iterations)
        end # let       
    #---------------------------------------------------------------------------
    describe(chainPriorPredictive)                  # results are ok
    #---------------------------------------------------------------------------
    plot(chainPriorPredictive; normalize=true)      # plots are ok
    #---------------------------------------------------------------------------
    modelPosteriorPredictive = let k = 1
        datum = k
        n = 15
        # priorPosteriorPredictive(n)          ,# prior predictive without datum
        priorPosteriorPredictive(n; k=datum)    # posterior predictive including datum
    end # let
    #---------------------------------------------------------------------------
    chainPosteriorPredictive =                      # completely misleading
        let iterations = 3000
            nBurnIn = 1000
            δ = 0.65
            init_ϵ = 0.3
            sampler = NUTS(nBurnIn, δ; init_ϵ=init_ϵ)
            sample(modelPosteriorPredictive, sampler, iterations)
        end # let
    #---------------------------------------------------------------------------
    describe(chainPosteriorPredictive)              # completely misleading
    #---------------------------------------------------------------------------
    plot(chainPosteriorPredictive; normalize=true)  # completely misleading
    #---------------------------------------------------------------------------
end # begin
sunxd3 commented 11 months ago

I think the issue here is that θPriorPred, kPriorPred, and kPostPred throw NUTS off quite a bit.

If the model is written as

using Turing, MCMCChains, StatsPlots

@model function priorPosteriorPredictive(n; k=missing)
    #----------------------------------------------------
    # prior on rate θ
    θ ~ Beta(1, 1)
    #----------------------------------------------------
    # likelihood of observed data
    k ~ Binomial(n, θ)
    #----------------------------------------------------
    # prior predictive
    # θPriorPred ~ Beta(1, 1)
    # kPriorPred ~ Binomial(n, θPriorPred)
    #----------------------------------------------------
    # posterior predictive
    # return kPostPred ~ Binomial(n, θ)
    #----------------------------------------------------
end # function priorPosteriorPredictive

modelPosteriorPredictive = let k = 1
    datum = k
    n = 15
    # priorPosteriorPredictive(n),          # prior predictive without datum
    priorPosteriorPredictive(n; k=datum)    # posterior predictive including datum
end # let

chainPosteriorPredictive =                      # completely misleading
    let iterations = 3000
        nBurnIn = 1000
        δ = 0.65
        init_ϵ = 0.3
        sampler = NUTS(nBurnIn, δ; init_ϵ=init_ϵ)
        sample(modelPosteriorPredictive, sampler, iterations)
    end # let

plot(chainPosteriorPredictive)

(where I commented the three predictive variables) the plot looks like plot_31 much better.

Alternatively, importance-sampling based samplers will likely perform better here

@model function priorPosteriorPredictive(n)
    #----------------------------------------------------
    # prior on rate θ
    θ ~ Beta(1, 1)
    #----------------------------------------------------
    # likelihood of observed data
    k ~ Binomial(n, θ)
    #----------------------------------------------------
    # prior predictive
    θPriorPred ~ Beta(1, 1)
    kPriorPred ~ Binomial(n, θPriorPred)
    #----------------------------------------------------
    # posterior predictive
    return kPostPred ~ Binomial(n, θ)
    #----------------------------------------------------
end # function priorPosteriorPredictive

modelPosteriorPredictive = let k = 1
    datum = k
    n = 15
    # priorPosteriorPredictive(n),          # prior predictive without datum
    priorPosteriorPredictive(n)  | (; k=datum)  # posterior predictive including datum
end # let

chainPosteriorPredictive = sample(modelPosteriorPredictive, PG(10), 3000)

plot(chainPosteriorPredictive)

(PG doesn't support models with keyword argument, so I did a simple rewrite with the DynamicPPL.condition syntax) plot_32

CMoebus commented 11 months ago

Hi sunxd3, thank you for the advice to use the sampler 'PG' and the condition bar '|'. I always wondered what the error comment "...does not support keyword arguments" meant. Now, with the '|' this error diappeared. Furthermore 'PG' solves the sampling problem. Before your comment I read (https://turinglang.org/v0.30/tutorials/04-hidden-markov-model/) that one possibility is to sample continuous variables with 'HMC' and discrete with 'PG'. So I tried the combined sampler'Gibbs(HMC...), PG(...))'. But I had some difficulties to get it run. Do you have experience with e.g. 'Gibbs(HMC(0.01, 50, ...), PG(120, ...))' All the best, C.

sunxd3 commented 11 months ago

You can specify which samplers are in charge of which variable(s) like

@model function priorPosteriorPredictive(n)
   θ ~ Beta(1, 1)
   k ~ Binomial(n, θ)
   θPriorPred ~ Beta(1, 1)
   kPriorPred ~ Binomial(n, θPriorPred)
   kPostPred ~ Binomial(n, θ)
   return θPriorPred, kPriorPred, kPostPred
end

model = priorPosteriorPredictive(15) | (; k=1) # creating the model

chn = sample(model, Gibbs(HMC(0.05, 10, :θ), NUTS(-1, 0.65, :θPriorPred), PG(100, :k, :kPriorPred, :kPostPred)), 1000) # use HMC for `θ`, NUTS for `θPriorPred`, and PG for the rest.

gives

Chains MCMC chain (1000×5×1 Array{Float64, 3}):

Iterations        = 1:1:1000
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 51.93 seconds
Compute duration  = 51.93 seconds
parameters        = θ, θPriorPred, kPriorPred, kPostPred
internals         = lp

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64

           θ    0.1121    0.0758    0.0068   100.5133   141.6987    0.9997        1.9354
  θPriorPred    0.5095    0.2863    0.0687    19.3442   164.4584    1.0147        0.3725
  kPriorPred    7.6210    4.6185    1.0676    19.8468        NaN    1.0166        0.3822
   kPostPred    1.6550    1.5743    0.1143   169.2209   195.1175    1.0025        3.2584

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           θ    0.0170    0.0537    0.0989    0.1555    0.2934
  θPriorPred    0.0377    0.2466    0.5016    0.7709    0.9697
  kPriorPred    0.0000    3.0000    8.0000   12.0000   15.0000
   kPostPred    0.0000    0.0000    1.0000    2.0000    5.0000
sunxd3 commented 11 months ago

@CMoebus we also have a package within Turing ecosystem that supports BUGS language directly, https://github.com/TuringLang/JuliaBUGS.jl, but currently in development and not feature complete.

We'll appreciate it if you give it a try and report issues as there are definitely a lot, but we'll try to fix them ASAP.

CMoebus commented 11 months ago

@sunxd3: Thank you again. Thank you for inviting me to become a JuliaBUGS.jl tester. Just a few weeks ago I started transpiling WinBugs scripts into pure Turing.jl. I liked the declarative, math-oriented style of BUGS. But at the same time, it is tedious if you need some calculations outside the BUGS language. A few years ago I switched to WebPPL. I liked its functional style. But Turing.jl and its embedding in Julia seem to be more promising.

sunxd3 commented 10 months ago

@CMoebus sorry for the late reply.

I liked the declarative, math-oriented style of BUGS. But at the same time, it is tedious if you need some calculations outside the BUGS language.

One of the goal of the JuliaBUGS project is to make this much easier and give user access to other Julia packages.

But Turing.jl and its embedding in Julia seem to be more promising.

As a maintainer and user of Turing, thanks for the support!