ReactiveBayes / RxInfer.jl

Julia package for automated Bayesian inference on a factor graph with reactive message passing
https://rxinfer.ml/
MIT License
278 stars 23 forks source link

Possibility to exponentiate random variables? #139

Closed mrazomej closed 1 year ago

mrazomej commented 1 year ago

I just learned about the project and still need to understand the inference pipeline fully. Nevertheless, I tried to adapt a Turing.jl model to work with RxInfer, just to run into the wall that the LogNormal distribution is not available. A simple solution would be to use a Normal distribution and then exponentiate the result. It turns out that exp is also not implemented for random variables. I wish I had a deeper understanding and more time to make a PR to implement this. Is this something on your immediate to-do list?

albertpod commented 1 year ago

Hi @mrazomej! Thanks for trying out RxInfer.jl, indeed LogNormal node isn't implemented in the package.

One way of circumventing the issue is, in fact to use exp as you've said.

Please have a look at the notebooks that make use of different nonlinear nodes:

  1. https://biaslab.github.io/RxInfer.jl/stable/examples/Nonlinear%20Noisy%20Pendulum/
  2. https://biaslab.github.io/RxInfer.jl/stable/examples/Nonlinear%20Rabbit%20Population/
  3. https://biaslab.github.io/RxInfer.jl/stable/examples/Nonlinear%20Sensor%20Fusion/

It will be helpful to look at your model to provide a more comprehensive solution. In case you need help with using exp, I will be happy to assist you further.

P.S. You can also implement LogNormal node yourself. Check out the following tutorial https://biaslab.github.io/RxInfer.jl/stable/manuals/custom-node/

albertpod commented 1 year ago

@mrazomej here's a minimal example (I personally prefer to use meta outside of the model definition, but I provided few other options):

using RxInfer

@model function test()

    y = datavar(Float64)
    # initialize variables
    x ~ Normal(mean=0.0, var=1.0)

    # z ~ exp(x) where {meta = Linearization(), inverse=log} 
    # z ~ exp(x) where {meta = DeltaMeta(; method = Unscented(), inverse = log)}
    # z ~ exp(x) where {meta = DeltaMeta(; method = Linearization(), inverse = log)}
    z ~ exp(x)
    y ~ Normal(mean=z, var=1.0)

end;

@meta function lognormal()
    exp(x) -> Linearization()
end

result = inference(model=test(), data=(y=1.0, ), meta=lognormal())

easy ;)

We are working on a big update of RxInfer that will cover all distributions from the exponential family and more, perhaps late September.

mrazomej commented 1 year ago

Amazing, thank you so much! I will try it out and report back. The project and the idea look absolutely amazing. I am really excited about it, and I will try to dig deeper to understand the method fully. Thanks for the prompt response, and kudos on such a neat library!

mrazomej commented 1 year ago

I am trying to use the library too much as a black box. I naively thought that translating my Turing.jl model would be straightforward, but not fully understanding the method behind your library is bitting me back. I need to learn how to do simple operations, such as taking a matrix of random variables and normalizing it by row.

In my model, I sample the λ parameters of a bunch of Poisson random variables as a long vector from a MvLogNormal (this is to speed up the computation in Turing since the distribution has a diagonal covariance matrix), then I reshape this long vector into the correct dimensions such that it is easier to work with it (rows represent time points and columns different tracked variables).

Λ̲̲ = reshape(Λ̲̲, n_row, n_col)

Then, I normalize the entries of this matrix by rows as

F̲̲ = Λ̲̲ ./ sum(Λ̲̲, dims=2)

All of this is because my model requires taking the log ratio of these normalized values between adjacent time points, i.e.,

Γ̲̲ = log.(F̲̲[2:end, :] ./ F̲̲[1:end-1, :])

My ignorance is making me hit too many walls with your excellent library. It might even be very possible that RxInfer is not the proper framework for the inference I want to make. Still, I was seduced by the potential to do inference on thousands of parameters⸺a current limitation with my MCMC-based method.

albertpod commented 1 year ago

