ReactiveBayes / RxInfer.jl

Julia package for automated Bayesian inference on a factor graph with reactive message passing
MIT License
273 stars 23 forks source link

`Linearization` meta does not allow DataVariable input #30

Closed wmkouw closed 11 months ago

wmkouw commented 1 year ago

The delta node with Linearization meta does not allow me to define a function with both a RandomVariable and DataVariable argument.

MWE (adapted from Nonlinear Virus Spread demo):

@model function virus_spread(n)

    x = randomvar(n)
    y = datavar(Float64, n)
    a = datavar(Float64)

    x_0 ~ NormalMeanVariance(1.0, 10.0)

    x_prev = x_0
    for i in 1:n
        x[i] ~ g(x_prev, a)
        y[i] ~ NormalMeanVariance(x[i], 0.1)
        x_prev = x[i]
    end
end

meta = @meta begin 
    g() -> Linearization()
end

result = inference(
    model = virus_spread(length(y_data)), 
    data = (y = y_data, a = 1.0),
    meta = meta,
    options = (limit_stack_depth = 100, ),
    returnvars = KeepLast(), 
)

returns:

MethodError: no method matching __make_delta_fn_node(::typeof(g), ::FactorNodeCreationOptions{FullFactorisation, Linearization, Nothing}, ::RandomVariable, ::Tuple{RandomVariable, DataVariable{PointMass{Float64}, Rocket.RecentSubjectInstance{Message{PointMass{Float64}}, Subject{Message{PointMass{Float64}}, AsapScheduler, AsapScheduler}}}})
Closest candidates are:
  __make_delta_fn_node(::F, ::FactorNodeCreationOptions, ::AbstractVariable, ::Tuple{Vararg{var"#s2203", N}} where var"#s2203"<:AbstractVariable) where {F<:Function, N}

The closest candidate does seem to allow subtypes of AbstractVariable as inputs. In fact, the function call indicates ::Tuple{RandomVariable, DataVariable{PointMass{Float64} .. and the candidate allows ::Tuple{Vararg{var"#s2203", N}} where var"#s2203"<:AbstractVariable. So that should match. I don't really understand why it throws the error.

Related to https://github.com/biaslab/RxInfer.jl/issues/9? Also, documentation as suggested by Albert here https://github.com/biaslab/RxInfer.jl/issues/46 would be useful.

A related question: should we allow GraphPPL to generate intermediate expressions such as gg(x) = g(x, a) and apply the delta node to those? Or is creating expressions dynamically too inefficient?

Versions:

bvdmitri commented 1 year ago

@wmkouw Yeah, we should clarify that in the documentation, but that is the known limitation and it is not related to the biaslab/ReactiveMP.jl#9. Linearization approximation implementation both in ReactiveMP & ForneyLab requires all inputs to be of the Gaussian type.

Intermediate expressions will not help in general, as it would work only with constvars, but not with datavars.

I think I have discussed this issue with someone (Tim?). The consensus was that we would like to support (IMO its perfectly valid to specify this) this case. But short story long its not straightforward nor trivial to implement and will require some time.

The main problem here is that we need to compute a joint marginal over input arguments and its not entirely clear for me what is the joint marginal over inputs in case if one of the variables is clamped. + node for example has it as a special case and implements hard-coded rules. So we need to, again, treat it as a special case and dynamically discard some edges from the resulting marginal. That would lead to the change in the Bethe Free Energy computation logic as well.

That is good that you've opened the issue, such that I do not forget about that, but for now you'll have to soften the input a (or fuse it into the function outside of the @model macro).

wmkouw commented 1 year ago

Ah, ok. Then I'll remove the bug label and close the issue.

bvdmitri commented 1 year ago

I would like to keep the issue itself if you don't mind :) Such that we don't forget and discuss during our next RxInfer meeting

bvdmitri commented 11 months ago

This task has been added to the milestone for tracking and prioritization.