Closed mrazomej closed 1 year ago
Hi @mrazomej! Thanks for trying out RxInfer.jl
, indeed LogNormal
node isn't implemented in the package.
One way of circumventing the issue is, in fact to use exp
as you've said.
Please have a look at the notebooks that make use of different nonlinear nodes:
It will be helpful to look at your model to provide a more comprehensive solution. In case you need help with using exp, I will be happy to assist you further.
P.S. You can also implement LogNormal node yourself. Check out the following tutorial https://biaslab.github.io/RxInfer.jl/stable/manuals/custom-node/
@mrazomej here's a minimal example (I personally prefer to use meta outside of the model definition, but I provided few other options):
using RxInfer
@model function test()
y = datavar(Float64)
# initialize variables
x ~ Normal(mean=0.0, var=1.0)
# z ~ exp(x) where {meta = Linearization(), inverse=log}
# z ~ exp(x) where {meta = DeltaMeta(; method = Unscented(), inverse = log)}
# z ~ exp(x) where {meta = DeltaMeta(; method = Linearization(), inverse = log)}
z ~ exp(x)
y ~ Normal(mean=z, var=1.0)
end;
@meta function lognormal()
exp(x) -> Linearization()
end
result = inference(model=test(), data=(y=1.0, ), meta=lognormal())
easy ;)
We are working on a big update of RxInfer
that will cover all distributions from the exponential family and more, perhaps late September.
Amazing, thank you so much! I will try it out and report back. The project and the idea look absolutely amazing. I am really excited about it, and I will try to dig deeper to understand the method fully. Thanks for the prompt response, and kudos on such a neat library!
I am trying to use the library too much as a black box. I naively thought that translating my Turing.jl
model would be straightforward, but not fully understanding the method behind your library is bitting me back. I need to learn how to do simple operations, such as taking a matrix of random variables and normalizing it by row.
In my model, I sample the λ parameters of a bunch of Poisson random variables as a long vector from a MvLogNormal
(this is to speed up the computation in Turing
since the distribution has a diagonal covariance matrix), then I reshape this long vector into the correct dimensions such that it is easier to work with it (rows represent time points and columns different tracked variables).
Λ̲̲ = reshape(Λ̲̲, n_row, n_col)
Then, I normalize the entries of this matrix by rows as
F̲̲ = Λ̲̲ ./ sum(Λ̲̲, dims=2)
All of this is because my model requires taking the log ratio of these normalized values between adjacent time points, i.e.,
Γ̲̲ = log.(F̲̲[2:end, :] ./ F̲̲[1:end-1, :])
My ignorance is making me hit too many walls with your excellent library. It might even be very possible that RxInfer
is not the proper framework for the inference I want to make. Still, I was seduced by the potential to do inference on thousands of parameters⸺a current limitation with my MCMC-based method.
@mrazomej sure! I think what happened with your initial implementation is that you've tried to use z = exp(x)
, which won't do due to the way =
operator works in Julia, so you should use ~
or :=
operators within @model
marco.
Indeed casting your model from Turing.jl
into RxInfer.jl
can be very challenging, that's not to say impossible.
To give you some idea, in short, RxInfer.jl
creates a graphical representation of the model. Every distribution or function you introduce in the @model
macro will be mapped to a node in the factor graph. These nodes send so-called messages that propagate information to the other nodes in the form of distribution. When messages between functional nodes collide it results in a posterior distribution. No sampling is involved, which makes RxInfer.jl
very fast.
More details can be found here.
At the moment (should be significantly improved in the October version), if the distribution is not represented as a node in RxInfer.jl
like in LogNormal
case, then you'd need to write your own update rules or use some approximations (Linearization
, Unscented
or CVI)
CVI
is not easy to use as it requires quite some tuning of hyperparameters. That's to say that the package might not be very suitable for your needs.
The equations you've provided make sense, but RxInfer.jl
is not that flexible at the moment. For example, having a matrix that stores random variables isn't available out of the box.
I imagine you wouldn't like to make your model public, but maybe you can show a minimum example of what your model looks like so I can better navigate you.
Cheers
Thank you so much, @albertpod! I read a bit of the paper and got the general gist of it. But, still, I have a long way to go to grasp the depth of your elegant method fully. I am happy to share my model. Here is a minimum version to get to the operations I need to do:
Imagine you are tracking a population of organisms over time. You have N
different organisms that you observe over T
time points. So the data consists of a T × N
matrix $D$ with integer entries. At each time point, you take a sample of the state of the population, so you draw a Poisson-distributed number of organisms such that each $D_{ti}$ entry is
$$ D{ti} \sim \text{Poiss}(\lambda{ti}) $$
where
$$ \lambda{ti} \propto n{ti} $$
i.e., the Poisson sample you grab for organism $i$ depends on the current number of organisms of type $i$ in your population. Now, for some biological/evolutionary reasons, you think that the dynamics of the population follow an equation of the form
$$ f{(t+1)i} = f{ti} \exp(s_i) $$
where $f_{ti}$ is the frequency of organism $i$ at time $t$. This is,
$$ f{ti} = \frac{n{ti}}{\sumj n{tj}} = \frac{\lambda_{ti}}{\sumj \lambda{tj}} $$
So the dynamics of organism $i$ depend on the population composition. What I care about is the parameter $s_i$ that determines the "relative fitness" of each organism.
The way I constructed the inference in Turing.jl
is to assume I have priors of the form
$$ si \sim \mathcal{N}(\mu{s}, \sigma_{s}), $$
and
$$ \lambda{ti} \sim \log\mathcal{N}(\mu\lambda, \sigma_\lambda). $$
Furthermore, since my likelihood function is also a log-normal distribution, I add the nuisance parameter
$$ \sigmai \sim \log\mathcal{N}(\mu\sigma, \sigma_\sigma), $$
such that the likelihood function for my observations is
$$ D{ti} \sim \text{Poisson}(\lambda{ti}), $$
and for the frequencies, I have
$$ \frac{f{(t+1)i}}{f{ti}} \sim \log\mathcal{N}(s_i, \sigma_i). $$
Here is an example of how one could write this model in Turing.jl
. I wrote it using inefficient for
-loops that one should avoid when using AutoDiff, but I think this gets at the core of the ideas and operations behind the model. Furthermore, the use of multivariate distributions here is for computational efficiency only.
Turing.@model function pop_dynamics(D)
# Extract the number of time points and number of mutants
n_time, n_mut = size(D)
# Prior on parameter sᵢ
sᵢ ~ Turing.MvNormal(
repeat([μ_s], n_mut), LinearAlgebra.I(n_mut) .* σ_s .^ 2
)
# Prior on parameter σᵢ
σᵢ ~ Turing.MvLogNormal(
repeat([μ_σ], n_mut), LinearAlgebra.I(n_mut) .* σ_σ .^ 2
)
# Prior on parameter λₜᵢ
λₜᵢ ~ Turing.MvLogNormal(
repeat([λ_σ], n_time * n_mut),
LinearAlgebra.I(n_time * n_mut) .* λ_σ .^ 2
)
# Reshape λₜᵢ to make it easier to manipulate
λₜᵢ = reshape(λₜᵢ, shape(D)...)
# Compute relative frequencies by time point
fₜᵢ = λₜᵢ ./ sum(λₜᵢ, dims=2)
# Compute frequency ratios between adjacent time points
γₜᵢ = fₜᵢ[2:end, :] ./ fₜᵢ[1:end-1, :]
# NOTE: I don't use the following for-loops in my model, but this is easier
# to follow.
# Likelihood function for data in D
# Loop through rows
for t in axes(fₜᵢ, 1)
# Loop through columns
for i in axes(fₜᵢ, 2)
D[t, i] ~ Turing.Poisson(λₜᵢ[t, i])
end # for
end # for
# Likelihood function for frequency ratios
# Loop through rows
for t in axes(γₜᵢ, 1)
# Loop through columns
for i in axes(γₜᵢ, 2)
γₜᵢ[t, i] ~ Turing.LogNormal(sᵢ[i], σᵢ[i])
end # for
end # for
end # model
Hopefully, this explanation and minimal model is enough for you to see whether it can easily be translated into a RxInfer.jl
model. The reason is that the number of organisms can be huge, so it becomes impossible to sample the posterior distribution. So if I could get this to work within your framework, that would be a total game-changer for what I'm trying to do!
@mrazomej I assigned PhD student of our lab @HoangMHNguyen for this issue. Hopefully, he'll be able to help you out!
Thank you so much, @albertpod! I really hope this is a doable problem. Getting this to work efficiently would completely transform what I'm trying to do. @HoangMHNguyen, please let me know how I can help. I can also get in touch with you guys through something else other than this GitHub issue.
Hi @mrazomej! Thank you for trying out RxInfer.jl
library. As @albertpod mentioned, the current RxInfer.jl
lacks some features like supporting a matrix of random variables or LogNormal
distribution, and this makes translating your model from Turing
to RxInfer
difficult (a big update to RxInfer
is on the way). However, a nice feature of RxInfer
is that you can define your own node and rules for messages (https://biaslab.github.io/RxInfer.jl/stable/manuals/custom-node/). For the update rules of messages, you can think of some algorithms like Belief Propagation
or Variational Message Passing
.
The following model might be an example of your model in RxInfer.jl
:
@model function pop_dynamics(n_time, n_organ)
y = datavar(Vector{Int64},n_time)
λ = randomvar(n_time)
s = randomvar()
w = randomvar()
γ = randomvar(n_time)
#prior
s ~ MvNormalMeanCovariance(zeros(n_organ), diageye(n_organ))
w ~ ElementwiseGammaShapeScale(0.01*ones(n_organ), 100*ones(n_organ)) # this is the inverse of σ
λ_0 ~ ElementwiseGammaShapeScale(0.01*ones(n_organ), 100*ones(n_organ))
λ_prev = λ_0
#consider each time step
for t=1:n_time
λ[t] ~ ElementwiseGammaShapeScale(rand()*ones(n_organ), rand()*ones(n_organ)) #prior
γ[t] ~ FreqRatioNode(λ_prev,λ[t])
γ[t] ~ ElementwiseLogNormal(s,w)
y[t] ~ ElementwisePoisson(λ[t])
λ_prev = λ[t]
end
end
and the inference looks like this
iresult = inference(
model = pop_dynamics(num_time,num_organ),
iterations = 2,
initmarginals = (s = MvNormalMeanPrecision(zeros(2),diageye(2)),
w = Gamma.(0.01*ones(num_organ), 100*ones(num_organ)),),
data = (y = y_data,),
constraints = pop_dynamics_constraints(),
returnvars = (s = KeepLast(),)
)
In the above model, onlyMvNormalMeanCovariance
node is already buit in RxInfer.jl
, while the other nodes ElementwiseGammaShapeScale
, FreqRatioNode
, ElementwiseLogNormal
and ElementwisePoisson
are custom nodes created by users. Here I changed the prior distributions of λ
and σ
(changed to w
) from LogNormal
to Gamma
distribution since Gamma
distribution is the conjugate prior of LogNormal
and Poisson
distributions and conjugate computation results in closed-form solutions. Of course you can use LogNormal
prior but the computation for messages might become cumbersome. I changed the prior because I assume you use LogNormal
just for the positive of your parameters.
For the custom nodes above, you can define update rules for messages. The followings are examples:
struct ElementwisePoisson end
@node ElementwisePoisson Stochastic [ y, x ]
@rule ElementwisePoisson(:y, Marginalisation) (q_x::Vector{GammaDistributionsFamily},) = begin
α_x, β_x = Distributions.params.(q_x)
λ = exp.(digamma.(α_x) .- log.(β_x))
return Poisson.(λ)
end
@rule ElementwisePoisson(:x, Marginalisation) (q_y::Vector{Poisson},)= begin
λ_y = mean.(q_y)
return GammaShapeScale.(1 .+ λ_y, ones(length(λ_y)))
end
@rule ElementwisePoisson(:x, Marginalisation) (q_y::PointMass,)= begin
λ_y = q_y.point
return GammaShapeScale.(1 .+ λ_y, ones(length(λ_y)))
end
#create element-wise Lognormal node
struct ElementwiseLogNormal end
@node ElementwiseLogNormal Stochastic [ y, μ, w ]
@rule ElementwiseLogNormal(:y, Marginalisation) (q_μ::MultivariateNormalDistributionsFamily, q_w::Array,) = begin
mean_μ = mean(q_μ)
mean_w = mean.(q_w)
σ = sqrt.(inv.(mean_w))
return LogNormal.(mean_μ, σ)
end
@rule ElementwiseLogNormal(:w, Marginalisation) (q_y::Array, q_μ::MvNormalWeightedMeanPrecision, ) = begin
param = Distributions.params.(q_y)
μ_y = [param[i][1] for i=1:length(param)]
σ_y = [param[i][2] for i=1:length(param)]
mean_μ,cov_μ = mean_cov(q_μ)
var_μ = diag(cov_μ)
return GammaShapeScale.(3/2*ones(length(μ_y)),1/(0.5.*(mean_μ.^2 .+ var_μ .- 2 .* mean_μ.*μ_y .+ μ_y.^2 .+ σ_y.^2)))
end
@rule ElementwiseLogNormal(:w, Marginalisation) (q_y::Array, q_μ::MvNormalMeanPrecision, ) = begin
μ_y, σ_y = Distributions.params.(q_y)
mean_μ,cov_μ = mean_cov(q_μ)
var_μ = diag(cov_μ)
return GammaShapeScale.(3/2*ones(length(μ_y)),1/(0.5.*(mean_μ.^2 .+ var_μ .- 2 .* mean_μ.*μ_y .+ μ_y.^2 .+ σ_y.^2)))
end
@rule ElementwiseLogNormal(:μ, Marginalisation) (q_y::Array, q_w::Array, ) = begin
μ_y = Distributions.params.(q_y)
mean_w = mean.(q_w)
μ_y = [μ_y[i][1] for i=1:length(μ_y)]
return MvNormalMeanPrecision(μ_y, Diagonal(mean_w))
end
struct ElementwiseGammaShapeScale end
@node ElementwiseGammaShapeScale Stochastic [ y, α, β ]
@rule ElementwiseGammaShapeScale(:y, Marginalisation) (m_α::PointMass, m_β::PointMass,) = begin
return GammaShapeScale.(m_α.point, m_β.point)
end
@rule ElementwiseGammaShapeScale(:y, Marginalisation) (q_α::PointMass, q_β::PointMass, ) = begin
return GammaShapeScale.(q_α.point, q_β.point)
end
struct FreqRatioNode end
@node FreqRatioNode Deterministic [ y, x, z]
@rule FreqRatioNode(:z, Marginalisation) (m_y::Array, m_x::Array, ) = begin
sample_list = rand.(m_x,2)
list_logpdf = [make_message.(sample_list[i], m_y[i]) for i =1:length(m_y)] #array of function
return list_logpdf
end
@rule FreqRatioNode(:x, Marginalisation) (m_y::Array, m_z::Array, ) = begin
sample_list = rand.(m_z,2)
list_logpdf = [make_message.(sample_list[i], m_y[i]) for i =1:length(m_y)] #array of function
return list_logpdf
end
@rule FreqRatioNode(:y, Marginalisation) (m_x::Array, m_z::Array, ) = begin
nsamples = 10
sample_list_x = hcat(rand.(m_x,nsamples)...) #prev_time
sample_list_z = hcat(rand.(m_z,nsamples)...) #present_time
γ = [freq_ratio(sample_list_x[i,:], sample_list_z[i,:]) for i=1:size(sample_list_x,1)]
logγ = [log.(γ_i) for γ_i in γ]
logγ² = [(log.(γ_i)).^2 for γ_i in γ]
E_logγ = sum(logγ,dims=1) / nsamples
E_logγ² = sum(logγ²,dims=1) / nsamples
μ_γ = E_logγ[1]
σ_γ = sqrt.(E_logγ²[1] .- μ_γ.^2)
return LogNormal.(μ_γ,σ_γ)
end
Don't worry if you don't know whether you have defined all necessary rules. The inference
function will throw errors like this
RuleMethodError: no method matching rule for the given arguments
Possible fix, define:
@rule ElementwiseGammaShapeScale(:y, Marginalisation) (q_α::PointMass, q_β::PointMass, ) = begin
return ...
end
and then you just add the missing rule. You might also get an error like
NoAnalyticalProdException: No analytical rule available to compute a product of....
which tells you to define a product rule for two distributions (this step becomes easier if you have conjugate-priors).
To run the model above, I'll provide some more code. These are just examples to show you an idea how RxInfer.jl
works. You can of course define your own computations.
function freq_ratio(λ_prev, λ_present)
f_prev = λ_prev ./ sum(λ_prev)
f_present = λ_present ./ sum(λ_present)
γ = f_present ./ f_prev
return γ
end
function make_message(samples_A,d_in)
return let samples_A=samples_A,d_in=d_in
(x) -> begin
result = mapreduce(+, zip(samples_A,)) do (sampleA,)
return pdf(d_in,freq_ratio(x,sampleA))
end
return log(result)
end
end
end
function ReactiveMP.prod(::ProdAnalytical, left::Vector{Gamma{Float64}}, right::Array)
return left
end
function ReactiveMP.prod(::ProdAnalytical, left::Vector{GammaShapeScale{Float64}}, right::Vector{GammaShapeScale{Float64}})
T = ReactiveMP.promote_samplefloattype(left[1], right[1])
d = [GammaShapeScale(shape(left[i]) + shape(right[i]) - one(T), 1/(rate(left[i]) + rate(right[i]))) for i=1:length(left)] #array
return d
end
function ReactiveMP.prod(::ProdAnalytical, left::Vector{LogNormal{Float64}}, right::Vector{LogNormal{Float64}})
#you can define the rule by yourself
return left
end
@constraints function pop_dynamics_constraints()
q(γ,s,w) = q(γ)q(s)q(w) #mean-field constraint
end
I hope my answer is helpful for your work. If you have any question, I'm willing to help.
cheers.
Thank you so much @HoangMHNguyen! I will try this out and report back. Really, thank you for taking the time to do this.
Hey @mrazomej and @HoangMHNguyen! I will transfer this issue into discussion. We can continue there if needed.
I just learned about the project and still need to understand the inference pipeline fully. Nevertheless, I tried to adapt a
Turing.jl
model to work withRxInfer
, just to run into the wall that theLogNormal
distribution is not available. A simple solution would be to use aNormal
distribution and then exponentiate the result. It turns out thatexp
is also not implemented for random variables. I wish I had a deeper understanding and more time to make a PR to implement this. Is this something on your immediate to-do list?