TuringLang / Turing.jl

Bayesian inference with probabilistic programming.
https://turinglang.org
MIT License
2.05k stars 219 forks source link

HMM example model with missing data yields Malformed dims error #1359

Open finf281 opened 4 years ago

finf281 commented 4 years ago

If we take the HMM Example from the tutorials section of this project's documentation, everything works as expected. However, if we set one of the values in the generated y data vector to missing, we get a Malformed dims error on the creation of a TArray.

using Turing, MCMCChains
using Distributions
using StatsPlots
using Random

# Define the emission parameter.
y = [ 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, missing, 3.0, 3.0, 2.0, 2.0, 2.0, 1.0, 1.0 ];
N = length(y);  K = 3;

# Turing model definition.
@model BayesHmm(y, K) = begin
    # Get observation length.
    N = length(y)

    # State sequence.
    s = tzeros(Int, N)

    # Emission matrix.
    m = Vector(undef, K)

    # Transition matrix.
    T = Vector{Vector}(undef, K)

    # Assign distributions to each element
    # of the transition matrix and the
    # emission matrix.
    for i = 1:K
        T[i] ~ Dirichlet(ones(K)/K)
        m[i] ~ Normal(i, 0.5)
    end

    # Observe each point of the input.
    s[1] ~ Categorical(K)
    y[1] ~ Normal(m[s[1]], 0.1)

    for i = 2:N
        s[i] ~ Categorical(vec(T[s[i-1]]))
        y[i] ~ Normal(m[s[i]], 0.1)
    end
end;

g = Gibbs(HMC(0.001, 7, :m, :T), PG(20, :s))
c = sample(BayesHmm(y, 3), g, 100);
trappmartin commented 4 years ago

This is unfortuanitely not easy to fix at the moment. (I think.)

Here is a slightly hacky solution to the problem:

@model function hmm(yobs, obsidx, K, N, ::Type{Ty} = Float64) where {Ty}

        ymis = Vector{Ty}(undef, N-length(obsidx))

        # State sequence.
        s = tzeros(Int, N)

        # Emission matrix.
        m ~ arraydist([Normal(i, 0.5) for i in 1:K])

        # Transition matrix.
        T ~ filldist(Dirichlet(K, 1/K), K)

        jo, jm = 1, 1

        # Observe each point of the input.
        s[1] ~ Categorical(K)
        dist = Normal(m[s[1]], 0.1)
        if 1 ∈ obsidx
            yobs[jo] ~ dist
            jo += 1
        else
            ymis[jm] ~ dist
            jm += 1
        end

        for i = 2:N
            s[i] ~ Categorical(T[:,s[i-1]])
            dist = Normal(m[s[i]], 0.1)
            if i ∈ obsidx
                yobs[jo] ~ dist
                jo += 1
            else
                ymis[jm] ~ dist
                jm += 1
            end
        end
end

y = [ 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, missing, 3.0, 3.0, 2.0, 2.0, 2.0, 1.0, 1.0 ];
N, K = length(y), 3;

obsidx = findall(!ismissing, y)
yobs = y[obsidx]

m = hmm(yobs, obsidx, K, N)
g = Gibbs(HMC(0.01, 5, :m, :T, :ymis), PG(10, :s))
c = sample(m, g, 100);