biaslab / ForneyLab.jl

Julia package for automatically generating Bayesian inference algorithms through message passing on Forney-style factor graphs.
MIT License
149 stars 35 forks source link

code generation for repetitive graph structure #179

Closed HiroIshida closed 2 years ago

HiroIshida commented 2 years ago

First, thanks for this great project!

My question is about way to generate the code taking advantage of the repetitive graph structure. For example, in the HMM case, the message passing algorithm could be really short with for loops. However, when I execute the HMM demo varying the number of samples, I found that the amount of the generated source code linearly increased as the n_sample increases.

Not only this will take long computational time, but also stackoveflow error occurs, when n_sample is large. In my application n_sample would be around 30000, and it will be problematic. I think the mentioned problem can be avoided by generating the code using for-loops.

So, my question is

Just for information, I paste the source code (copied and pasted from the demo and change n_sample) to reproduce the stackoverflow error.

tmp.jl

using ForneyLab

function main()
    n_samples = 30000
    A_data = [0.9 0.0 0.1; 0.1 0.9 0.0; 0.0 0.1 0.9] # Transition probabilities (some transitions are impossible)
    B_data = [0.9 0.05 0.05; 0.05 0.9 0.05; 0.05 0.05 0.9] # Observation noise

    s_0_data = [1.0, 0.0, 0.0] # Initial state

    # Generate some data
    s_data = Vector{Vector{Float64}}(undef, n_samples) # one-hot encoding of the states
    x_data = Vector{Vector{Float64}}(undef, n_samples) # one-hot encoding of the observations
    s_t_min_data = s_0_data
    for t = 1:n_samples
        a = A_data*s_t_min_data
        s_data[t] = sample(ProbabilityDistribution(Categorical, p=a./sum(a))) # Simulate state transition
        b = B_data*s_data[t]
        x_data[t] = sample(ProbabilityDistribution(Categorical, p=b./sum(b))) # Simulate observation

        s_t_min_data = s_data[t]
    end

    g = FactorGraph()

    @RV A ~ Dirichlet(ones(3,3)) # Vague prior on transition model
    @RV B ~ Dirichlet([10.0 1.0 1.0; 1.0 10.0 1.0; 1.0 1.0 10.0]) # Stronger prior on observation model
    @RV s_0 ~ Categorical(1/3*ones(3))

    s = Vector{Variable}(undef, n_samples) # one-hot coding
    x = Vector{Variable}(undef, n_samples) # one-hot coding
    s_t_min = s_0
    for t = 1:n_samples
        @RV s[t] ~ Transition(s_t_min, A)
        @RV x[t] ~ Transition(s[t], B)

        s_t_min = s[t]

        placeholder(x[t], :x, index=t, dims=(3,))
    end;

    pfz = PosteriorFactorization(A, B, [s_0; s], ids=[:A, :B, :S])
    algo = messagePassingAlgorithm(free_energy=true)
    source_code = algorithmSourceCode(algo, free_energy=true);
end
main()
julia> include("tmp.jl")
ERROR: LoadError: StackOverflowError:
Stacktrace:
 [1] promote_eltype(::Set{Edge}, ::Set{Edge}, ::Vararg{Set{Edge}, N} where N) (repeats 39991 times)
   @ Base ./abstractarray.jl:1463
in expression starting at /home/h-ishida/documents/julia/ForneyLab.jl/tmp.jl:45
albertpod commented 2 years ago

Hi @HiroIshida! Thanks for trying out ForneyLab.jl (FL).

Unfortunately, FL doesn't take advantage of "the repetitive graph structure" automatically. This means that you can't avoid StackOverflow errors for your problem by coding it up as you did.

However, you have a couple of options to circumvent the problem. First, you can split your data into batches and use batch learning procedure, where the posterior of the previous batch is used as a prior for the next batch. Here is a code snippet for your problem (also check this issue and this demo)

using ForneyLab
import ProgressMeter

n_samples = 30000
A_data = [0.9 0.0 0.1; 0.1 0.9 0.0; 0.0 0.1 0.9] # Transition probabilities (some transitions are impossible)
B_data = [0.9 0.05 0.05; 0.05 0.9 0.05; 0.05 0.05 0.9] # Observation noise

