ReactiveBayes / RxInfer.jl

Julia package for automated Bayesian inference on a factor graph with reactive message passing
MIT License
260 stars 24 forks source link

Some errors with CVIProjection #336

Open acertain opened 1 month ago

acertain commented 1 month ago

Error 1:

ERROR: PosDefException: matrix is not Hermitian; Cholesky factorization failed.
Stacktrace:
   [1] non_hermitian_error()
     @ StaticArrays ~/.julia/packages/StaticArrays/MSJcA/src/cholesky.jl:2
   [2] #cholesky#549
     @ ~/.julia/packages/StaticArrays/MSJcA/src/cholesky.jl:4 [inlined]
   [3] cholesky
     @ ~/.julia/packages/StaticArrays/MSJcA/src/cholesky.jl:3 [inlined]
   [4] fastcholesky(input::StaticArraysCore.SMatrix{2, 2, Float64, 4})
     @ StaticArraysCoreExt ~/.julia/packages/FastCholesky/5C3F6/ext/StaticArraysCoreExt.jl:6
   [5] cholinv
     @ ~/.julia/packages/FastCholesky/5C3F6/src/FastCholesky.jl:115 [inlined]
   [6] (::ExponentialFamilyProjection.CVICostGradientObjective{…})(M::ExponentialFamilyManifolds.NaturalParametersManifold{…}, X::RecursiveArrayTools.ArrayPartition{…}, p::RecursiveArrayTools.ArrayPartition{…})
     @ ExponentialFamilyProjection ~/.julia/packages/ExponentialFamilyProjection/kEJjX/src/cvi.jl:22
   [7] get_cost_and_gradient!
     @ ~/.julia/packages/Manopt/OISFZ/src/plans/gradient_plan.jl:125 [inlined]
   [8] get_gradient!
     @ ~/.julia/packages/Manopt/OISFZ/src/plans/gradient_plan.jl:202 [inlined]
   [9] get_gradient!
     @ ~/.julia/packages/Manopt/OISFZ/src/plans/gradient_plan.jl:145 [inlined]
  [10] IdentityUpdateRule
     @ ~/.julia/packages/Manopt/OISFZ/src/solvers/gradient_descent.jl:80 [inlined]
  [11] (::ExponentialFamilyProjection.BoundedNormUpdateRule{…})(mp::Manopt.DefaultManoptProblem{…}, s::Manopt.GradientDescentState{…}, i::Int64)
     @ ExponentialFamilyProjection ~/.julia/packages/ExponentialFamilyProjection/kEJjX/src/manopt/bounded_norm_update_rule.jl:37
  [12] step_solver!(p::Manopt.DefaultManoptProblem{…}, s::Manopt.GradientDescentState{…}, i::Int64)
     @ Manopt ~/.julia/packages/Manopt/OISFZ/src/solvers/gradient_descent.jl:279
  [13] solve!(p::Manopt.DefaultManoptProblem{…}, s::Manopt.GradientDescentState{…})
     @ Manopt ~/.julia/packages/Manopt/OISFZ/src/solvers/solver.jl:137
  [14] gradient_descent!(M::ExponentialFamilyManifolds.NaturalParametersManifold{…}, mgo::Manopt.ManifoldCostGradientObjective{…}, p::RecursiveArrayTools.ArrayPartition{…}; retraction_method::ManifoldsBase.ExponentialRetraction, stepsize::Manopt.ConstantStepsize{…}, stopping_criterion::Manopt.StopWhenAny{…}, debug::Missing, direction::ExponentialFamilyProjection.BoundedNormUpdateRule{…}, X::RecursiveArrayTools.ArrayPartition{…}, kwargs::@Kwargs{})
     @ Manopt ~/.julia/packages/Manopt/OISFZ/src/solvers/gradient_descent.jl:268
  [15] gradient_descent!
     @ ~/.julia/packages/Manopt/OISFZ/src/solvers/gradient_descent.jl:250 [inlined]
  [16] (::ExponentialFamilyProjection.var"#20#22"{…})(buffer::StaticTools.MallocSlabBuffer)
     @ ExponentialFamilyProjection ~/.julia/packages/ExponentialFamilyProjection/kEJjX/src/projected_to.jl:259
  [17] with_buffer(f::ExponentialFamilyProjection.var"#20#22"{…}, ::Val{…}, ::ProjectionParameters{…})
     @ ExponentialFamilyProjection ~/.julia/packages/ExponentialFamilyProjection/kEJjX/src/projected_to.jl:136
  [18] with_buffer(f::ExponentialFamilyProjection.var"#20#22"{…}, parameters::ProjectionParameters{…})
     @ ExponentialFamilyProjection ~/.julia/packages/ExponentialFamilyProjection/kEJjX/src/projected_to.jl:128
  [19] project_to(::ProjectedTo{…}, ::ProjectionExt.var"#1#2"{…}; initialpoint::Nothing, kwargs::@Kwargs{})
     @ ExponentialFamilyProjection ~/.julia/packages/ExponentialFamilyProjection/kEJjX/src/projected_to.jl:253
  [20] constrain_form(constraint::ProjectedTo{…}, context::ProjectionExt.ProjectionContext{…}, fn::Function)
     @ ProjectionExt ~/.julia/packages/RxInfer/iG4wU/ext/ProjectionExt/ProjectionExt.jl:40
  [21] constrain_form
     @ ~/.julia/packages/RxInfer/iG4wU/ext/ProjectionExt/ProjectionExt.jl:25 [inlined]
  [22] constrain_form
     @ ~/.julia/packages/ReactiveMP/DrjJB/src/constraints/form.jl:117 [inlined]
  [23] constrain_form
     @ ~/.julia/packages/ReactiveMP/DrjJB/src/constraints/form.jl:115 [inlined]
  [24] constrain_form_as_message(message::Message{…}, form_constraint::ReactiveMP.WrappedFormConstraint{…})
     @ ReactiveMP ~/.julia/packages/ReactiveMP/DrjJB/src/message.jl:139
  [25] (::ReactiveMP.var"#15#17"{GenericProd, ReactiveMP.WrappedFormConstraint{…}})(messages::Vector{AbstractMessage})
     @ ReactiveMP ~/.julia/packages/ReactiveMP/DrjJB/src/message.jl:149
  [26] (::ReactiveMP.var"#119#120"{ReactiveMP.var"#15#17"{GenericProd, ReactiveMP.WrappedFormConstraint{…}}})(messages::Vector{AbstractMessage})
     @ ReactiveMP ~/.julia/packages/ReactiveMP/DrjJB/src/variables/variable.jl:36
  [27] next_received!(wrapper::Rocket.CollectLatestObservableWrapper{…}, data::DefferedMessage{…}, index::CartesianIndex{…})
     @ Rocket ~/.julia/packages/Rocket/LrFUI/src/observable/collected.jl:103
  [28] on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/observable/collected.jl:93 [inlined]
  [29] scheduled_next!(actor::Rocket.CollectLatestObservableInnerActor{…}, value::DefferedMessage{…}, ::AsapScheduler)
     @ Rocket ~/.julia/packages/Rocket/LrFUI/src/schedulers/asap.jl:23
  [30] on_next!(subject::Subject{AbstractMessage, AsapScheduler, AsapScheduler}, data::DefferedMessage{Tuple{…}, Tuple{…}, ReactiveMP.MessageMapping{…}})
     @ Rocket ~/.julia/packages/Rocket/LrFUI/src/subjects/subject.jl:62
  [31] actor_on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:250 [inlined]
  [32] next!
     @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:202 [inlined]
  [33] on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/subjects/recent.jl:62 [inlined]
  [34] actor_on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:250 [inlined]
  [35] next!
     @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:202 [inlined]
  [36] on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/operators/map.jl:62 [inlined]
  [37] next!
     @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:206 [inlined]
