Closed blolt closed 1 month ago
Hi @blolt . Thanks for this PR! As far as I can see you're 90% along the way, and are definitely in the right direction. The last puzzle piece I would say is that GraphPPL
has to realize that with a ReactiveMPGraphPPLBackend{True}
, every Composite
node turns into an Atomic
node. This could be realized in:
https://github.com/ReactiveBayes/RxInfer.jl/blob/858000c9b5c66b1ebe3951dd638fcc62a3352927/src/model/graphppl.jl#L162-L165
which for now defaults to the default backend implementation (which checks that if something is made with the @model
macro it is Composite
and otherwise it is Atomic
. Probably if we make a separate dispatch on ReactiveMPGraphPPLBackend{True}
and return GraphPPL.Atomic()
the rest of your PR works perfectly. I think the usage of Static
is appropriate.
Was it not the intention to let the user define custom node logic for a given backend? This was the example in the original issue,
GraphPPL.NodeType(::RxInferBackend{True}, ::typeof(gcv)) = GraphPPL.Atomic()
which uses a user defined node gcv
. Then again, I don't see much of a use for the Static parameter here, since the node-type should be enough for dispatch. In any case, just want to clarify that contracting all composite nodes, rather than just particular kinds of composite nodes, is the intended use-case.
Ah, you are right. I will check to see if the current implementation works this way
Thanks Wouter! Just linking a recent paper with some rules on message-passing for HGF as I know one of the goals of this issue is to produce documentation with custom rules. Perhaps this will prove useful. Interesting paper in its own right, too.
@blolt its a very cool feature, it has been a very busy period in BIASlab, but we didn't forget about your contribution! @wouterwln we should review this before next update meeting on September 18th
Hi @blolt, I checked your PR, I would advise you to test the functionality in a simpler (sub-)model; there is a reason we like to contract the gcv
submodel into a node, which is because it is actually really hard to run inference inside of it. I tried recreating some of the results with the following simpler model:
@model function submodel(x, z, y)
p ~ NormalMeanVariance(0, 1)
y ~ NormalMeanVariance(x + z + p, 1)
end
@model function larger_model(y)
x ~ NormalMeanVariance(0, 1)
z ~ NormalMeanVariance(0, 1)
y ~ submodel(x = x, z = z)
end
This still won't work, and I think it is a bug on our part. As far as I understand, we can label the node as being Atomic
, and GraphPPL
will label it Stochastic
as well, as is default for submodels. However, ReactiveMP.sdtype()
will still be Deterministic
, which prompts a DeltaMeta
meta object which shouldn't happen. @bvdmitri any idea how to fix this? Is there a reason GraphPPL
and ReactiveMP
's implementation of Stochastic
and Deterministic
are mixed?
In any case, I would ask you to rewrite the tests with a simpler example, and I think the rest of your PR works as intended, as I was able to isolate this issue quite easily and the rest of the behaviour was as expected. Thank you very much!
Thanks for your work, @blolt! I reviewed the PR and made a few adjustments, particularly to the API. Now, a user only needs to define the node using the @node
macro for the contraction to work. For example:
@model function gcv(y, x, z, κ, ω)
log_σ := κ * z + ω
σ := exp(log_σ)
y ~ Normal(mean = x, precision = σ)
end
@node typeof(gcv) Stochastic [ y, x, z, κ, ω ]
With allow_node_contraction = true
, this should be enough to automatically identify the corresponding node and it will try to use the rules instead (after this PR is merged we need to redefine the GCV
node as a submodel I guess).
I also updated the tests, but I noticed that inference isn't running for the multi-layer HGF. @wouterwln, could you take a look? The single-layer HGF works perfectly.
Unfortunately, the 2-layer HGF also breaks with the existing GCV node. The problem is with the initialization. The problem occurs around the 'topmost' GCV node, which is surrounded by x_2[i], x_2[i - 1], x_3[i], κ_2
and ω_2
. Now, because of the initialization bug we found over the past couple of weeks, we can fire off all existing rules, using up all the initialized marginals, which sends the message passing algorithm into a deadlock. We can fix this by initializing not every x_2
, but every other x_2
(so x_2[1:2:n]
). This makes the inference run, and also proves that (luckily for this PR) the problem is not with the node contraction, and node contraction works perfectly, but (unfortunately for us) this weird behavior is due to some other bug. I'll push the fix and then everything should work as is.
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 84.94%. Comparing base (
9facd12
) to head (f5150d9
). Report is 10 commits behind head on main.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
@bvdmitri I think the PR is in a shape to merge now
Overview
Note: This is a re-opening of a previous PR, which I inadvertently closed.
This PR resolves #287 (or it will), allowing the user to specify custom composite node contraction through an argument on the infer function and an implementation of
GraphPPL.NodeType
.Testing
Three unit tests have been created to test this functionality. The first is a simple test to confirm that the RxInfer backend can be parameterized, and the next two test the infer function on an HGF with contracted nodes. The test
"Static Inference With Node Contraction"
is currently failing with the following error:However, all keyword arguments are being used for the
gcv
macro, so it seems likely that the issue lies elsewhere. I have tried different approaches, but none so far have bore fruit (a passing hgf unit test).The test
"Static Inference With Node Contraction 1"
was extended from RxInfer doc examples of an HGF and ReactiveMP's GCV node. It was originally passing with theGCV
node, but I have not been able to get it to pass with the customgcv
model I introduced. Rules are not yet specified for the custom gcv node. The test is currently failing due to incorrect constraints.