ReactiveBayes / ReactiveMP.jl

High-performance reactive message-passing based Bayesian inference engine
MIT License
106 stars 14 forks source link

`Probit` node `:out` marginalisation not defined for `q_in`, which is needed for binary linear classification. #425

Open wmkouw opened 5 days ago

wmkouw commented 5 days ago

In 5SSD0, we have a simple binary classification model:

@model function linear_classification(y,X)

    θ ~ MvNormalMeanCovariance(zeros(D), diageye(D))

    for i in eachindex(y)
        y[i] ~ Probit(dot(θ, X[i]))
    end
end

results = infer(
    model       = linear_classification(),
    data        = (y = y, X = X),
    returnvars  = (θ = KeepLast()),
    predictvars = (y = KeepLast()),
    iterations  = 10,
)

Requesting a prediction will throw:

RuleMethodError: no method matching rule for the given arguments

Possible fix, define:

@rule Probit(:out, Marginalisation) (q_in::NormalMeanVariance, meta::ProbitMeta) = begin 
    return ...
end

ReactiveMP complains that the Probit node's rule for making output predictions does not exist. Upon inspection of source code, I find:

@rule Probit(:out, Marginalisation) (m_in::UnivariateNormalDistributionsFamily, meta::Union{ProbitMeta, Nothing}) = begin
    p = normcdf(mean(m_in) / sqrt(1 + var(m_in)))
    return Bernoulli(p)
end

So, the rule does exist, but for m_in not q_in.

Seeing as binary classification is a pretty important use case, I think we need to have a rule for q_in. Can I just copy-paste or is there some reason we don't have a q_in rule?

albertpod commented 5 days ago

I think it again has to do something with missings in your data (see https://github.com/ReactiveBayes/RxInfer.jl/issues/201). Probit node shouldn't be used in mean-field context.

This works:

using RxInfer
using Random

Random.seed!(123)

N = 100  
D = 2   

θ_true = [1.5, -1.0]  

X = randn(N, D)  
X_vector = [vec(X[i, :]) for i in 1:N]
z = X * θ_true + randn(N) * 0.1     
y = Float64.(z .> 0)  

@model function linear_classification(y, X)

    θ ~ MvNormalMeanCovariance(zeros(D), diageye(D))

    for i in eachindex(y)
        y[i] ~ Probit(dot(θ, X[i]))
    end
end

results = infer(
    model       = linear_classification(),
    data        = (y = y, X = X_vector),
    iterations  = 10,
)

println(mean(results.posteriors[:θ][end]))
println(θ_true)
albertpod commented 5 days ago

If you want to get predictions out of your model with Probit, I suggest to write a different function for computing predictions. Besides, don't forget to use tuples in returnvars and predictvars, i.e. predictvars = (y = KeepLast(),)

albertpod commented 5 days ago

@bvdmitri knows better, but the issue is that introducing predictvars enforces a different constraint (MF), which you don't do explicitly but it occurs behind the scenes, hence the error you're seeing makes sense. It's quite unfortunate, but it's not something that can be resolved easily.

wouterwln commented 5 days ago

You can wrap data in UnfactorizedData to not enforce this MF constraint. so you'd get data = (y = UnfactorizedData(y), x = X_vector) and this should do SP message passing around that node.

I don't know what the rule for VMP would look like though. Maybe we can derive it tomorrow at the office.

wmkouw commented 5 days ago

Thanks for the quick response. There's no rush to this issue. I just wanted to discuss it.

Ok, so ReactiveMP is looking for a rule with q_in because predictvars enforces MeanField() even though I didn't specify MeanField(). That explains that.

But the current m_in rule is not SP, right? It's EP. Since EP is also variational, can't we just copy the m_in rule to q_in? I know it's not technically the solution to a variational rule, but we can just report that (maybe via the ProbitMeta).

wouterwln commented 4 days ago

Specifically, ReactiveMP assumes that data is always factorized out of the joint distribution. This is because if we actually supply data, we know that the posterior marginal for that datapoint is fixed and independent from any other posterior distributions (predictvars is considered data as well, as it is one of the interfaces that is unknown at model construction time. Since GraphPPL doesn't know at this point if you're going to pass data inside of these nodes or if you're going to pass missing, it will assume that it is data and factorize it out). This is one of the implicit assumptions we always make. Now in order to predict something, this might actually not be the case, and we don't always have to make this assumption. In order to override it, I added UnfactorizedData as a wrapper struct for any kind of data that won't automatically factorize out the posterior marginal distribution from the rest of the joint posterior marginal. So for prediction, it might make more sense to wrap it in UnfactorizedData since you might be able to send a SP message to the "data" instead of a VMP message.

As for your heuristic to send this message, there's nothing stopping us from using the EP message as a fallback, but I'd rather not have quick fixes in RMP without a proper derivation.