ReactiveBayes / RxInfer.jl

Julia package for automated Bayesian inference on a factor graph with reactive message passing
MIT License
259 stars 24 forks source link

Not sure how to set initmarginals #39

Closed SupplyChef closed 1 year ago

SupplyChef commented 1 year ago

I am trying to implement some of the models described in 'State Space Time Series Analysis' by Commandeur and Koopman. Overall RxInfer makes this easy. However I am having troubles with one model:

@model function locallevel_seasonal(n, σ², σ²_noise)
    level = randomvar(n)
    seasons = randomvar(12)

    for i in 1:12
        seasons[i] ~ NormalMeanPrecision(10, 0.1)
    end

    y = datavar(Float64, n) 

    l0 ~ NormalMeanPrecision(10, 0.1)
    level_prev = l0
    for i in 1:n
        level[i] ~ Normal(mean = level_prev, precision = 1 / σ²)
        y[i] ~ Normal(mean = level[i] + seasons[mod1(i, 12)], precision = 1 / σ²_noise)
        level_prev = level[i]
    end
    return level, y
end

When running, I get a warning about setting the initial marginals:

Variables [ level, seasons, l0 ] have not been updated after an update event. 
Therefore, make sure to initialize all required marginals and messages. See `initmarginals` and `initmessages` keyword arguments for the inference function. 

It's not clear to me how to add the initial marginals. The variables are vectors of variables. I tried passing in arrays of distributions (see below) but I keep getting the same warning.

inference(...,
initmarginals = (l0=NormalMeanPrecision(10, 0.1),
                         seasons=[NormalMeanPrecision(10, 0.1),NormalMeanPrecision(10, 0.1),NormalMeanPrecision(10, 0.1),NormalMeanPrecision(10, 0.1),NormalMeanPrecision(10, 0.1),NormalMeanPrecision(10, 0.1),NormalMeanPrecision(10, 0.1),NormalMeanPrecision(10, 0.1),NormalMeanPrecision(10, 0.1),NormalMeanPrecision(10, 0.1),NormalMeanPrecision(10, 0.1),NormalMeanPrecision(10, 0.1)],)
    )

Adding some examples on how to set the initial marginals would be very helpful.

Thank you.

bvdmitri commented 1 year ago

Hey @SupplyChef , Your model has structured loops in the corresponding factor graph, thus you need to initialise either marginals or messages. Due to the fact that you do not use constraints = ... keyword argument it means to me that you do Loopy Belief Propagation. That means that you need to use initmessages = ... instead:

n = 50
inference(
    model = locallevel_seasonal(n, 1.0, 1.0),
    data = (y = rand(n), ),
    initmessages = (
        # vector of variables simply uses the same value for all elements
        # but it is possible to provide a vector of distributions to set each element individually 
        seasons = vague(NormalMeanPrecision),  
    ),
    free_energy = true,
    iterations = 10
)
Inference results:
  Posteriors       | available for (level, seasons, l0)
  Free Energy:     | Real[238.264, 161.561, 145.516, 139.709, 135.625, 132.003, 128.846, 126.068, 123.628, 121.484  …  106.045, 106.045, 106.045, 106.046, 106.046, 106.046, 106.046, 106.046, 106.046, 106.046]

In addition, you need to set the iterations = ... keyword argument, because the Loopy Belief Propagation requires some iterations to find an answer. Note, however, that Loopy Belief Propagation has no convergence guaranties and the initialization may affect your results significantly.

bartvanerp commented 1 year ago

Just adding a bit to @bvdmitri explanation as to why this warning is actually thrown:

From you model specification it seems to me that the model is able to perform exact inference. This means that RxInfer will perform inference through sum-product message passing. As a result, there is indeed no need to make variational approximations of the marginal distributions.

The error message specifically refers to initmarginals or initmessages, which suggests that the model somehow does not have enough information to compute messages and is stuck in an infinitely long waiting loop. As you are using exact inference, there is no need to specify the initmarginals argument, as this solely refers to variational inference. The initmessages argument, however, can be relevant here. I gave your model a try for n=12 (which worked) and n=13 (which gave the same error as in your case). This suggests that increasing the size of the graph gives rise to your errors. In fact, by increasing the size of the graph in this case, we start to create loops through the seasons variables. More specifically, the variable y[13] depends on seasons[1] and on level[13], which in turn depends on level[12], ..., level[1], which again depends on seasons[1]. Because of this loop, you are actually performing loopy belief propagation for n>12, which requires the initialization of the messages in order to start the inference procedure.

We hope this solves the issue, however, we will keep it open as a reminder to extend the documentation.