Closed albertpod closed 3 years ago
To make ForneyLab.jl
generate correct init()
, we need to provide optional parameters that specify the dimension of Nonlinear node's inputs.
The model specification should be changed as follows:
@RV θ ~ GaussianMeanPrecision(placeholder(:m_θ, dims=(dimensionality,)), placeholder(:W_θ, dims=(dimensionality, dimensionality)))
f(w, x) = 1/(1+exp(-w'x))
for i in 1:T
@RV z[i] ~ GaussianMeanPrecision(inputs[i, :], 1e4*diageye(dimensionality))
@RV x[i] ~ Nonlinear{Sampling}(θ, z[i], g=f, in_variates=[Multivariate, Multivariate], out_variate=Univariate)
@RV y[i] ~ Bernoulli(x[i])
placeholder(y[i], :y, index=i)
end
Thanks to @ThijsvdLaar.
The rest of the issue will be addressed in future PR.
Given the following model
Following algorithm constuction
After running
We get the following error:
DimensionMismatch("first array has length 1 which does not match the length of the second, 2.")
This happens becauseForneyLab
returns improperinit()
function:The messages are supposed to carry
Multivariate
distributions. We can circumvent this issue by creating a custominit()
function, i.e.However, when fixing this issue, we encounter a different problem, i.e.
MethodError: no method matching ruleSPNonlinearSOutNGX(::typeof(f), ::Nothing, ::Message{GaussianWeightedMeanPrecision, Multivariate}, ::Message{GaussianMeanPrecision, Multivariate}; variate=Univariate) Closest candidates are: ruleSPNonlinearSOutNGX(::Function, ::Nothing, ::Message{var"#s158", V} where var"#s158"<:Gaussian...; n_samples) where V<:ForneyLab.VariateType at /Users/apodusenko/.julia/dev/ForneyLab/src/engines/julia/update_rules/nonlinear_sampling.jl:73 got unsupported keyword argument "variate"
The error happens whenForneyLab
tries to compute the following message:For some reason, when I omit the argument
variate=Univariate
the message will be computed. It's strange becausenonlinear_sampling.jl
exports the method with thevariate
argument, i.e.ruleSPNonlinearSOutNGX(g::Function, msg_out::Nothing, msgs_in::Vararg{Message{<:Gaussian, <:VariateType}}; n_samples=default_n_samples, variate)
I can't figure out why it's happening. Any ideas?