Open wmkouw opened 5 days ago
I think it again has to do something with missings in your data (see https://github.com/ReactiveBayes/RxInfer.jl/issues/201). Probit node shouldn't be used in mean-field context.
This works:
using RxInfer
using Random
Random.seed!(123)
N = 100
D = 2
θ_true = [1.5, -1.0]
X = randn(N, D)
X_vector = [vec(X[i, :]) for i in 1:N]
z = X * θ_true + randn(N) * 0.1
y = Float64.(z .> 0)
@model function linear_classification(y, X)
θ ~ MvNormalMeanCovariance(zeros(D), diageye(D))
for i in eachindex(y)
y[i] ~ Probit(dot(θ, X[i]))
end
end
results = infer(
model = linear_classification(),
data = (y = y, X = X_vector),
iterations = 10,
)
println(mean(results.posteriors[:θ][end]))
println(θ_true)
If you want to get predictions out of your model with Probit
, I suggest to write a different function for computing predictions. Besides, don't forget to use tuples in returnvars
and predictvars
, i.e. predictvars = (y = KeepLast(),)
@bvdmitri knows better, but the issue is that introducing predictvars enforces a different constraint (MF), which you don't do explicitly but it occurs behind the scenes, hence the error you're seeing makes sense. It's quite unfortunate, but it's not something that can be resolved easily.
You can wrap data in UnfactorizedData
to not enforce this MF constraint. so you'd get data = (y = UnfactorizedData(y), x = X_vector)
and this should do SP message passing around that node.
I don't know what the rule for VMP would look like though. Maybe we can derive it tomorrow at the office.
Thanks for the quick response. There's no rush to this issue. I just wanted to discuss it.
Ok, so ReactiveMP is looking for a rule with q_in
because predictvars
enforces MeanField()
even though I didn't specify MeanField()
. That explains that.
But the current m_in
rule is not SP, right? It's EP. Since EP is also variational, can't we just copy the m_in
rule to q_in
? I know it's not technically the solution to a variational rule, but we can just report that (maybe via the ProbitMeta).
Specifically, ReactiveMP
assumes that data is always factorized out of the joint distribution. This is because if we actually supply data, we know that the posterior marginal for that datapoint is fixed and independent from any other posterior distributions (predictvars is considered data as well, as it is one of the interfaces that is unknown at model construction time. Since GraphPPL
doesn't know at this point if you're going to pass data inside of these nodes or if you're going to pass missing, it will assume that it is data and factorize it out). This is one of the implicit assumptions we always make. Now in order to predict something, this might actually not be the case, and we don't always have to make this assumption. In order to override it, I added UnfactorizedData
as a wrapper struct for any kind of data that won't automatically factorize out the posterior marginal distribution from the rest of the joint posterior marginal. So for prediction, it might make more sense to wrap it in UnfactorizedData
since you might be able to send a SP message to the "data" instead of a VMP message.
As for your heuristic to send this message, there's nothing stopping us from using the EP message as a fallback, but I'd rather not have quick fixes in RMP without a proper derivation.
In 5SSD0, we have a simple binary classification model:
Requesting a prediction will throw:
ReactiveMP complains that the Probit node's rule for making output predictions does not exist. Upon inspection of source code, I find:
So, the rule does exist, but for
m_in
notq_in
.Seeing as binary classification is a pretty important use case, I think we need to have a rule for
q_in
. Can I just copy-paste or is there some reason we don't have aq_in
rule?