--- the last 2 lines are repeated 1 more time ---
  [40] next_received!
     @ ~/.julia/packages/Rocket/LrFUI/src/observable/combined.jl:101 [inlined]
  [41] on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/observable/combined.jl:68 [inlined]
  [42] next!
     @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:206 [inlined]
--- the last 3 lines are repeated 1 more time ---
  [46] next_received!(wrapper::Rocket.CombineLatestUpdatesActorWrapper{…}, data::Marginal{…}, index::Int64)
     @ Rocket ~/.julia/packages/Rocket/LrFUI/src/observable/combined_updates.jl:72
  [47] on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/observable/combined_updates.jl:34 [inlined]
  [48] scheduled_next!(actor::Rocket.CombineLatestUpdatesInnerActor{…}, value::Marginal{…}, ::AsapScheduler)
     @ Rocket ~/.julia/packages/Rocket/LrFUI/src/schedulers/asap.jl:23
  [49] on_next!(subject::Subject{Marginal, AsapScheduler, AsapScheduler}, data::Marginal{FactorizedJoint{Tuple{NormalMeanVariance{…}, Gamma{…}}}, Nothing})
     @ Rocket ~/.julia/packages/Rocket/LrFUI/src/subjects/subject.jl:62
  [50] actor_on_next!(::BaseActorTrait{Marginal}, actor::Subject{Marginal, AsapScheduler, AsapScheduler}, data::Marginal{FactorizedJoint{Tuple{…}}, Nothing})
     @ Rocket ~/.julia/packages/Rocket/LrFUI/src/actor.jl:250
  [51] next!
     @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:202 [inlined]
  [52] on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/subjects/recent.jl:62 [inlined]
  [53] actor_on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:250 [inlined]
  [54] next!(actor::Rocket.RecentSubjectInstance{Marginal, Subject{Marginal, AsapScheduler, AsapScheduler}}, data::Marginal{FactorizedJoint{Tuple{…}}, Nothing})
     @ Rocket ~/.julia/packages/Rocket/LrFUI/src/actor.jl:202
  [55] next_received!(wrapper::Rocket.CombineLatestUpdatesActorWrapper{…}, data::Tuple{…}, index::Int64)
     @ Rocket ~/.julia/packages/Rocket/LrFUI/src/observable/combined_updates.jl:72
  [56] on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/observable/combined_updates.jl:34 [inlined]
  [57] next!
     @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:206 [inlined]
  [58] on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/operators/map.jl:62 [inlined]
  [59] next!
     @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:206 [inlined]
  [60] next_received!
     @ ~/.julia/packages/Rocket/LrFUI/src/observable/combined.jl:101 [inlined]
  [61] on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/observable/combined.jl:68 [inlined]
  [62] next!
     @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:206 [inlined]
  [63] next_received!
     @ ~/.julia/packages/Rocket/LrFUI/src/observable/combined_updates.jl:72 [inlined]
  [64] on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/observable/combined_updates.jl:34 [inlined]
  [65] scheduled_next!(actor::Rocket.CombineLatestUpdatesInnerActor{…}, value::Message{…}, ::AsapScheduler)
     @ Rocket ~/.julia/packages/Rocket/LrFUI/src/schedulers/asap.jl:23
  [66] on_next!(subject::Subject{Message, AsapScheduler, AsapScheduler}, data::Message{NormalMeanVariance{Float64}, Nothing})
     @ Rocket ~/.julia/packages/Rocket/LrFUI/src/subjects/subject.jl:62
  [67] actor_on_next!(::BaseActorTrait{Message}, actor::Subject{Message, AsapScheduler, AsapScheduler}, data::Message{NormalMeanVariance{Float64}, Nothing})
     @ Rocket ~/.julia/packages/Rocket/LrFUI/src/actor.jl:250
  [68] next!
     @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:202 [inlined]
  [69] on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/subjects/recent.jl:62 [inlined]
  [70] actor_on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:250 [inlined]
  [71] next!(actor::Rocket.RecentSubjectInstance{Message, Subject{Message, AsapScheduler, AsapScheduler}}, data::Message{NormalMeanVariance{Float64}, Nothing})
     @ Rocket ~/.julia/packages/Rocket/LrFUI/src/actor.jl:202
  [72] next_received!(wrapper::Rocket.CollectLatestObservableWrapper{…}, data::DefferedMessage{…}, index::CartesianIndex{…})
     @ Rocket ~/.julia/packages/Rocket/LrFUI/src/observable/collected.jl:104
  [73] on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/observable/collected.jl:93 [inlined]
  [74] scheduled_next!(actor::Rocket.CollectLatestObservableInnerActor{…}, value::DefferedMessage{…}, ::AsapScheduler)
     @ Rocket ~/.julia/packages/Rocket/LrFUI/src/schedulers/asap.jl:23
  [75] on_next!(subject::Subject{AbstractMessage, AsapScheduler, AsapScheduler}, data::DefferedMessage{Nothing, Tuple{…}, ReactiveMP.MessageMapping{…}})
     @ Rocket ~/.julia/packages/Rocket/LrFUI/src/subjects/subject.jl:62
  [76] actor_on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:250 [inlined]
  [77] next!
     @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:202 [inlined]
  [78] on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/subjects/recent.jl:62 [inlined]
  [79] actor_on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:250 [inlined]
  [80] next!
     @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:202 [inlined]
  [81] on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/operators/map.jl:62 [inlined]
  [82] next!
     @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:206 [inlined]
  [83] next_received!
     @ ~/.julia/packages/Rocket/LrFUI/src/observable/combined.jl:101 [inlined]
  [84] on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/observable/combined.jl:68 [inlined]
  [85] next!
     @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:206 [inlined]
  [86] next_received!
     @ ~/.julia/packages/Rocket/LrFUI/src/observable/combined_updates.jl:72 [inlined]
  [87] on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/observable/combined_updates.jl:34 [inlined]
  [88] scheduled_next!(actor::Rocket.CombineLatestUpdatesInnerActor{…}, value::Marginal{…}, ::AsapScheduler)
     @ Rocket ~/.julia/packages/Rocket/LrFUI/src/schedulers/asap.jl:23
  [89] on_next!(subject::Subject{Marginal, AsapScheduler, AsapScheduler}, data::Marginal{PointMass{Float64}, Nothing})
     @ Rocket ~/.julia/packages/Rocket/LrFUI/src/subjects/subject.jl:62
  [90] actor_on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:250 [inlined]
  [91] next!
     @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:202 [inlined]
  [92] on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/subjects/recent.jl:62 [inlined]
  [93] actor_on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:250 [inlined]
  [94] next!
     @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:202 [inlined]
  [95] on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/operators/map.jl:62 [inlined]
  [96] scheduled_next!(actor::Rocket.MapActor{Message, Rocket.RecentSubjectInstance{…}, typeof(as_marginal)}, value::Message{PointMass{…}, Nothing}, ::AsapScheduler)
     @ Rocket ~/.julia/packages/Rocket/LrFUI/src/schedulers/asap.jl:23
  [97] on_next!(subject::Subject{Message, AsapScheduler, AsapScheduler}, data::Message{PointMass{Float64}, Nothing})
     @ Rocket ~/.julia/packages/Rocket/LrFUI/src/subjects/subject.jl:62
  [98] actor_on_next!(::BaseActorTrait{Message}, actor::Subject{Message, AsapScheduler, AsapScheduler}, data::Message{PointMass{Float64}, Nothing})
     @ Rocket ~/.julia/packages/Rocket/LrFUI/src/actor.jl:250
  [99] next!
     @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:202 [inlined]
 [100] on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/subjects/recent.jl:62 [inlined]
 [101] actor_on_next!
     @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:250 [inlined]
 [102] next!
     @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:202 [inlined]
 [103] update!
     @ ~/.julia/packages/ReactiveMP/DrjJB/src/variables/data.jl:85 [inlined]
 [104] update!(datavar::DataVariable{Rocket.RecentSubjectInstance{Message, Subject{…}}, ReactiveMP.MarginalObservable}, data::Float64)
     @ ReactiveMP ~/.julia/packages/ReactiveMP/DrjJB/src/variables/data.jl:84
 [105] batch_inference(; model::GraphPPL.ModelGenerator{…}, data::@NamedTuple{…}, initialization::RxInfer.InitSpecification, constraints::GraphPPL.Constraints{…}, meta::GraphPPL.MetaSpecification, options::Nothing, returnvars::Nothing, predictvars::Nothing, iterations::Nothing, free_energy::Bool, free_energy_diagnostics::Tuple{…}, showprogress::Bool, callbacks::Nothing, addons::Nothing, postprocess::DefaultPostprocess, warn::Bool, catch_exception::Bool)
     @ RxInfer ~/.julia/packages/RxInfer/iG4wU/src/inference/batch.jl:300
 [106] batch_inference
     @ ~/.julia/packages/RxInfer/iG4wU/src/inference/batch.jl:94 [inlined]
 [107] #infer#244
     @ ~/.julia/packages/RxInfer/iG4wU/src/inference/inference.jl:306 [inlined]
 [108] infer
     @ ~/.julia/packages/RxInfer/iG4wU/src/inference/inference.jl:258 [inlined]
 [109] do_infer()
     @ Main ~/Sync/Code/scripts/qself/old/julia-rxinfer/error_repro.jl:30
