ReactiveBayes / RxInfer.jl

Julia package for automated Bayesian inference on a factor graph with reactive message passing
MIT License
265 stars 23 forks source link

Linear model with log-link on sigma (exp(sigma)) errors: are there plans to support it in the future? #357

Closed DominiqueMakowski closed 4 days ago

DominiqueMakowski commented 5 days ago

I am trying to fit a simple linear model with some parameters expressed using a link-function (e.g., exp()) which is common to express priors in an unconstrained space.

using Random, RxInfer

y = randn(100)

RxInfer.@model function model_Gaussian(y)

    # Priors
    μ ~ RxInfer.NormalMeanVariance(0.3, 0.5)
    σ ~ RxInfer.NormalMeanVariance(log(0.2), 3)

    for i in eachindex(y)
        sigma = exp(σ)
        y[i] ~ RxInfer.NormalMeanVariance(μ, sigma)
    end
end

result = infer(
    model=model_Gaussian(),
    data=(y=y,),
)

Unfortunately, the above fails with:

ERROR: MethodError: no method matching exp(::GraphPPL.VariableRef{…})

In general, my question is about whether arbitrary link functions (such as that of StatsFuns) will be supported in the future?

Also, are there plans to make RxInfer work with a "standard" Distributions.Normal() rather than the bespoke NormalMeanVariance()? Thanks for the clarifications!

albertpod commented 4 days ago

Hi @DominiqueMakowski. There's inherently no issue with RxInfer here, so I will transfer this issue into a discussion. Given the non-conjugacy and nonlinearity of your model, you will likely have to use the CVI projections method. This is documented here. Here's an example of how to solve this model with this method:

using RxInfer
using ExponentialFamilyProjection
using BayesBase
using Distributions

# Define the custom distribution first
struct TransformedNormalDistribution{H,T} <: ContinuousUnivariateDistribution
    h::H
    t::T
end

BayesBase.logpdf(dist::TransformedNormalDistribution, x) = logpdf(Normal(dist.h, exp(dist.t)), x)
BayesBase.insupport(dist::TransformedNormalDistribution, x) = true

@node TransformedNormalDistribution Stochastic [out, h, t]

# Define the model
@model function model_Gaussian(y)
    μ ~ Normal(mean=0.3, variance=0.5)
    σ ~ Normal(mean=log(0.2), variance=1.0)
    y .~ TransformedNormalDistribution(μ, σ)
end

# Define constraints
@constraints function non_conjugate_model_constraints()
    q(μ) :: ProjectedTo(NormalMeanVariance)
    q(σ) :: ProjectedTo(NormalMeanVariance)
    q(σ, μ) = q(σ)q(μ)
end

# Define initialization
@initialization function model_initialization()
    q(σ) = NormalMeanVariance(0.0, 1.0)
    # q(σ) = GammaShapeRate(1.0, 1.0) # also works
end

# Generate sample data
y = randn(100)

# Perform inference
result = infer(
    model = model_Gaussian(),
    data = (y = y,),
    constraints = non_conjugate_model_constraints(),
    initialization = model_initialization(),
    options = (rulefallback = NodeFunctionRuleFallback(),),
    showprogress = true,
    iterations = 10
)

println(result.posteriors[:μ][end])
println(result.posteriors[:σ][end])

I haven't checked the inference result, but this should give you a good starting point. There are other ways to run inference on your model, but I opted for the fastest one now.

As for support for Distributions.jl, I acknowledge it can sometimes be inconvenient. It's important to note that "primitive" nodes in GraphPPL are not the same as distributions from Distributions.jl. However, apart from model definition, RxInfer does support regular Distributions. We augment these with ExponentialFamily.jl, which provides some important features for RxInfer.jl's inference engine.

If you need any further clarification or have additional questions, feel free to ask.