ReactiveBayes / GraphPPL.jl

DSL for probabilistic models specification and probabilistic programming.
MIT License
27 stars 4 forks source link

Error when running infer due to splatting #240

Open albertpod opened 2 weeks ago

albertpod commented 2 weeks ago

I know the issue with splatting was addressed here, but I am having a trouble with it, when running the infer function with the provided PCA model (see below). The following error is encountered:

ERROR: MethodError: no method matching iterate(::GraphPPL.VariableRef{…})

Closest candidates are:
  iterate(::Revise.LineSkippingIterator)
   @ Revise ~/.julia/packages/Revise/bAgL0/src/relocatable_exprs.jl:70
  iterate(::Revise.LineSkippingIterator, ::Any)
   @ Revise ~/.julia/packages/Revise/bAgL0/src/relocatable_exprs.jl:70
  iterate(::Base.MethodSpecializations)
   @ Base reflection.jl:1148
  ...

Stacktrace:
  [1] macro expansion
    @ ~/.julia/dev/GraphPPL/src/model_macro.jl:543 [inlined]
  [2] macro expansion

The error occurs at the line: https://github.com/ReactiveBayes/GraphPPL.jl/blob/b3ae2a6917e898f3bfe653f538a191f61f7bde34/src/model_macro.jl#L542

The minimum code example to reproduce the error:

using RxInfer

PCA_block(x, w...) = hcat(w...)*x
# PCA_block(x, w1, w2) = [w1 w2]*x

@model function pca_mode(y, components, obs_dim, lat_dim)
    local w
    for j in 1:components
        w[j] ~ MvNormal(μ=zeros(obs_dim), Λ=diageye(obs_dim))
    end

    for i in eachindex(y)
        x[i] ~ MvNormalMeanPrecision(ones(lat_dim), diageye(lat_dim))
        y[i] ~ MvNormal(μ=PCA_block(x[i], w...), Λ=diageye(obs_dim))
        # y[i] ~ MvNormal(μ=PCA_block(x[i], w[1], w[2]), Λ=diageye(obs_dim))
    end
end

components = 2
n_samples = 100
obs_dim = 4
lat_dim = 2

w1 = [2.0, -1.0, 0.5, -0.2] 
w2 = [0.8, 1.5, -0.3, 0.1]

latent_x = [rand(MvNormal(zeros(lat_dim), diageye(lat_dim))) for i in 1:n_samples]

y = [PCA_block(latent_x[i], w1, w2) + rand(MvNormal(zeros(obs_dim), 0.1diageye(obs_dim))) for i in 1:n_samples]

pca_meta = @meta begin
    PCA_block() -> Linearization()
end

initialization = @initialization begin
    μ(w) = MvNormalMeanPrecision(zeros(obs_dim), diageye(obs_dim))
end

result = infer(model=pca_mode(components=components, obs_dim=obs_dim, lat_dim=lat_dim), initialization=initialization, data=(y=y, ), meta=pca_meta, free_energy=true, iterations=5, showprogress=true, returnvars=KeepLast())

The issue appears to be related to the splatting of the w array in the PCA_block function.

bvdmitri commented 2 weeks ago

Can you try and see if it works as a workaround?

μ=PCA_block(in = [ x[i], w... ])

EDIT: ah, sorry, it won't work probably either

bvdmitri commented 2 weeks ago

I think its a real issue, but should be fixable. We need to define iterate, which probably should reuse some code from Base.broadcastable(ref::VariableRef) (it may even call broadcastable?). We may had some justification for not including iterate on VariableRef, but I cannot recall it.