Some type information was truncated. Use `show(err)` to see complete types.

Model:

using ExponentialFamilyProjection
using RxInfer

decay(x, dt, f) = x * exp(-dt / f)

@model function model(dummy)
    H_decay ~ GammaShapeRate(1, 1)
    H_prev ~ NormalMeanVariance(5.0, 0.25)
    H := decay(H_prev, 1.0, H_decay)
    dummy ~ NormalMeanVariance(H, 0.5)
end

@constraints function mk_constraints()
    q(H) :: ProjectedTo(NormalMeanVariance)
    q(H_prev) :: ProjectedTo(NormalMeanVariance)
    q(H_decay) :: ProjectedTo(NormalMeanVariance)
end

@meta function model_meta()
    decay() -> CVIProjection()
end

@initialization function init()
    q(H_decay)    = NormalMeanVariance(1.0, 1.0)
    q(H)          = NormalMeanVariance(1.0, 1.0)
end

function do_infer()
    y = infer(
        model=model(),
        data=(
            dummy = 2.0,
           ),
        constraints=mk_constraints(),
        meta=model_meta(),
        initialization=init(),
    )
    return y
end

Error 2:


ERROR: MethodError: objects of type ExponentialFamilyDistribution{NormalMeanVariance, RecursiveArrayTools.ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Float64}}}, Nothing, Nothing} are not callable
Stacktrace:
  [1] map!
    @ ./abstractarray.jl:3278 [inlined]
  [2] #5
    @ ~/.julia/packages/BayesBase/8W6zj/src/statsfuns.jl:438 [inlined]
  [3] InplaceLogpdf
    @ ~/.julia/packages/BayesBase/8W6zj/src/statsfuns.jl:433 [inlined]
  [4] prepare_state!(state::ExponentialFamilyProjection.ControlVariateStrategyState{…}, strategy::ExponentialFamilyProjection.ControlVariateStrategy{…}, targetfn::BayesBase.InplaceLogpdf{…}, distribution::ExponentialFamilyDistribution{…}, supplementary_η::Tuple{})
    @ ExponentialFamilyProjection ~/.julia/packages/ExponentialFamilyProjection/kEJjX/src/strategies/control_variate.jl:185
  [5] prepare_state!(::Nothing, strategy::ExponentialFamilyProjection.ControlVariateStrategy{…}, targetfn::BayesBase.InplaceLogpdf{…}, distribution::ExponentialFamilyDistribution{…}, supplementary_η::Tuple{})
    @ ExponentialFamilyProjection ~/.julia/packages/ExponentialFamilyProjection/kEJjX/src/strategies/control_variate.jl:112
  [6] prepare_state!(strategy::ExponentialFamilyProjection.ControlVariateStrategy{…}, targetfn::ExponentialFamilyDistribution{…}, distribution::ExponentialFamilyDistribution{…}, supplementary_η::Tuple{})
    @ ExponentialFamilyProjection ~/.julia/packages/ExponentialFamilyProjection/kEJjX/src/strategies/control_variate.jl:52
  [7] project_to(::ProjectedTo{…}, ::ExponentialFamilyDistribution{…}; initialpoint::Nothing, kwargs::@Kwargs{})
    @ ExponentialFamilyProjection ~/.julia/packages/ExponentialFamilyProjection/kEJjX/src/projected_to.jl:241
  [8] constrain_form(constraint::ProjectedTo{…}, context::ProjectionExt.ProjectionContext{…}, fn::ExponentialFamilyDistribution{…})
    @ ProjectionExt ~/.julia/packages/RxInfer/iG4wU/ext/ProjectionExt/ProjectionExt.jl:40
  [9] constrain_form
    @ ~/.julia/packages/ReactiveMP/DrjJB/src/constraints/form.jl:117 [inlined]
 [10] constrain_form
    @ ~/.julia/packages/ReactiveMP/DrjJB/src/constraints/form.jl:115 [inlined]
 [11] constrain_form_as_message(message::Message{…}, form_constraint::ReactiveMP.WrappedFormConstraint{…})
    @ ReactiveMP ~/.julia/packages/ReactiveMP/DrjJB/src/message.jl:139
 [12] (::ReactiveMP.var"#15#17"{GenericProd, ReactiveMP.WrappedFormConstraint{…}})(messages::Vector{AbstractMessage})
    @ ReactiveMP ~/.julia/packages/ReactiveMP/DrjJB/src/message.jl:149
 [13] (::ReactiveMP.var"#119#120"{ReactiveMP.var"#15#17"{GenericProd, ReactiveMP.WrappedFormConstraint{…}}})(messages::Vector{AbstractMessage})
    @ ReactiveMP ~/.julia/packages/ReactiveMP/DrjJB/src/variables/variable.jl:36
 [14] next_received!(wrapper::Rocket.CollectLatestObservableWrapper{…}, data::DefferedMessage{…}, index::CartesianIndex{…})
    @ Rocket ~/.julia/packages/Rocket/LrFUI/src/observable/collected.jl:103
 [15] on_next!
    @ ~/.julia/packages/Rocket/LrFUI/src/observable/collected.jl:93 [inlined]
 [16] scheduled_next!(actor::Rocket.CollectLatestObservableInnerActor{…}, value::DefferedMessage{…}, ::AsapScheduler)
    @ Rocket ~/.julia/packages/Rocket/LrFUI/src/schedulers/asap.jl:23
 [17] on_next!(subject::Subject{AbstractMessage, AsapScheduler, AsapScheduler}, data::DefferedMessage{Nothing, Tuple{…}, ReactiveMP.MessageMapping{…}})
    @ Rocket ~/.julia/packages/Rocket/LrFUI/src/subjects/subject.jl:62
 [18] actor_on_next!
    @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:250 [inlined]
 [19] next!
    @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:202 [inlined]
 [20] on_next!
    @ ~/.julia/packages/Rocket/LrFUI/src/subjects/recent.jl:62 [inlined]
 [21] actor_on_next!
    @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:250 [inlined]
 [22] next!
    @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:202 [inlined]
 [23] on_next!
    @ ~/.julia/packages/Rocket/LrFUI/src/operators/map.jl:62 [inlined]
 [24] next!
    @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:206 [inlined]
 [25] next_received!
    @ ~/.julia/packages/Rocket/LrFUI/src/observable/combined.jl:101 [inlined]
 [26] on_next!
    @ ~/.julia/packages/Rocket/LrFUI/src/observable/combined.jl:68 [inlined]
 [27] next!
    @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:206 [inlined]
 [28] next_received!
    @ ~/.julia/packages/Rocket/LrFUI/src/observable/combined_updates.jl:72 [inlined]
 [29] on_next!
    @ ~/.julia/packages/Rocket/LrFUI/src/observable/combined_updates.jl:34 [inlined]
 [30] scheduled_next!(actor::Rocket.CombineLatestUpdatesInnerActor{…}, value::Marginal{…}, ::AsapScheduler)
    @ Rocket ~/.julia/packages/Rocket/LrFUI/src/schedulers/asap.jl:23
 [31] on_next!(subject::Subject{Marginal, AsapScheduler, AsapScheduler}, data::Marginal{PointMass{Float64}, Nothing})
    @ Rocket ~/.julia/packages/Rocket/LrFUI/src/subjects/subject.jl:62
 [32] actor_on_next!
    @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:250 [inlined]
 [33] next!
    @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:202 [inlined]
 [34] on_next!
    @ ~/.julia/packages/Rocket/LrFUI/src/subjects/recent.jl:62 [inlined]
 [35] actor_on_next!
    @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:250 [inlined]
 [36] next!
    @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:202 [inlined]
 [37] on_next!
    @ ~/.julia/packages/Rocket/LrFUI/src/operators/map.jl:62 [inlined]
 [38] scheduled_next!(actor::Rocket.MapActor{Message, Rocket.RecentSubjectInstance{…}, typeof(as_marginal)}, value::Message{PointMass{…}, Nothing}, ::AsapScheduler)
    @ Rocket ~/.julia/packages/Rocket/LrFUI/src/schedulers/asap.jl:23
 [39] on_next!(subject::Subject{Message, AsapScheduler, AsapScheduler}, data::Message{PointMass{Float64}, Nothing})
    @ Rocket ~/.julia/packages/Rocket/LrFUI/src/subjects/subject.jl:62
 [40] actor_on_next!(::BaseActorTrait{Message}, actor::Subject{Message, AsapScheduler, AsapScheduler}, data::Message{PointMass{Float64}, Nothing})
    @ Rocket ~/.julia/packages/Rocket/LrFUI/src/actor.jl:250
 [41] next!
    @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:202 [inlined]
 [42] on_next!
    @ ~/.julia/packages/Rocket/LrFUI/src/subjects/recent.jl:62 [inlined]
 [43] actor_on_next!
    @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:250 [inlined]
 [44] next!
    @ ~/.julia/packages/Rocket/LrFUI/src/actor.jl:202 [inlined]
 [45] update!
    @ ~/.julia/packages/ReactiveMP/DrjJB/src/variables/data.jl:85 [inlined]
 [46] update!(datavar::DataVariable{Rocket.RecentSubjectInstance{Message, Subject{…}}, ReactiveMP.MarginalObservable}, data::Float64)
    @ ReactiveMP ~/.julia/packages/ReactiveMP/DrjJB/src/variables/data.jl:84
 [47] batch_inference(; model::GraphPPL.ModelGenerator{…}, data::@NamedTuple{…}, initialization::RxInfer.InitSpecification, constraints::GraphPPL.Constraints{…}, meta::GraphPPL.MetaSpecification, options::Nothing, returnvars::Nothing, predictvars::Nothing, iterations::Nothing, free_energy::Bool, free_energy_diagnostics::Tuple{…}, showprogress::Bool, callbacks::Nothing, addons::Nothing, postprocess::DefaultPostprocess, warn::Bool, catch_exception::Bool)
    @ RxInfer ~/.julia/packages/RxInfer/iG4wU/src/inference/batch.jl:300
 [48] batch_inference
    @ ~/.julia/packages/RxInfer/iG4wU/src/inference/batch.jl:94 [inlined]
 [49] #infer#244
    @ ~/.julia/packages/RxInfer/iG4wU/src/inference/inference.jl:306 [inlined]
 [50] infer
    @ ~/.julia/packages/RxInfer/iG4wU/src/inference/inference.jl:258 [inlined]
 [51] do_infer()
    @ Main ~/Sync/Code/scripts/qself/old/julia-rxinfer/error_repro.jl:27
