ReactiveBayes / GraphPPL.jl

DSL for probabilistic models specification and probabilistic programming.
MIT License
31 stars 5 forks source link

Pass multivariate data as univariate (collected) input to function #246

Open wouterwln opened 2 months ago

wouterwln commented 2 months ago

Let's look at the following model:

function dot end

@model function foo(x, y)
    local w
    for i in 1:length(i)
        w[i] ~ Normal(0, 1)
        x[i] ~ Normal(0, 1)
    end
    y ~ dot(x, w)
end

Now x will be created as a vector of data variables because we call x[i] ~ .... However, when we pass it do dot, it is still a ProxyLabel with maycreate=True() since that is how we pass it to the model. This will under the hood call getorcreate! without an interface and hence throw ERROR: Variable x is already a vector variable in the model. Nasty bug and I don't really know how to fix it (yet).

blolt commented 1 month ago

Ran into this recently when specifying a Poisson GLM:

@model function poisson_glm(X, y, n, m)
    local θ
    for j in 1:m
        θ[j] ~ Normal(0, 1)
    end

    for i in 1:n
        λ[i] = dot(X[i, :], θ)
        y[i] ~ Poisson(exp(λ[i]))
    end
end

Would achieve a nice generalization of this Turing.jl example.

I'm looking to pick up more on RxInfer's backend packages, so this one might make sense since I stumbled across it independently. Is there a design for GraphPPL I could reference?

Edit: Will start here: https://reactivebayes.github.io/GraphPPL.jl/stable/developers_guide/#Developers-guide

wouterwln commented 1 month ago

Hi @blolt , thanks for checking this out! Indeed, the Developers Guide is the closest thing we have to a description of the design of GraphPPL. I'll try to give some additional pointers. GraphPPL is split in two (maybe 3 but for the sake of this argument let's keep it at 2) separate modules: A graph engine (containing code for the creation and manipulation of a probabilistic model represented as a factor graph) and a metaprogramming module (which transforms user code into code the graph engine can interpret). We're mainly interested in the graph engine part here.

(The first thing to note is that in GraphPPL, we don't represent the model as an FFG, but we have factor nodes and variable nodes, so the entire graph is bipartite)

Whenever we create a (sub)-model, we assume that we have all interfaces (arguments to the function) available. For the top level model, this is trivial, but for nested models, there might be some stuff we would have to create. Also, for data, as is in this case, we don't know if we would have to pass X as a matrix-variate RV, a vector of multivariate RVs or a matrix of univariate RVs. That's why we came up with a clever trick: Whenever you pass something to a node/submodel, we don't (yet) materialize the variable node, but we pass a ProxyLabel: https://github.com/ReactiveBayes/GraphPPL.jl/blob/3c62c3f507a080df5c2055d70bb852217de6a25e/src/graph_engine.jl#L289-L301 This maycreate field denotes if we are allowed to create a new variable node if we use the variable in the creation of an atomic factor node. Now, this field is handled incorrectly, since if we use a label with maycreate=True() on the right hand side of an equation (as I'm doing in y ~ dot(x, w) since x here is still a ProxyLabel because it is supplied outside of the model) it will still try to create x instead of fetching it from the existing variables. This is all very deep down in the meticulous detail of GraphPPLs design so it's okay if you do not understand any of this. If you are interested in the internal workings of GraphPPL, I would start by playing around with a simple GraphPPL.Model, build a model with the nested model functionalities, look at the GraphPPL.Context attached to this and get a feeling for what objects live where.

I think in the end the bug is here: https://github.com/ReactiveBayes/GraphPPL.jl/blob/3c62c3f507a080df5c2055d70bb852217de6a25e/src/graph_engine.jl#L325-L327 Since proxied.maycreate | maycreate will return True() if X could be created, even though it should not be created in this specific instance. Hope this helps!