s_0_data = [1.0, 0.0, 0.0] # Initial state

# Generate some data
s_data = Vector{Vector{Float64}}(undef, n_samples) # one-hot encoding of the states
x_data = Vector{Vector{Float64}}(undef, n_samples) # one-hot encoding of the observations
s_t_min_data = s_0_data
for t = 1:n_samples
    a = A_data*s_t_min_data
    s_data[t] = sample(ProbabilityDistribution(Categorical, p=a./sum(a))) # Simulate state transition
    b = B_data*s_data[t]
    x_data[t] = sample(ProbabilityDistribution(Categorical, p=b./sum(b))) # Simulate observation

    s_t_min_data = s_data[t]
end

# Initialize batch size
batch_size = 100
n_batch = Int(n_samples/batch_size)

g = FactorGraph()

# clamp placeholders to pass priors as data
@RV A ~ Dirichlet(placeholder(:A_0, dims=(3, 3))) 
@RV B ~ Dirichlet(placeholder(:B_0, dims=(3, 3)))
@RV s_0 ~ Categorical(1/3*ones(3))

s = Vector{Variable}(undef, batch_size) # one-hot coding
x = Vector{Variable}(undef, batch_size) # one-hot coding
s_t_min = s_0
for t in 1:batch_size
    @RV s[t] ~ Transition(s_t_min, A)
    @RV x[t] ~ Transition(s[t], B)

    s_t_min = s[t]

    placeholder(x[t], :x, index=t, dims=(3,))
end;

pfz = PosteriorFactorization(A, B, [s_0; s], ids=[:A, :B, :S])
algo = messagePassingAlgorithm(free_energy=true)
source_code = algorithmSourceCode(algo, free_energy=true);
eval(Meta.parse(source_code))

# Define values for prior statistics
A_0_prev = ones(3, 3)
B_0_prev = [10.0 1.0 1.0; 1.0 10.0 1.0; 1.0 1.0 10.0]

n_its = 10

s_ = []
A_ = []
B_ = []
F = Matrix{Float64}(undef, n_batch, n_its)

ProgressMeter.@showprogress for i in 1:batch_size:n_samples-batch_size
    data = Dict(:x => x_data[i:i+batch_size], :A_0 => A_0_prev, :B_0 => B_0_prev)
    marginals = Dict{Symbol, ProbabilityDistribution}(:A => vague(Dirichlet, (3,3)), :B => vague(Dirichlet, (3,3)))
    for v in 1:n_its
        stepS!(data, marginals)
        stepB!(data, marginals)
        stepA!(data, marginals)

        # Compute FE for every batch
        F[div(i, batch_size)+1, v] = freeEnergy(data, marginals)
    end

    # Extract posteriors
    A_0_prev = marginals[:A].params[:a]
    B_0_prev = marginals[:B].params[:a]

    # Save posterior marginals for each batch
    push!(s_, [marginals[:s_*t] for t in 1:batch_size])
    push!(A_, marginals[:A])
    push!(B_, marginals[:B])
end

I will leave data extraction for you.

In case you want to run inference for the full graph with 30000 observations, I refer you to another package that is developed in our lab ReactiveMP.jl. In particular, you can look up this demo which is exactly what you want. Note when creating big models in ReactiveMP.jl you must supply an optional argument limit_stack_depth to avoid StackOverflow errors, e.g.

... = create_my_model(..., options = (limit_stack_depth = 100, ))

Hope this helps!

HiroIshida commented 2 years ago

@albertpod
Thank you very much for the prompt response and suggestion! I tried your method, and it works. I will also look deeper into ReactiveMP.jl (new programming paradigm is hard to learn but seems much faster than FornyLab)

albertpod commented 2 years ago

@HiroIshida you are welcome. Don't be discouraged by the engine of ReactiveMP.jl. While it implements message passing differently, from the user's perspective it is very similar to ForneyLab.jl. Reach us out at ReactiveMP repo if you have other questions. Good luck!