Some type information was truncated. Use `show(err)` to see complete types.

Model:

using ExponentialFamilyProjection
using RxInfer

decay(x, dt, f) = x * exp(-dt / f)

@model function model(dummy)
    H_prev ~ NormalMeanVariance(5.0, 0.25)
    H := decay(H_prev, 1.0, 1.0)
    dummy ~ NormalMeanVariance(H, 0.5)
end

@constraints function mk_constraints()
    q(H) :: ProjectedTo(NormalMeanVariance)
    q(H_prev) :: ProjectedTo(NormalMeanVariance)
end

@meta function model_meta()
    decay() -> CVIProjection()
end

@initialization function init()
    q(H)          = NormalMeanVariance(1.0, 1.0)
end

function do_infer()
    y = infer(
        model=model(),
        data=(
            dummy = 2.0,
           ),
        constraints=mk_constraints(),
        meta=model_meta(),
        initialization=init(),
    )
    return y
end
ismailsenoz commented 1 month ago

Hi @acertain! Thanks for trying out RxInfer. I was able to run the first model by changing the constraint on H_decay to a gamma constraint.

using ExponentialFamilyProjection
using RxInfer

decay(x, dt, f) = x * exp(-dt / f)

@model function model(dummy)
    H_decay ~ GammaShapeRate(1, 1)
    H_prev ~ NormalMeanVariance(5.0, 0.25)
    H := decay(H_prev, 1.0, H_decay)
    dummy ~ NormalMeanVariance(H, 0.5)
