ReactiveBayes / RxInfer.jl

Julia package for automated Bayesian inference on a factor graph with reactive message passing
https://rxinfer.ml/
MIT License
278 stars 23 forks source link

Add custom node contraction #342

Closed blolt closed 1 month ago

blolt commented 3 months ago

Overview

Note: This is a re-opening of a previous PR, which I inadvertently closed.

This PR resolves #287 (or it will), allowing the user to specify custom composite node contraction through an argument on the infer function and an implementation of GraphPPL.NodeType.

Testing

Three unit tests have been created to test this functionality. The first is a simple test to confirm that the RxInfer backend can be parameterized, and the next two test the infer function on an HGF with contracted nodes. The test "Static Inference With Node Contraction" is currently failing with the following error:

ERROR: The `gcv` model macro does not support positional arguments. Use keyword arguments `gcv(κ = ..., ω = ..., z = ..., x = ..., y = ...)` instead.

However, all keyword arguments are being used for the gcv macro, so it seems likely that the issue lies elsewhere. I have tried different approaches, but none so far have bore fruit (a passing hgf unit test).

The test "Static Inference With Node Contraction 1" was extended from RxInfer doc examples of an HGF and ReactiveMP's GCV node. It was originally passing with the GCV node, but I have not been able to get it to pass with the custom gcv model I introduced. Rules are not yet specified for the custom gcv node. The test is currently failing due to incorrect constraints.

wouterwln commented 3 months ago

Hi @blolt . Thanks for this PR! As far as I can see you're 90% along the way, and are definitely in the right direction. The last puzzle piece I would say is that GraphPPL has to realize that with a ReactiveMPGraphPPLBackend{True}, every Composite node turns into an Atomic node. This could be realized in: https://github.com/ReactiveBayes/RxInfer.jl/blob/858000c9b5c66b1ebe3951dd638fcc62a3352927/src/model/graphppl.jl#L162-L165 which for now defaults to the default backend implementation (which checks that if something is made with the @model macro it is Composite and otherwise it is Atomic. Probably if we make a separate dispatch on ReactiveMPGraphPPLBackend{True} and return GraphPPL.Atomic() the rest of your PR works perfectly. I think the usage of Static is appropriate.

blolt commented 2 months ago

Was it not the intention to let the user define custom node logic for a given backend? This was the example in the original issue,

GraphPPL.NodeType(::RxInferBackend{True}, ::typeof(gcv)) = GraphPPL.Atomic()

which uses a user defined node gcv. Then again, I don't see much of a use for the Static parameter here, since the node-type should be enough for dispatch. In any case, just want to clarify that contracting all composite nodes, rather than just particular kinds of composite nodes, is the intended use-case.

wouterwln commented 2 months ago

Ah, you are right. I will check to see if the current implementation works this way

blolt commented 2 months ago

Thanks Wouter! Just linking a recent paper with some rules on message-passing for HGF as I know one of the goals of this issue is to produce documentation with custom rules. Perhaps this will prove useful. Interesting paper in its own right, too.

https://arxiv.org/pdf/2305.10937

bvdmitri commented 2 months ago

@blolt its a very cool feature, it has been a very busy period in BIASlab, but we didn't forget about your contribution! @wouterwln we should review this before next update meeting on September 18th

wouterwln commented 2 months ago

Hi @blolt, I checked your PR, I would advise you to test the functionality in a simpler (sub-)model; there is a reason we like to contract the gcv submodel into a node, which is because it is actually really hard to run inference inside of it. I tried recreating some of the results with the following simpler model:

@model function submodel(x, z, y)
    p ~ NormalMeanVariance(0, 1)
    y ~ NormalMeanVariance(x + z + p, 1)
end

@model function larger_model(y)
    x ~ NormalMeanVariance(0, 1)
    z ~ NormalMeanVariance(0, 1)
    y ~ submodel(x = x, z = z)
end

This still won't work, and I think it is a bug on our part. As far as I understand, we can label the node as being Atomic, and GraphPPL will label it Stochastic as well, as is default for submodels. However, ReactiveMP.sdtype() will still be Deterministic, which prompts a DeltaMeta meta object which shouldn't happen. @bvdmitri any idea how to fix this? Is there a reason GraphPPL and ReactiveMP's implementation of Stochastic and Deterministic are mixed?

In any case, I would ask you to rewrite the tests with a simpler example, and I think the rest of your PR works as intended, as I was able to isolate this issue quite easily and the rest of the behaviour was as expected. Thank you very much!

bvdmitri commented 2 months ago

Thanks for your work, @blolt! I reviewed the PR and made a few adjustments, particularly to the API. Now, a user only needs to define the node using the @node macro for the contraction to work. For example:

@model function gcv(y, x, z, κ, ω)
      log_σ := κ * z + ω
      σ := exp(log_σ)
      y ~ Normal(mean = x, precision = σ)
  end

@node typeof(gcv) Stochastic [ y, x, z, κ, ω ]

With allow_node_contraction = true, this should be enough to automatically identify the corresponding node and it will try to use the rules instead (after this PR is merged we need to redefine the GCV node as a submodel I guess).

I also updated the tests, but I noticed that inference isn't running for the multi-layer HGF. @wouterwln, could you take a look? The single-layer HGF works perfectly.

bvdmitri commented 2 months ago
wouterwln commented 2 months ago

Unfortunately, the 2-layer HGF also breaks with the existing GCV node. The problem is with the initialization. The problem occurs around the 'topmost' GCV node, which is surrounded by x_2[i], x_2[i - 1], x_3[i], κ_2 and ω_2. Now, because of the initialization bug we found over the past couple of weeks, we can fire off all existing rules, using up all the initialized marginals, which sends the message passing algorithm into a deadlock. We can fix this by initializing not every x_2, but every other x_2 (so x_2[1:2:n]). This makes the inference run, and also proves that (luckily for this PR) the problem is not with the node contraction, and node contraction works perfectly, but (unfortunately for us) this weird behavior is due to some other bug. I'll push the fix and then everything should work as is.

codecov[bot] commented 2 months ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 84.94%. Comparing base (9facd12) to head (f5150d9). Report is 10 commits behind head on main.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #342 +/- ## ========================================== + Coverage 84.84% 84.94% +0.09% ========================================== Files 20 20 Lines 1511 1521 +10 ========================================== + Hits 1282 1292 +10 Misses 229 229 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

wouterwln commented 1 month ago

@bvdmitri I think the PR is in a shape to merge now