@mrazomej sure! I think what happened with your initial implementation is that you've tried to use z = exp(x), which won't do due to the way = operator works in Julia, so you should use ~ or := operators within @model marco.

Indeed casting your model from Turing.jl into RxInfer.jl can be very challenging, that's not to say impossible.

To give you some idea, in short, RxInfer.jl creates a graphical representation of the model. Every distribution or function you introduce in the @model macro will be mapped to a node in the factor graph. These nodes send so-called messages that propagate information to the other nodes in the form of distribution. When messages between functional nodes collide it results in a posterior distribution. No sampling is involved, which makes RxInfer.jl very fast.

More details can be found here.

At the moment (should be significantly improved in the October version), if the distribution is not represented as a node in RxInfer.jl like in LogNormal case, then you'd need to write your own update rules or use some approximations (Linearization, Unscented or CVI)

CVI is not easy to use as it requires quite some tuning of hyperparameters. That's to say that the package might not be very suitable for your needs.

The equations you've provided make sense, but RxInfer.jl is not that flexible at the moment. For example, having a matrix that stores random variables isn't available out of the box.

I imagine you wouldn't like to make your model public, but maybe you can show a minimum example of what your model looks like so I can better navigate you.

Cheers

mrazomej commented 1 year ago

Thank you so much, @albertpod! I read a bit of the paper and got the general gist of it. But, still, I have a long way to go to grasp the depth of your elegant method fully. I am happy to share my model. Here is a minimum version to get to the operations I need to do:

Imagine you are tracking a population of organisms over time. You have N different organisms that you observe over T time points. So the data consists of a T × N matrix $D$ with integer entries. At each time point, you take a sample of the state of the population, so you draw a Poisson-distributed number of organisms such that each $D_{ti}$ entry is

$$ D{ti} \sim \text{Poiss}(\lambda{ti}) $$

where

$$ \lambda{ti} \propto n{ti} $$

i.e., the Poisson sample you grab for organism $i$ depends on the current number of organisms of type $i$ in your population. Now, for some biological/evolutionary reasons, you think that the dynamics of the population follow an equation of the form

$$ f{(t+1)i} = f{ti} \exp(s_i) $$

where $f_{ti}$ is the frequency of organism $i$ at time $t$. This is,

$$ f{ti} = \frac{n{ti}}{\sumj n{tj}} = \frac{\lambda_{ti}}{\sumj \lambda{tj}} $$

So the dynamics of organism $i$ depend on the population composition. What I care about is the parameter $s_i$ that determines the "relative fitness" of each organism.

The way I constructed the inference in Turing.jl is to assume I have priors of the form

$$ si \sim \mathcal{N}(\mu{s}, \sigma_{s}), $$

and

$$ \lambda{ti} \sim \log\mathcal{N}(\mu\lambda, \sigma_\lambda). $$

Furthermore, since my likelihood function is also a log-normal distribution, I add the nuisance parameter

$$ \sigmai \sim \log\mathcal{N}(\mu\sigma, \sigma_\sigma), $$

such that the likelihood function for my observations is

$$ D{ti} \sim \text{Poisson}(\lambda{ti}), $$

and for the frequencies, I have

$$ \frac{f{(t+1)i}}{f{ti}} \sim \log\mathcal{N}(s_i, \sigma_i). $$

Here is an example of how one could write this model in Turing.jl. I wrote it using inefficient for-loops that one should avoid when using AutoDiff, but I think this gets at the core of the ideas and operations behind the model. Furthermore, the use of multivariate distributions here is for computational efficiency only.

