ReactiveBayes / RxInfer.jl

Julia package for automated Bayesian inference on a factor graph with reactive message passing
MIT License
265 stars 23 forks source link

The `slide()` function in Active Inference Mountain Car example #329

Closed Flawless1202 closed 2 months ago

Flawless1202 commented 2 months ago

Thanks for your amazing work about Bayesian Inference! I have learned a lot from the example of 'Active Inference Mountain Car' example. But in the following code:

# We are going to use some private functionality from ReactiveMP, 
# in the future we should expose a proper API for this
import RxInfer.ReactiveMP: getrecent, messageout

function create_agent(;T = 20, Fg, Fa, Ff, engine_force_limit, x_target, initial_position, initial_velocity)
    Epsilon = fill(huge, 1, 1)                # Control prior variance
    m_u = Vector{Float64}[ [ 0.0] for k=1:T ] # Set control priors
    V_u = Matrix{Float64}[ Epsilon for k=1:T ]

    Sigma    = 1e-4*diageye(2) # Goal prior variance
    m_x      = [zeros(2) for k=1:T]
    V_x      = [huge*diageye(2) for k=1:T]
    V_x[end] = Sigma # Set prior to reach goal at t=T

    # Set initial brain state prior
    m_s_t_min = [initial_position, initial_velocity] 
    V_s_t_min = tiny * diageye(2)

    # Set current inference results
    result = nothing

    # The `infer` function is the heart of the agent
    # It calls the `RxInfer.inference` function to perform Bayesian inference by message passing
    compute = (upsilon_t::Float64, y_hat_t::Vector{Float64}) -> begin
        m_u[1] = [ upsilon_t ] # Register action with the generative model
        V_u[1] = fill(tiny, 1, 1) # Clamp control prior to performed action

        m_x[1] = y_hat_t # Register observation with the generative model
        V_x[1] = tiny*diageye(2) # Clamp goal prior to observation

        data = Dict(:m_u       => m_u, 
                    :V_u       => V_u, 
                    :m_x       => m_x, 
                    :V_x       => V_x,
                    :m_s_t_min => m_s_t_min,
                    :V_s_t_min => V_s_t_min)

        model  = mountain_car(T = T, Fg = Fg, Fa = Fa, Ff = Ff, engine_force_limit = engine_force_limit) 
        result = infer(model = model, data = data)

    # The `act` function returns the inferred best possible action
    act = () -> begin
        if result !== nothing
            return mode(result.posteriors[:u][2])[1]
            return 0.0 # Without inference result we return some 'random' action

    # The `future` function returns the inferred future states
    future = () -> begin 
        if result !== nothing 
            return getindex.(mode.(result.posteriors[:s]), 1)
            return zeros(T)

    # The `slide` function modifies the `(m_s_t_min, V_s_t_min)` for the next step
    # and shifts (or slides) the array of future goals `(m_x, V_x)` and inferred actions `(m_u, V_u)`
    slide = () -> begin

        model  = RxInfer.getmodel(result.model)
        (s, )  = RxInfer.getreturnval(model)
        varref = RxInfer.getvarref(model, s) 
        var    = RxInfer.getvariable(varref)

        slide_msg_idx = 3 # This index is model dependend
        (m_s_t_min, V_s_t_min) = mean_cov(getrecent(messageout(var[2], slide_msg_idx)))

        m_u = circshift(m_u, -1)
        m_u[end] = [0.0]
        V_u = circshift(V_u, -1)
        V_u[end] = Epsilon

        m_x = circshift(m_x, -1)
        m_x[end] = x_target
        V_x = circshift(V_x, -1)
        V_x[end] = Sigma

    return (compute, act, slide, future)    

I want to know if this line

(m_s_t_min, V_s_t_min) = mean_cov(getrecent(messageout(var[2], slide_msg_idx)))

could be replaced by the following line

(m_s_t_min, V_s_t_min) = mean_cov(result.posteriors[:x][2])

I have modified the code and get the following result: ai-mountain-car-ai Could you give me an explaination?

ThijsvdLaar commented 2 months ago

Hi @Flawless1202, thanks for trying it out and for your question. The short answer is no; unfortunately you cannot exchange the message for the marginal.

The longer answer: Informally, a message summarizes the information of the sub-graph that it leaves (aka "closing the box"). Therefore, the forward message (obtained by mean_cov(getrecent(messageout(var[2], slide_msg_idx)))) summarizes all past data (actions and observations). The marginal, obtained by mean_cov(result.posteriors[:x][2]), is the result of two colliding messages: the forward message that summarizes the past, and the backward message that summarizes the desired future. Therefore, the marginal is biased by the desired future and your current state estimate will be biased as well.

Hope this clears things up. For more details you can also consult the original paper:

albertpod commented 2 months ago

I suppose this can be moved into discussion ;)