ReactiveBayes / RxInfer.jl

Julia package for automated Bayesian inference on a factor graph with reactive message passing
MIT License
237 stars 24 forks source link

Vector of inputs in `streaming_inference` #318

Open wouterwln opened 2 weeks ago

wouterwln commented 2 weeks ago

If we add a vector/tensor of inputs in streaming inference and create a (data)variable for every entry, we get the following error message:

MethodError: no method matching is_data(::Vector{RxInfer.GraphVariableRef})

Closest candidates are:
  is_data(!Matched::RxInfer.GraphVariableRef)
   @ RxInfer ~/.julia/packages/RxInfer/SROpQ/src/model/plugins/reactivemp_inference.jl:229
  is_data(!Matched::GraphPPL.VariableNodeProperties)
   @ GraphPPL ~/.julia/packages/GraphPPL/ke7hR/src/graph_engine.jl:696

MWE:

@model function test_model(x, y, mx, vx)
    for i in 1:3
        x[i] ~ NormalMeanVariance(mx, vx)
    end
    my ~ NormalMeanVariance(0, 1)
    y ~ NormalMeanVariance(my, 1.0)
end

d = [(x = rand(3),y = rand()) for i in 1:10]
datastream = from(d) |> map(NamedTuple{(:x, :y), Tuple{Vector{Float64}, Float64}}, (d) -> d)

foo(x) = 1.0

autoupdates = @autoupdates begin
    mx = foo(q(my))
    vx = foo(q(my))
end

The following code runs and gives a result:

infer(model = test_model(mx = 1.0, vx = 1.0), data=(x = rand(3), y = 0.0), iterations=10, showprogress=true)

When we run streaming inference the error message is being thrown:

infer(model = test_model(), datastream=datastream, autoupdates = autoupdates, initialization = @initialization begin q(my) = NormalMeanVariance(1.0, 1.0) end)

The following fixes this, but might not be the most rigorous fix:

RxInfer.is_data(vector::Vector{RxInfer.GraphVariableRef}) = all(RxInfer.is_data.(vector))