Turing.@model function pop_dynamics(D)
    # Extract the number of time points and number of mutants
    n_time, n_mut = size(D)

    # Prior on parameter sᵢ
    sᵢ ~ Turing.MvNormal(
        repeat([μ_s], n_mut), LinearAlgebra.I(n_mut) .* σ_s .^ 2
    )

    # Prior on parameter σᵢ
    σᵢ ~ Turing.MvLogNormal(
        repeat([μ_σ], n_mut), LinearAlgebra.I(n_mut) .* σ_σ .^ 2
    )

    # Prior on parameter λₜᵢ
    λₜᵢ ~ Turing.MvLogNormal(
        repeat([λ_σ], n_time * n_mut), 
        LinearAlgebra.I(n_time * n_mut) .* λ_σ .^ 2
    )

    # Reshape λₜᵢ to make it easier to manipulate
    λₜᵢ = reshape(λₜᵢ, shape(D)...)

    # Compute relative frequencies by time point
    fₜᵢ = λₜᵢ ./ sum(λₜᵢ, dims=2)

    # Compute frequency ratios between adjacent time points
    γₜᵢ = fₜᵢ[2:end, :] ./ fₜᵢ[1:end-1, :]

    # NOTE: I don't use the following for-loops in my model, but this is easier
    # to follow.

    # Likelihood function for data in D

    # Loop through rows
    for t in axes(fₜᵢ, 1)
        # Loop through columns
        for i in axes(fₜᵢ, 2)
            D[t, i] ~ Turing.Poisson(λₜᵢ[t, i])
        end # for
    end # for

    # Likelihood function for frequency ratios
    # Loop through rows
    for t in axes(γₜᵢ, 1)
        # Loop through columns
        for i in axes(γₜᵢ, 2)
            γₜᵢ[t, i] ~ Turing.LogNormal(sᵢ[i], σᵢ[i])
        end # for
    end # for

end # model

Hopefully, this explanation and minimal model is enough for you to see whether it can easily be translated into a RxInfer.jl model. The reason is that the number of organisms can be huge, so it becomes impossible to sample the posterior distribution. So if I could get this to work within your framework, that would be a total game-changer for what I'm trying to do!

albertpod commented 1 year ago

@mrazomej I assigned PhD student of our lab @HoangMHNguyen for this issue. Hopefully, he'll be able to help you out!

mrazomej commented 1 year ago

Thank you so much, @albertpod! I really hope this is a doable problem. Getting this to work efficiently would completely transform what I'm trying to do. @HoangMHNguyen, please let me know how I can help. I can also get in touch with you guys through something else other than this GitHub issue.

HoangMHNguyen commented 1 year ago

