Closed HiroIshida closed 2 years ago
Hi @HiroIshida! Thanks for trying out ForneyLab.jl
(FL).
Unfortunately, FL doesn't take advantage of "the repetitive graph structure" automatically. This means that you can't avoid StackOverflow
errors for your problem by coding it up as you did.
However, you have a couple of options to circumvent the problem. First, you can split your data into batches and use batch learning procedure, where the posterior of the previous batch is used as a prior for the next batch. Here is a code snippet for your problem (also check this issue and this demo)
using ForneyLab
import ProgressMeter
n_samples = 30000
A_data = [0.9 0.0 0.1; 0.1 0.9 0.0; 0.0 0.1 0.9] # Transition probabilities (some transitions are impossible)
B_data = [0.9 0.05 0.05; 0.05 0.9 0.05; 0.05 0.05 0.9] # Observation noise
s_0_data = [1.0, 0.0, 0.0] # Initial state
# Generate some data
s_data = Vector{Vector{Float64}}(undef, n_samples) # one-hot encoding of the states
x_data = Vector{Vector{Float64}}(undef, n_samples) # one-hot encoding of the observations
s_t_min_data = s_0_data
for t = 1:n_samples
a = A_data*s_t_min_data
s_data[t] = sample(ProbabilityDistribution(Categorical, p=a./sum(a))) # Simulate state transition
b = B_data*s_data[t]
x_data[t] = sample(ProbabilityDistribution(Categorical, p=b./sum(b))) # Simulate observation
s_t_min_data = s_data[t]
end
# Initialize batch size
batch_size = 100
n_batch = Int(n_samples/batch_size)
g = FactorGraph()
# clamp placeholders to pass priors as data
@RV A ~ Dirichlet(placeholder(:A_0, dims=(3, 3)))
@RV B ~ Dirichlet(placeholder(:B_0, dims=(3, 3)))
@RV s_0 ~ Categorical(1/3*ones(3))
s = Vector{Variable}(undef, batch_size) # one-hot coding
x = Vector{Variable}(undef, batch_size) # one-hot coding
s_t_min = s_0
for t in 1:batch_size
@RV s[t] ~ Transition(s_t_min, A)
@RV x[t] ~ Transition(s[t], B)
s_t_min = s[t]
placeholder(x[t], :x, index=t, dims=(3,))
end;
pfz = PosteriorFactorization(A, B, [s_0; s], ids=[:A, :B, :S])
algo = messagePassingAlgorithm(free_energy=true)
source_code = algorithmSourceCode(algo, free_energy=true);
eval(Meta.parse(source_code))
# Define values for prior statistics
A_0_prev = ones(3, 3)
B_0_prev = [10.0 1.0 1.0; 1.0 10.0 1.0; 1.0 1.0 10.0]
n_its = 10
s_ = []
A_ = []
B_ = []
F = Matrix{Float64}(undef, n_batch, n_its)
ProgressMeter.@showprogress for i in 1:batch_size:n_samples-batch_size
data = Dict(:x => x_data[i:i+batch_size], :A_0 => A_0_prev, :B_0 => B_0_prev)
marginals = Dict{Symbol, ProbabilityDistribution}(:A => vague(Dirichlet, (3,3)), :B => vague(Dirichlet, (3,3)))
for v in 1:n_its
stepS!(data, marginals)
stepB!(data, marginals)
stepA!(data, marginals)
# Compute FE for every batch
F[div(i, batch_size)+1, v] = freeEnergy(data, marginals)
end
# Extract posteriors
A_0_prev = marginals[:A].params[:a]
B_0_prev = marginals[:B].params[:a]
# Save posterior marginals for each batch
push!(s_, [marginals[:s_*t] for t in 1:batch_size])
push!(A_, marginals[:A])
push!(B_, marginals[:B])
end
I will leave data extraction for you.
In case you want to run inference for the full graph with 30000
observations, I refer you to another package that is developed in our lab ReactiveMP.jl. In particular, you can look up this demo which is exactly what you want. Note when creating big models in ReactiveMP.jl
you must supply an optional argument limit_stack_depth
to avoid StackOverflow
errors, e.g.
... = create_my_model(..., options = (limit_stack_depth = 100, ))
Hope this helps!
@albertpod
Thank you very much for the prompt response and suggestion! I tried your method, and it works.
I will also look deeper into ReactiveMP.jl (new programming paradigm is hard to learn but seems much faster than FornyLab)
@HiroIshida you are welcome.
Don't be discouraged by the engine of ReactiveMP.jl
. While it implements message passing differently, from the user's perspective it is very similar to ForneyLab.jl
. Reach us out at ReactiveMP repo if you have other questions.
Good luck!
First, thanks for this great project!
My question is about way to generate the code taking advantage of the repetitive graph structure. For example, in the HMM case, the message passing algorithm could be really short with for loops. However, when I execute the HMM demo varying the number of samples, I found that the amount of the generated source code linearly increased as the
n_sample
increases.Not only this will take long computational time, but also stackoveflow error occurs, when
n_sample
is large. In my applicationn_sample
would be around 30000, and it will be problematic. I think the mentioned problem can be avoided by generating the code using for-loops.So, my question is
Just for information, I paste the source code (copied and pasted from the demo and change
n_sample
) to reproduce the stackoverflow error.tmp.jl