Open acertain opened 1 month ago
Hi @acertain! Thanks for trying out RxInfer. I was able to run the first model by changing the constraint on H_decay to a gamma constraint.
using ExponentialFamilyProjection
using RxInfer
decay(x, dt, f) = x * exp(-dt / f)
@model function model(dummy)
H_decay ~ GammaShapeRate(1, 1)
H_prev ~ NormalMeanVariance(5.0, 0.25)
H := decay(H_prev, 1.0, H_decay)
dummy ~ NormalMeanVariance(H, 0.5)
end
@constraints function mk_constraints()
q(H) :: ProjectedTo(NormalMeanVariance)
q(H_prev) :: ProjectedTo(NormalMeanVariance)
q(H_decay) :: ProjectedTo(Gamma)
end
@meta function model_meta()
decay() -> CVIProjection()
end
@initialization function init()
q(H_decay) = Gamma(1.0, 1.0)
q(H) = NormalMeanVariance(1.0, 1.0)
end
function do_infer()
y = infer(
model=model(),
data=(
dummy = 2.0,
),
constraints=mk_constraints(),
meta=model_meta(),
initialization=init(),
)
return y
end
do_infer()
As for the second model, you actually caught a bug. We will fix the bug in future releases but for now the following code should do the job
using ExponentialFamilyProjection
using RxInfer, ReactiveMP
RxInferProjectionExt = Base.get_extension(RxInfer, :ProjectionExt)
using .RxInferProjectionExt
function ReactiveMP.constrain_form(constraint::ProjectedTo, context::RxInferProjectionExt.ProjectionContext, something::Union{Distribution,ExponentialFamilyDistribution})
T = ExponentialFamilyProjection.get_projected_to_type(constraint)
D = ExponentialFamily.exponential_family_typetag(something)
if T === D
result = convert(D, something)
context.previous = result
return result
else
return ReactiveMP.constrain_form(constraint, context, (x) -> logpdf(something, x))
end
end
decay(x, dt, f) = x * exp(-dt / f)
@model function model(dummy)
H_prev ~ NormalMeanVariance(5.0, 0.25)
H := decay(H_prev, 1.0, 1.0)
dummy ~ NormalMeanVariance(H, 0.5)
end
@constraints function mk_constraints()
q(H) :: ProjectedTo(NormalMeanVariance)
q(H_prev) :: ProjectedTo(NormalMeanVariance)
end
@meta function model_meta()
decay() -> CVIProjection()
end
@initialization function init()
q(H) = NormalMeanVariance(1.0, 1.0)
end
function do_infer()
y = infer(
model=model(),
data=(
dummy = 2.0,
),
constraints=mk_constraints(),
meta=model_meta(),
initialization=init(),
)
return y
end
do_infer()
Error 1:
Model:
Error 2:
Model: