ReactiveBayes / RxInfer.jl

Julia package for automated Bayesian inference on a factor graph with reactive message passing
MIT License
254 stars 24 forks source link

`initmarginals` in `rxinference` confusing (not starting inference) #77

Open bartvanerp opened 1 year ago

bartvanerp commented 1 year ago

Ran into an issue that I think a lot of people will get who want to try out the rxinference() function. For the simple model

@model function observation_model(Hs, Qs, R)

    # specify data variables
    μ_s = datavar(Vector{Float64})
    Σ_s = datavar(Matrix{Float64})

    # specify priors
    s_prev ~ MvNormalMeanCovariance(μ_s, Σ_s)

    # add process noise
    s ~ MvNormalMeanCovariance(s_prev, Qs)

    # form observation
    y = datavar(Float64)
    y ~ NormalMeanVariance(dot(Hs, s), R)

    # return variables
    return y, s, n 

end

with auto updates:

autoupdates = @autoupdates begin
    μ_s, Σ_s = mean_cov(q(s))
end;

I started off with the following implementation:

rxinference(
        model = observation_model(Hs, Qs, R),
        data          = (y = data, ),
        autoupdates   = autoupdates,
        initmarginals = (
            s_prev = vague(MvNormalMeanCovariance, deployed_model_signal.dim_in), 
        ), 
        returnvars    = (:s,),
        keephistory   = length(data),
        historyvars   = (s = KeepLast(),),
        autostart     = true,
   )

which did not start the inference procedure, i.e. methoderror: iterate(::Nothing) as a result of mean_cov(::Missing). It turned out that initmarginals first sets the marginals, then autoupdates is called and then inference actually starts. Below code did run:

rxinference(
        model = observation_model(Hs, Qs, R),
        data          = (y = data, ),
        autoupdates   = autoupdates,
        initmarginals = (
            s= vague(MvNormalMeanCovariance, deployed_model_signal.dim_in), 
        ), # specifies the initial q(s) to get the inference starting (i.e. first autoupdates, then message passing)
        returnvars    = (:s,),
        keephistory   = length(data),
        historyvars   = (s = KeepLast(),),
        autostart     = true,
   )

I understand why it is implemented as such, but it is not clear from the error message what is going wrong. Perhaps we can improve the error handling by first check whether all marginals in autoupdates are specified in the initmarginals struct. If this is not the case we should throw an error message stating that first marginals are set, then autoupdates is called and then data is fed into the model.

bvdmitri commented 10 months ago

61 link