Hi @mrazomej! Thank you for trying out RxInfer.jl library. As @albertpod mentioned, the current RxInfer.jl lacks some features like supporting a matrix of random variables or LogNormal distribution, and this makes translating your model from Turing to RxInfer difficult (a big update to RxInfer is on the way). However, a nice feature of RxInfer is that you can define your own node and rules for messages (https://biaslab.github.io/RxInfer.jl/stable/manuals/custom-node/). For the update rules of messages, you can think of some algorithms like Belief Propagation or Variational Message Passing. The following model might be an example of your model in RxInfer.jl:

@model function pop_dynamics(n_time, n_organ)
    y = datavar(Vector{Int64},n_time)
    λ = randomvar(n_time)
    s = randomvar()
    w = randomvar()
    γ = randomvar(n_time)

    #prior 
    s ~ MvNormalMeanCovariance(zeros(n_organ), diageye(n_organ))
    w ~ ElementwiseGammaShapeScale(0.01*ones(n_organ), 100*ones(n_organ)) # this is the inverse of σ

    λ_0 ~ ElementwiseGammaShapeScale(0.01*ones(n_organ), 100*ones(n_organ))
    λ_prev = λ_0 
    #consider each time step
    for t=1:n_time 
        λ[t] ~ ElementwiseGammaShapeScale(rand()*ones(n_organ), rand()*ones(n_organ)) #prior 
        γ[t] ~ FreqRatioNode(λ_prev,λ[t])
        γ[t] ~ ElementwiseLogNormal(s,w)
        y[t] ~ ElementwisePoisson(λ[t])
        λ_prev = λ[t]
    end
end

and the inference looks like this

iresult = inference(
    model = pop_dynamics(num_time,num_organ),
    iterations = 2, 
    initmarginals = (s = MvNormalMeanPrecision(zeros(2),diageye(2)),
                w = Gamma.(0.01*ones(num_organ), 100*ones(num_organ)),),
    data  = (y = y_data,),
    constraints = pop_dynamics_constraints(),
    returnvars = (s = KeepLast(),)
)

In the above model, onlyMvNormalMeanCovariance node is already buit in RxInfer.jl, while the other nodes ElementwiseGammaShapeScale, FreqRatioNode, ElementwiseLogNormal and ElementwisePoisson are custom nodes created by users. Here I changed the prior distributions of λ and σ (changed to w) from LogNormal to Gamma distribution since Gamma distribution is the conjugate prior of LogNormal and Poisson distributions and conjugate computation results in closed-form solutions. Of course you can use LogNormal prior but the computation for messages might become cumbersome. I changed the prior because I assume you use LogNormal just for the positive of your parameters.

For the custom nodes above, you can define update rules for messages. The followings are examples:

struct ElementwisePoisson end 
@node ElementwisePoisson Stochastic [ y, x ]
@rule ElementwisePoisson(:y, Marginalisation) (q_x::Vector{GammaDistributionsFamily},) = begin
    α_x, β_x = Distributions.params.(q_x)
    λ = exp.(digamma.(α_x) .- log.(β_x))
    return Poisson.(λ)
end

@rule ElementwisePoisson(:x, Marginalisation) (q_y::Vector{Poisson},)= begin
    λ_y = mean.(q_y)
    return GammaShapeScale.(1 .+ λ_y, ones(length(λ_y)))
end

@rule ElementwisePoisson(:x, Marginalisation) (q_y::PointMass,)= begin
    λ_y = q_y.point

    return GammaShapeScale.(1 .+ λ_y, ones(length(λ_y)))
end
#create element-wise Lognormal node 
struct ElementwiseLogNormal end 

@node ElementwiseLogNormal Stochastic [ y, μ, w ]

@rule ElementwiseLogNormal(:y, Marginalisation) (q_μ::MultivariateNormalDistributionsFamily, q_w::Array,) = begin
    mean_μ = mean(q_μ)
    mean_w = mean.(q_w)
    σ = sqrt.(inv.(mean_w))
    return LogNormal.(mean_μ, σ)
end

@rule ElementwiseLogNormal(:w, Marginalisation) (q_y::Array, q_μ::MvNormalWeightedMeanPrecision, ) = begin 
    param = Distributions.params.(q_y)
    μ_y = [param[i][1] for i=1:length(param)]
    σ_y = [param[i][2] for i=1:length(param)]
    mean_μ,cov_μ = mean_cov(q_μ)
    var_μ = diag(cov_μ)

    return GammaShapeScale.(3/2*ones(length(μ_y)),1/(0.5.*(mean_μ.^2 .+ var_μ .- 2 .* mean_μ.*μ_y .+ μ_y.^2 .+ σ_y.^2)))
end

@rule ElementwiseLogNormal(:w, Marginalisation) (q_y::Array, q_μ::MvNormalMeanPrecision, ) = begin 
    μ_y, σ_y = Distributions.params.(q_y)
    mean_μ,cov_μ = mean_cov(q_μ)
    var_μ = diag(cov_μ)

    return GammaShapeScale.(3/2*ones(length(μ_y)),1/(0.5.*(mean_μ.^2 .+ var_μ .- 2 .* mean_μ.*μ_y .+ μ_y.^2 .+ σ_y.^2)))
end

@rule ElementwiseLogNormal(:μ, Marginalisation) (q_y::Array, q_w::Array, ) = begin 
    μ_y = Distributions.params.(q_y)
    mean_w = mean.(q_w)
    μ_y = [μ_y[i][1] for i=1:length(μ_y)] 
    return MvNormalMeanPrecision(μ_y, Diagonal(mean_w))
end   
struct ElementwiseGammaShapeScale end 

@node ElementwiseGammaShapeScale Stochastic [ y, α, β ]

@rule ElementwiseGammaShapeScale(:y, Marginalisation) (m_α::PointMass, m_β::PointMass,) = begin
    return GammaShapeScale.(m_α.point, m_β.point)
end

@rule ElementwiseGammaShapeScale(:y, Marginalisation) (q_α::PointMass, q_β::PointMass, ) = begin 
    return GammaShapeScale.(q_α.point, q_β.point)
end
struct FreqRatioNode end
@node FreqRatioNode Deterministic [ y, x, z]

@rule FreqRatioNode(:z, Marginalisation) (m_y::Array, m_x::Array, ) = begin 
    sample_list = rand.(m_x,2)
    list_logpdf = [make_message.(sample_list[i], m_y[i]) for i =1:length(m_y)] #array of function
    return list_logpdf
end

@rule FreqRatioNode(:x, Marginalisation) (m_y::Array, m_z::Array, ) = begin 
    sample_list = rand.(m_z,2)
    list_logpdf = [make_message.(sample_list[i], m_y[i]) for i =1:length(m_y)] #array of function
    return list_logpdf
end

@rule FreqRatioNode(:y, Marginalisation) (m_x::Array, m_z::Array, ) = begin 
    nsamples = 10
    sample_list_x = hcat(rand.(m_x,nsamples)...) #prev_time
    sample_list_z = hcat(rand.(m_z,nsamples)...) #present_time
    γ = [freq_ratio(sample_list_x[i,:], sample_list_z[i,:]) for i=1:size(sample_list_x,1)]
    logγ = [log.(γ_i) for γ_i in γ]
    logγ² = [(log.(γ_i)).^2 for γ_i in γ]
    E_logγ = sum(logγ,dims=1) / nsamples 
    E_logγ² = sum(logγ²,dims=1) / nsamples
    μ_γ = E_logγ[1]
    σ_γ = sqrt.(E_logγ²[1] .- μ_γ.^2)
    return LogNormal.(μ_γ,σ_γ)
end

Don't worry if you don't know whether you have defined all necessary rules. The inference function will throw errors like this

RuleMethodError: no method matching rule for the given arguments

Possible fix, define:

@rule ElementwiseGammaShapeScale(:y, Marginalisation) (q_α::PointMass, q_β::PointMass, ) = begin 
    return ...
end

and then you just add the missing rule. You might also get an error like

NoAnalyticalProdException:   No analytical rule available to compute a product of....

which tells you to define a product rule for two distributions (this step becomes easier if you have conjugate-priors).

To run the model above, I'll provide some more code. These are just examples to show you an idea how RxInfer.jl works. You can of course define your own computations.

function freq_ratio(λ_prev, λ_present)
    f_prev = λ_prev ./ sum(λ_prev)
    f_present = λ_present ./ sum(λ_present)
    γ = f_present ./ f_prev
    return γ
end

function make_message(samples_A,d_in)
    return let samples_A=samples_A,d_in=d_in
        (x) -> begin
            result = mapreduce(+, zip(samples_A,)) do (sampleA,)
                return pdf(d_in,freq_ratio(x,sampleA))
            end
            return log(result)
        end
    end
end

function ReactiveMP.prod(::ProdAnalytical, left::Vector{Gamma{Float64}}, right::Array) 
    return left 
end

function ReactiveMP.prod(::ProdAnalytical, left::Vector{GammaShapeScale{Float64}}, right::Vector{GammaShapeScale{Float64}})
    T = ReactiveMP.promote_samplefloattype(left[1], right[1])
    d = [GammaShapeScale(shape(left[i]) + shape(right[i]) - one(T), 1/(rate(left[i]) + rate(right[i]))) for i=1:length(left)] #array 
    return d
end

function ReactiveMP.prod(::ProdAnalytical, left::Vector{LogNormal{Float64}}, right::Vector{LogNormal{Float64}})
    #you can define the rule by yourself 
    return left 
end

@constraints function pop_dynamics_constraints()
    q(γ,s,w) = q(γ)q(s)q(w) #mean-field constraint
end

I hope my answer is helpful for your work. If you have any question, I'm willing to help.

cheers.

mrazomej commented 1 year ago

Thank you so much @HoangMHNguyen! I will try this out and report back. Really, thank you for taking the time to do this.

albertpod commented 1 year ago

Hey @mrazomej and @HoangMHNguyen! I will transfer this issue into discussion. We can continue there if needed.