ReactiveBayes / RxInfer.jl

Julia package for automated Bayesian inference on a factor graph with reactive message passing
MIT License
265 stars 23 forks source link

Prediction of data variables when some are missing #355

Closed wouterwln closed 6 days ago

wouterwln commented 6 days ago

When we pass missing in place of some data, RxInfer is supposed to predict these variables (without having to terminate the graph with Uninformative). However, if we consider the following model, something is going wrong, and predictions are not being made. The model is a simple POMDP with a time horizon of 3 and an identity mapping between the states and the observations.

using RxInfer

@model function pred_model(p_s_t, y, goal, p_B, A)
    s[1] ~ p_s_t
    B ~ p_B
    y[1] ~ Transition(s[1], A)
    for i in 2:3
        s[i] ~ Transition(s[i-1], B)
        y[i] ~ Transition(s[i], A)
    end
    s[3] ~ Categorical(goal)
end

pred_model_constraints = @constraints begin
    q(s, B) = q(s)q(B)
end

@initialization function pred_model_init(q_B)
    q(B) = q_B
end

result = infer(model=pred_model(A=diageye(4), goal = [0, 1, 0, 0], p_B=MatrixDirichlet(ones(4,4)), p_s_t = Categorical([0.7, 0.3, 0, 0])), data = (y=[[1, 0, 0, 0], missing, missing],), initialization = pred_model_init(MatrixDirichlet(ones(4,4))), constraints = pred_model_constraints, iterations=10)
julia> last(result.predictions[:y])
3-element Vector{Categorical{Float64, Vector{Float64}}}:
 Categorical{Float64, Vector{Float64}}(support=Base.OneTo(4), p=[0.25, 0.25, 0.25, 0.25])
 Categorical{Float64, Vector{Float64}}(support=Base.OneTo(4), p=[0.25, 0.25, 0.25, 0.25])
 Categorical{Float64, Vector{Float64}}(support=Base.OneTo(4), p=[0.25, 0.25, 0.25, 0.25])

We can recreate this model by introducing an additional latent state s_int, where y[i] = s_int[i] such that we have a 'latent' state that is actually just the observations. In this way we should be able to alias y as a latent state and obtain predictions over y:


@model function pred_model(p_s_t, y, goal, p_B, A)
    s[1] ~ p_s_t
    B ~ p_B
    s_int[1] ~ Transition(s[1], diageye(4))
    y[1] ~ Transition(s_int[1], A)
    for i in 2:3
        s[i] ~ Transition(s[i-1], B)
        s_int[i] ~ Transition(s[i], diageye(4))
        y[i] ~ Transition(s_int[i], A)
    end
    s[3] ~ Categorical(goal)
end

result = infer(model=pred_model(A=diageye(4), goal = [0, 1, 0, 0], p_B=MatrixDirichlet(ones(4,4)), p_s_t = Categorical([0.7, 0.3, 0, 0])), data = (y=[[1, 0, 0, 0], missing, missing],), initialization = pred_model_init(MatrixDirichlet(ones(4,4))), constraints = pred_model_constraints, iterations=10)
julia> last(result.posteriors[:s_int])
3-element Vector{Categorical{Float64, Vector{Float64}}}:
 Categorical{Float64, Vector{Float64}}(support=Base.OneTo(4), p=[0.9999999999962583, 1.247264204453763e-12, 1.247264204453763e-12, 1.247264204453763e-12])
 Categorical{Float64, Vector{Float64}}(support=Base.OneTo(4), p=[0.9999999999939958, 2.0014457580302767e-12, 2.0014457580302767e-12, 2.0014457580302767e-12])
 Categorical{Float64, Vector{Float64}}(support=Base.OneTo(4), p=[2.7357588823252253e-12, 0.9999999999917927, 2.7357588823252253e-12, 2.7357588823252253e-12])

So we can see that the 'predictions' for y are actually not updated.

albertpod commented 6 days ago

@bvdmitri relates to #291?

wouterwln commented 6 days ago

I think I know what is up. Because y is factorized out, this rule is being called instead of this one. Because we're taking the log-mean of a matrix that contains 0s, this will contain -Inf values and after multiplying with probvec(q_in) and elementwise exponentiation this will give NaNs. For some reason RxInfer doesn't throw a warning or an error, but if I call the rule manually I do get an error. If I add 0.001 to A when passing it to pred_model, the predictions are fine again, so I'm fairly sure that this is the problem. I have no idea though why there is no error being thrown.

So indeed it is vaguely related to #291 but there is some other stuff going on that shouldn't really happen.

wouterwln commented 6 days ago

The problem occurs because of the Transition node, where the rules for q_in::DiscreteNonParametric and q_a::PointMass cause a weird effect when the transition matrix contains 0s. Therefore I think it is not a bug but an unfortunate consequence #291 in combination with the fact that a structured factorization between the two categorical interfaces of a Transition node introduces some weird side effect rule-wise (where if the Transition matrix contains either ones or zeros, the incoming probability vector is ignored). The behaviour is strictly speaking not wrong, just unexpected.