TuringLang / AdvancedPS.jl

Implementation of advanced Sequential Monte Carlo and particle MCMC algorithms
https://turinglang.org/AdvancedPS.jl/
MIT License
56 stars 9 forks source link

PGAS example where `state` consists of a tuple of distributions? #82

Open YSanchezAraujo opened 1 year ago

YSanchezAraujo commented 1 year ago

I'm wondering if a model like the one I present below is possible? The basic problem is one where the state isn't a single distribution, but a collection of distributions, which all evolve in a Markovian manner. I don't know exactly how this works internally, so the code is based on the assumption that the state is propagated forward from initialization to transition to observation.

n_trials, n_cols = size(X)

Parameters = @NamedTuple begin
    X::Matrix
    lam_lapse_init::Float64
    sigma_set_init::Array{Float64}
    mu_init::Array{Float64}
    n_trials::Int64
    n_cols::Int64
end

mutable struct PF <: AdvancedPS.AbstractStateSpaceModel
    W::Matrix
    lam_lapse::Array
    sigma_set::Matrix
    theta::Parameters
    PF(theta::Parameters) = new(
        zeros(Float64, theta.n_trials, theta.n_cols),
        zeros(Float64, theta.n_trials),
        zeros(Float64, theta.n_trials, theta.n_cols),
        theta
    )
end

function init_step(m::PF)
    return (
        truncated(Normal(m.theta.lam_lapse_init, 0.1), lower=-10),
        truncated(Normal(m.theta.sigma_set_init[1], 0.1), lower=0.),
        truncated(Normal(m.theta.sigma_set_init[2], 0.1), lower=0.),
        truncated(Normal(m.theta.sigma_set_init[3], 0.1), lower=0.),
        MvNormal(m.theta.mu_init, 1.)
    )
end

AdvancedPS.initialization(m::PF) = init_step(m)

function transition_step(m::PF, state)
    return (
        truncated(Normal(state[1], 0.1), lower=-10), # lam_lapse
        truncated(Normal(state[2], 0.1), lower=0.), # sigma1
        truncated(Normal(state[3], 0.1), lower=0.), # sigma2
        truncated(Normal(state[4], 0.1), lower=0.), # sigma3
        MvNormal(state[5], Diagonal([state[2], state[3], state[4]])) # mu
    )
end

AdvancedPS.transition(m::PF, state) = transition_step(m, state)

function obs_density(m::PF, state, t)
    lam_lapse, _, _,_, mu = state
    lapse = logistic(lam_lapse)
    prob = (1 - lapse) * logistic(m.theta.X[t, :]'mu) + lapse * 0.5
    return Bernoulli(prob)
end  

function AdvancedPS.observation(m::PF ,state, t)
    return logpdf(obs_density(m, state, t), y[t])
end

AdvancedPS.isdone(m::PF, t) = t > m.theta.n_trials

n_particles = 20
n_samples = 200
rng = MersenneTwister(2342)

theta0 = Parameters(
    (-9, zeros(3), zeros(3), n_trials, n_cols)
    )

model = PF(theta0)
pgas = AdvancedPS.PGAS(n_particles)
chains = sample(rng, model, pgas, n_samples; progress=true);
YSanchezAraujo commented 1 year ago

looking at advance function in pgas

it looks like it's not possible in this formulation? In my case:


rand(init_step(model)) # 

# will just give a random element of: 

(
        truncated(Normal(m.theta.lam_lapse_init, 0.1), lower=-10),
        truncated(Normal(m.theta.sigma_set_init[1], 0.1), lower=0.),
        truncated(Normal(m.theta.sigma_set_init[2], 0.1), lower=0.),
        truncated(Normal(m.theta.sigma_set_init[3], 0.1), lower=0.),
        MvNormal(m.theta.mu_init, 1.)
    )

it seems the workout would be to allow for

rand.(init_step(model))

?

yebai commented 11 months ago

cc @FredericWantiez