ReactiveBayes / RxInfer.jl

Julia package for automated Bayesian inference on a factor graph with reactive message passing
MIT License
265 stars 23 forks source link

Enhance Support for Broadcasting #165

Closed albertpod closed 2 months ago

albertpod commented 11 months ago

This issue has been opened following discussion #164, where users encountered difficulties while trying to assign distributions to matrix elements (e.g., A .~ Gamma(α,β)) within the model definition. Currently, RxInfer.jl does not accommodate such broadcasting operations, resulting in a MethodError.

We propose two primary enhancements to RxInfer.jl:

  1. Improved Error Handling: Introduce a clear and informative error message that is triggered when users attempt broadcasting operations that are not supported. This message should explicitly indicate the non-support of broadcasting in these scenarios and direct users towards established best practices (long-awaited sharp bits for RxInfer.jl).

  2. Support for Broadcasting in Models: Work on a feature that enables broadcasting operations within probabilistic models. (This task is complex and might require substantial effort.)

To address these needs, I will create the following subtasks:

gvdr commented 11 months ago

Hi Albert, many thanks for the answer at #164 :-)

Would something like

for i in eachindex(A)
    A[i] ~ Gamma(2.0,3.0)
end

work, instead of broadcasting, in the example I provided?

[edited: I got the point about the priors in Gamma as well, so I'm only talking about the broadcasting]

albertpod commented 11 months ago

@gvdr it's somewhat possible, but multiplying this matrix with a multivariate random variable isn't possible as there's no multiplication defined for Matrix{RandomVariable} and RandomVariable types in the toolbox. It's not saying that multiplication of matrix variate distributions and multivariate distributions is not defined. For example, if you want to multiply a matrix where elements are Gamma distributed, I would create a node called, GammaElementWise and then write something like this:

A ~ GammaElementWise(s=(2, 2), α=1.0, β=1.0)
for i in 1:n
        x[i] ~ MvNormalMeanPrecision(A*x_prev, Q)
        ...
end

This would create a node that corresponds to the prior on A, that RxInfer.jl will connect to the multiplication node. The posterior of A would be the product of two messages between GammaElementWise and multiplication node *. This could be an extremely efficient approach, although it requires deriving messages for the nodes and the product between these messages that yield posterior of A.

I've tried CVI once more and improved the inference results, although this required tuning of the optimizer (changed from Descent to ADAM) and adding a prior on the precision of transition noise QQ.

function f(x, a, b, c, d)
    A = [a b; c d]
    A*x
end

@model function rotate_ssm(n, x0, B, Q, P)

    # `x` is a sequence of hidden states
    x = randomvar(n)
    x̂ = randomvar(n)
    # `y` is a sequence of "clamped" observations
    y = datavar(Vector{Float64}, n)

    x_prior ~ MvNormalMeanCovariance(mean(x0), cov(x0))

    x_prev = x_prior

    A = randomvar(2, 2)

    for i in eachindex(A)
        A[i] ~ Gamma(1.0, 1.0)
    end

    # prior on the precision of transition noise
    QQ ~ Wishart(2, 1e-2Q)

    for i in 1:n
        # here you'd want to do  x̂[i] ~ f(x_prev, A...) or just f(x_prev, A), but we can't do that yet
        # so we do ugly way
        x̂[i] ~ f(x_prev, A[1, 1], A[1, 2], A[2, 1], A[2, 2])
        # MvNormalMeanCovariance -> MvNormalMeanPrecision
        x[i] ~ MvNormalMeanPrecision(x̂[i], QQ)
        y[i] ~ MvNormalMeanCovariance(B * x[i], P)
        x_prev = x[i]
    end

end

delta_meta = @meta begin 
    f() -> CVI(StableRNG(42), 100, 200, Optimisers.ADAM(0.1))
end

x0 = MvNormalMeanCovariance(zeros(2), 1.0 * diageye(2))
result = inference(model = rotate_ssm(length(y), x0, B, Q, P), options = (limit_stack_depth = 500, ), constraints=MeanField(), 
initmarginals = (QQ = Wishart(2,1e-2Q),x = MvNormalMeanCovariance(zeros(2), diageye(2)), x̂ = MvNormalMeanCovariance(zeros(2), diageye(2))),
initmessages=(A = GammaShapeRate(1.0, 1.0),), 
meta=delta_meta, data = (y = y,), free_energy=true, showprogress=true, iterations=5, returnvars=KeepLast())

@wouterwln I believe the new GraphPPL.jl broadcasting of this kind #164 is available and perhaps wouldn't require the creating of special GammaElementWise node? The release is expected this quarter.

wouterwln commented 11 months ago

I think we can rethink broadcasting in the new GraphPPL.jl to incorporate this behaviour, because I think it's quite nice to be able to do this. @bvdmitri and I will think this through, as it also aligns with some of the other issues that we currently have. @gvdr you can expect a follow up on this in the coming weeks, and an RxInfer release where this is possible somewhere this quarter.

bvdmitri commented 11 months ago

We can indeed support this in the new version of GraphPPL for the specification part, but the inference part still remains unsolved.

albertpod commented 11 months ago

Okay, so we don't throw an error just yet. The merge of the new GraphPPL is coming. The error perhaps should be thrown on the inference step.

albertpod commented 2 months ago

@wouterwln this can be closed?

wouterwln commented 2 months ago

I think so, the example as it is stated in the discussion can be done in any case