ReactiveBayes / RxInfer.jl

Julia package for automated Bayesian inference on a factor graph with reactive message passing
MIT License
242 stars 24 forks source link

Matrix{RandomVariable} * RandomVariable errors #163

Closed gvdr closed 9 months ago

gvdr commented 9 months ago

Hello. I'm trying to extend one of the examples but hitting a wall.

I'm building upon the Gaussian Linear Dynamical System example.

In particular I'm trying to generalise it so that we don't assume to know the transition matrix $A$ a priori (eventually, I'd like to get to a scenario where we learn most of those matrices). In the example $A$ is passed as an input, and then converted to a constantvar to be used in the model. I would like it to be given a prior in the model, and then inferred together with the rest of the parameters.

I tried a couple of things, but I can't make it work. A minimum reproducible example is consist in redefining model and inference as follows:

@model function rotate_ssm(n, x0, B, Q, P)

    # We create constvar references for better efficiency
    cB = constvar(B)
    cQ = constvar(Q)
    cP = constvar(P)

    # `x` is a sequence of hidden states
    x = randomvar(n)
    # THIS IS WHERE I DIVERGE FROM EXAMPLE:
    A = randomvar(2,2)
    # `y` is a sequence of "clamped" observations
    y = datavar(Vector{Float64}, n)

    x_prior ~ MvNormalMeanCovariance(mean(x0), cov(x0))
    x_prev = x_prior

    α ~ Uniform(0.0,3.0)
    β ~ Uniform(0.0,3.0)

    A .~ Gamma(α,β)

    for i in 1:n
        x[i] ~ MvNormalMeanCovariance(A*x_prev, cQ)
        y[i] ~ MvNormalMeanCovariance(cB * x[i], cP)
        x_prev = x[i]
    end

end

And I change the inference function accordingly

result = inference(
    model = rotate_ssm(length(y), x0, B, Q, P), 
    data = (y = y,),
    free_energy = true
)

I tried to specify $A$ in a number of ways, but never getting any luck.

When I try to run it, I get the following error:

ERROR: MethodError: no method matching make_node(::typeof(*), ::FactorNodeCreationOptions{FullFactorisation, Nothing, Nothing}, ::RandomVariable, ::Matrix{RandomVariable}, ::RandomVariable)
albertpod commented 9 months ago

Hi @gvdr! Thanks for trying out RxInfer.jl. Long story short, broadcasting isn't supported in RxInfer.jl. The reason is mainly the intricacies of graph construction and inference. I see you are trying to build a hierarchical prior by introducing a Uniform prior on top of the parameters of Gamma. This is currently not supported out of the box as well.

One way to circumvent the current problem is somewhat similar to what was suggested in this discussion: https://github.com/biaslab/RxInfer.jl/discussions/156 Check docs for Delta node as well: DISCLAIMER: (1) the code below looks ugly, but this is the only way to enforce Gamma prior for each element of your transition matrix. (2) CVI will be slow, hyperparameter-dependent and inaccurate to mean field constraint in this case; I don't think that inference is accurate, so sampling-based toolboxes (such as Turing or numpyro) could make the inference faster.

# assuming your matrix A has 4 elements. 
# ugly, I know
function f(x, a, b, c, d)
    A = [a b; c d]
    A*x
end

@model function rotate_ssm(n, x0, B, Q, P)

    # We create constvar references for better efficiency
    cB = constvar(B)
    cQ = constvar(Q)
    cP = constvar(P)

    # `x` is a sequence of hidden states
    x = randomvar(n)
    x̂ = randomvar(n)
    # `y` is a sequence of "clamped" observations
    y = datavar(Vector{Float64}, n)

    x_prior ~ MvNormalMeanCovariance(mean(x0), cov(x0))
    x_prev = x_prior

    a = randomvar(length(Q))

    for i in 1:length(Q)
        a[i] ~ Gamma(α=1.0, β=1.0)
    end

    for i in 1:n
        # here you'd want to do  x̂[i] ~ f(x_prev, a), but we can't do that yet, but it's coming
        # so we do ugly way
        x̂[i] ~ f(x_prev, a[1], a[2], a[3], a[4])
        x[i] ~ MvNormalMeanCovariance(x̂[i], cQ)
        y[i] ~ MvNormalMeanCovariance(cB * x[i], cP)
        x_prev = x[i]
    end

end

delta_meta = @meta begin 
    f() -> CVI(StableRNG(42), 100, 200, Optimisers.Descent(0.01))
end

x0 = MvNormalMeanCovariance(zeros(2), 1.0 * diageye(2))
result = inference(model = rotate_ssm(length(y), x0, B, Q, P), options = (limit_stack_depth = 500, ), constraints=MeanField(), 
initmarginals = (x = MvNormalMeanCovariance(zeros(2), 1e4diageye(2)), x̂ = MvNormalMeanCovariance(zeros(2), 1e4diageye(2))),
initmessages=(a = GammaShapeRate(1e-2, 1e2),), 
meta=delta_meta, data = (y = y,), free_energy=true, showprogress=true, iterations=5, returnvars=KeepLast())

Initial marginals are needed due to mean-field constraint; the init message is necessary for the CVI approximation method.

I will convert this issue into a discussion. I will create two issues associated with your question: (1) throwing an error on broadcasting, (2) dealing with broadcasting (this will take some time)

There's of always room for creating your node and the associated rules with that.

To get a better idea what RxInfer.jl can or cannot support, the understanding of Forney-style factor graphs helps significantly.