end

@constraints function mk_constraints()
    q(H) :: ProjectedTo(NormalMeanVariance)
    q(H_prev) :: ProjectedTo(NormalMeanVariance)
    q(H_decay) :: ProjectedTo(Gamma)
end

@meta function model_meta()
    decay() -> CVIProjection()
end

@initialization function init()
    q(H_decay)    = Gamma(1.0, 1.0)
    q(H)          = NormalMeanVariance(1.0, 1.0)
end

function do_infer()
    y = infer(
        model=model(),
        data=(
            dummy = 2.0,
           ),
        constraints=mk_constraints(),
        meta=model_meta(),
        initialization=init(),
    )
    return y
end

do_infer()

As for the second model, you actually caught a bug. We will fix the bug in future releases but for now the following code should do the job

using ExponentialFamilyProjection
using RxInfer, ReactiveMP
RxInferProjectionExt = Base.get_extension(RxInfer, :ProjectionExt)
using .RxInferProjectionExt

function ReactiveMP.constrain_form(constraint::ProjectedTo, context::RxInferProjectionExt.ProjectionContext, something::Union{Distribution,ExponentialFamilyDistribution})
    T = ExponentialFamilyProjection.get_projected_to_type(constraint)
    D = ExponentialFamily.exponential_family_typetag(something)
    if T === D
        result = convert(D, something)
        context.previous = result
        return result
    else
        return ReactiveMP.constrain_form(constraint, context, (x) -> logpdf(something, x))
    end
end

decay(x, dt, f) = x * exp(-dt / f)

@model function model(dummy)
    H_prev ~ NormalMeanVariance(5.0, 0.25)
    H := decay(H_prev, 1.0, 1.0)
    dummy ~ NormalMeanVariance(H, 0.5)
end

@constraints function mk_constraints()
    q(H) :: ProjectedTo(NormalMeanVariance)
    q(H_prev) :: ProjectedTo(NormalMeanVariance)
end

@meta function model_meta()
    decay() -> CVIProjection()
end

@initialization function init()
    q(H)          = NormalMeanVariance(1.0, 1.0)
end

function do_infer()
    y = infer(
        model=model(),
        data=(
            dummy = 2.0,
           ),
        constraints=mk_constraints(),
        meta=model_meta(),
        initialization=init(),
    )
    return y
end

do_infer()