ReactiveBayes / RxInfer.jl

Julia package for automated Bayesian inference on a factor graph with reactive message passing
MIT License
270 stars 23 forks source link

Stuck on MarginalRuleMethodError warning #37

Closed SupplyChef closed 1 year ago

SupplyChef commented 1 year ago

I am trying to modify the GP regression example to create a local level (forecasting) model. I tried to follow the Missing Data example to handle the missing data. However I get a warning about MarginalRuleMethodError and cannot find documentation/examples on how to proceed. I would echo issue #15 about having more examples for forecasting cases. Thank you!

My model is defined as:

@model function locallevel(n, σ², σ²_noise)
    f_0 ~ Normal(mean = 0, precision = 1 / σ²)

    f = randomvar(n)
    y = datavar(Float64, n) where { allow_missing = true }

    f_prev = f_0

    for i=1:n
        f[i] ~ Normal(mean = f_prev, precision = 1/σ²)
        y[i] ~ Normal(mean = f[i], precision = 1 / σ²_noise)
        f_prev = f[i]
    end
    return f, y
end

The data is the same as in the GP regression example with some missing data for the prediction part:

Random.seed!(10)
n = 100
σ²_noise = 0.04;
t = collect(range(-2, 2, length=n)); #timeline
f_true = sinc.(t); # true process
f_noisy = f_true + sqrt(σ²_noise) * randn(n); #noisy process

pos = 1:100 
t_obser = t[pos]; # time where we observe data

y_data = Array{Union{Float64,Missing}}(missing, n)
for i in pos 
    y_data[i] = f_noisy[i]
end
for i in 80:100 
    y_data[i] = missing
end

θ = [1., 1.]; # store [l, σ²]
Δt = [t[1]]; # time difference
append!(Δt, t[2:end] - t[1:end-1]);

I also added the following rules for missing data:

@rule NormalMeanPrecision(:μ, Marginalisation) (q_out::Any, q_τ::Missing) = missing
@rule NormalMeanPrecision(:μ, Marginalisation) (q_out::Missing, q_τ::Any) = missing
@rule NormalMeanPrecision(:μ, Marginalisation) (m_out::Missing, q_τ::PointMass, ) = missing

@rule NormalMeanPrecision(:τ, Marginalisation) (q_out::Any, q_μ::Missing) = missing
@rule NormalMeanPrecision(:τ, Marginalisation) (q_out::Missing, q_μ::Any) = missing

@rule typeof(+)(:in1, Marginalisation) (m_out::Missing, m_in2::Any) = missing
@rule typeof(+)(:in1, Marginalisation) (m_out::Any, m_in2::Missing) = missing

And call the inference:

result = inference(
        model = locallevel(n, 1, 1),
        data = (y = y_data,),
        free_energy = true
    )

When running, I am getting the following warning which I do not know how to handle. Do you have a suggestion? Thank you

MarginalRuleMethodError: no method matching rule for the given arguments

Possible fix, define:

@marginalrule NormalMeanPrecision(:out_μ) (m_out::Missing, m_μ::NormalWeightedMeanPrecision, q_τ::PointMass, ) = begin 
    return ...
end
albertpod commented 1 year ago

Hi @SupplyChef,

Thank you for trying out the example on GPs.

This error is because RxInfer.jl tries to compute the free energy for you; there is no rule for free energy computation when handling missing data.

The quick fix would be to turn off free energy computations, e.g.

result = inference(
        model = locallevel(n, 1, 1),
        data = (y = y_data,),
        free_energy = false
    )

I hope @HoangMHNguyen can help you further with this issue.

SupplyChef commented 1 year ago

I see. The reason I was using free energy is because I wanted to optimize the hyper-parameter of the model (σ², σ²_noise). I am thinking now of optimizing the hyper parameters without the missing data, and then inferring the missing data with fixed hyper-parameters. Do you think that's the best/reasonable way forward?

bvdmitri commented 1 year ago

@SupplyChef Without going to much into the details - there is a workaround:

  1. Define the missing marginal rule as follows:
@marginalrule NormalMeanPrecision(:out_μ) (m_out::Missing, m_μ::NormalDistributionsFamily, q_τ::PointMass, ) = begin 
    out = @call_rule NormalMeanPrecision(:out, Marginalisation) (m_μ = m_μ, q_τ = q_τ)
    return @call_marginalrule NormalMeanPrecision(:out_μ) (m_out = out, m_μ = m_μ, q_τ = q_τ)
end
  1. Define auxiliary functions for Bethe Free Energy computation:
Distributions.entropy(::Missing) = ReactiveMP.CountingReal(Float64, -1)

@average_energy NormalMeanPrecision (q_out::Missing, q_μ::Any, q_τ::Any) = begin
    # Assume that a node with a missing point should not contribute to the Free Energy, you may change it though
    return 0 
end

I tried you model in this case and the Free Energy is being computed properly.

Details: This model runs with free_energy = false, because the inference procedure does not require those marginals, but the Bethe Free Energy procedure does require. RxInfer cannot compute the joint marginal distribution given a missing data point, there is simply not enough information to compute those marginals. But it would be mathematically correct to factorize out missing points from the variational distribution. RxInfer does that where it can automatically (e.g. for datavars and constvars), but in your case it happens inside the graph. Generally speaking, support for missing data points is not fully implemented yet and one must be very careful as a lot of functionality is missing 😅. In my example I attempted to enforce mean-field approximation manually.

P.S. Just in case, you can remove these rules if you don't use + operator in you model:

@rule typeof(+)(:in1, Marginalisation) (m_out::Missing, m_in2::Any) = missing
@rule typeof(+)(:in1, Marginalisation) (m_out::Any, m_in2::Missing) = missing
SupplyChef commented 1 year ago

Thank you for the work around and the explanation!

albertpod commented 1 year ago

Can we close this one?

HoangMHNguyen commented 1 year ago

I think we can close this one because Dmitry has explained clearly everything.

bvdmitri commented 1 year ago

Should be a part of the Sharp Bits section. But yea, we can close @albertpod