wazizian / OnlineSampling.jl

Online inference on reactive probabilistic models, with SMC and symbolic methods
MIT License
15 stars 3 forks source link

Overloading the broadcast operation (like adding a scalar) #38

Open mlelarge opened 2 years ago

mlelarge commented 2 years ago

In streaming with belief_propagation, the algorithm is sampling unnecessary nodes... Here is the example:

using OnlineSampling
using PDMats
using Distributions

speed = 10.0
trans1 = 5.0
trans_noise = 5.0
noise = 0.5

@node function model()
    @init x0 = rand(MvNormal([0.0], ScalMat(1, 1000.0))) 
    x0 = rand(MvNormal(@prev(x0), ScalMat(1, speed)))    
    x1 = rand(MvNormal(x0 .+ trans1, ScalMat(1, trans_noise)))
    x2 = rand(MvNormal(2 .* x0, ScalMat(1, trans_noise)))
    y1 = rand(MvNormal(x1, ScalMat(1, noise)))
    y2 = rand(MvNormal(x2, ScalMat(1, noise)))            
    return x0, x1, x2, y1, y2
end
@node function hmm(obs1, obs2)
    x0,x1,x2, y1,y2 = @nodecall model() 
    @observe(y1, obs1)         
    @observe(y2, obs2) 
    return x0, x1, x2
end

steps = 2
obs = reshape(Vector{Float64}(1:steps), (steps, 1)) 
cloudbp = @noderun particles = 1 algo = belief_propagation hmm(eachrow(obs),eachrow(obs)) # launch the inference with 1 particles for all observations
d = dist(cloudbp.particles[1])
d[1]

here you see that d[1] (corresponding to the distribution of x0 is a Dirac and it should not. Here is the corresponding code for the pure bp (i.e. without streaming):

using OnlineSampling.BP
using OnlineSampling.CD
using LinearAlgebra
gm = BP.GraphicalModel(Int)
x0 = initialize!(gm, MvNormal([0.0], ScalMat(1, 1000.0)))
x1 = initialize!(gm,CdMvNormal(I(1), [trans1], ScalMat(1, trans_noise)), x0)
x2 = initialize!(gm,CdMvNormal(2.0 * I(1), [0.0], ScalMat(1, trans_noise)), x0)
y1 = initialize!(gm,CdMvNormal(I(1), [0.0], ScalMat(1, noise)), x1)
y2 = initialize!(gm,CdMvNormal(I(1), [0.0], ScalMat(1, noise)), x2)

observe!(gm, y1, [1.0])
observe!(gm, y2, [1.0])
gm.nodes[x0]

Here you see that x0 is Initialized with a conditional distribution which is the expected behavior.

mlelarge commented 2 years ago

The problem is coming form the broadcast which has not been overloaded... Using vector for trans1 and matrix for multiplication solves the issue. I am changing the label and the title.