ReactiveBayes / RxInfer.jl

Julia package for automated Bayesian inference on a factor graph with reactive message passing
MIT License
270 stars 23 forks source link

Chance constraint example would be helpful #68

Closed John-Boik closed 1 year ago

John-Boik commented 1 year ago

Is anyone planning to convert the ForneyLab code published for Chance-Constrained Active Inference, van de Laar et al. (2021), into code that will run in RxInfer? If so, great! What is the status?

If not, I might attempt to do so, but I'm still trying to learn RXInfer and could use some guidance to get started with the conversion. Below I'm using the RXInfer mountain car code as an example template. Am I correct that the steps to convert the code would include the following?

Are their major differences in coding a problem for ForneyLab vs. RXInfer, such that, for example, the function truncatedGaussianMoments would no longer have the signature truncatedGaussianMoments(m::Float64, V::Float64, a::Float64, b::Float64)? Or as another example, would an outgoing rule still have a signature as in ruleSPChanceConstraintOutG(msg_out::Message{<:Gaussian, Univariate}, G::Tuple, epsilon::Float64; atol=default_atol)? Is the use of functions such as unsafeMeanCov() and unsafeMode() still required?

Any suggestions/guidance getting started would be welcome. Is there any reason that converting this code would be particularly challenging?

wmkouw commented 1 year ago

Hi John-Boik, cool to see you're interested in Thijs' paper. He's the expert and will probably be able to guide you more. I just wanted to let you know that there are indeed plans on our side to incorporate chance constraints in RxInfer. The first hurdle to cross is the creation of a factor node corresponding to a TruncatedGaussian. That distribution exists in ForneyLab, but not in RxInfer / ReactiveMP. Those two inherit from Distributions.jl, which does implement it (link), so it shouldn't be a major challenge. The second hurdle is to specify the chance constraint appropriately, through the constraint specification procedure. That is probably harder for an external contributor.

John-Boik commented 1 year ago

Thanks @wmkouw. Perhaps you will want to move this discussion elsewhere. I've written code (below) for running the reference model contained in the @ThijsvdLaar chance constraints paper. I used the mountain car example as a template. Next I'm hoping to extend the reference model to the full chance constraint model. But I have a few questions:

  1. The slide function uses code mean_var(getrecent(messageout(x[2], slide_msg_idx))) that is taken from the mountain car example. I understand these are private ReactiveMP functions. Can you explain how the slide_msg_idx can be chosen? Is there some way to inspect what the indexes pertain to? Is there a better approach that does not use private functions?
  2. Regarding your comment about creating a factor node for a TruncatedGaussian, is that done only via the @node macro, or does the @rule macro also need to be employed?
  3. Do I understand correctly that to convert the reference model to the full chance constraint model, the only things required are (1) a new TruncatedGaussian node, and (2) appropriate constraints that are set using the @constraints macro?

Any suggestions/guidance regarding items 2 and 3 above would be welcomed.

The script below, which I save in a file RXref.jl will run in the julia REPL and produce results very close to the published ones. An example output graph is attached after the code.

#=
To run script in REPL:
    include("./RXRef.jl")
=#

module RXRef

import Plots
import Random

# use some private functionality from ReactiveMP, 
import RxInfer.ReactiveMP: getrecent, messageout

using Formatting
using Infiltrator
using Revise
using RxInfer

Random.seed!(51233) # Set random seed for reproducibility

Plots.scalefontsizes()
Plots.scalefontsizes(0.8)

# ==================================================================================================
# --------------------------------------------------------------------------------------------------
function createWorld(; v_wind)
    # functions for interacting with the simulated environment.

    x_0 = 0.0 # Initial position, drone elevation
    x_t0 = x_0
    x_t1 = x_0

    execute = (action_t::Float64, m_wind_t::Float64) -> begin
        # Execute the action
        # action_t = ascention velocity, dx = action_t * t, with t=1
        x_t1 = x_t0 + action_t + m_wind_t + sqrt(v_wind) * randn()  # Update elevation
        x_t0 = x_t1 # Prepare for next step
    end

    observe = () -> begin 
        return x_t0  # Observe the current state
    end

    return (execute, observe)
end

# --------------------------------------------------------------------------------------------------
@model function ref_model(; T)

    # the mean and variance of observation at the last t
    m_x_t_last = datavar(Float64)
    v_x_t_last = datavar(Float64)

    # sample the observation at the last t
    x_t_last ~ GaussianMeanVariance(m_x_t_last, v_x_t_last)  
    x_k_last = x_t_last

    # control
    m_u = datavar(Float64, T)
    v_u = datavar(Float64, T)

    # obs
    m_x = datavar(Float64, T)
    v_x = datavar(Float64, T)

    # wind
    m_w = datavar(Float64, T)
    v_w = datavar(Float64, T)

    # random variables
    u = randomvar(T)  # control
    x = randomvar(T)  # height
    uw = randomvar(T)  # control + wind variance

    # loop over horizon
    for k = 1:T
        x[k] ~ GaussianMeanVariance(m_x[k], v_x[k]) # goal prior
        u[k] ~ GaussianMeanVariance(m_u[k], v_u[k]) # control prior
        uw[k] ~ GaussianMeanVariance(u[k], v_w[k])  # control + wind variance
        x[k] ~ x_k_last + uw[k] + m_w[k]
        x_k_last = x[k]
    end
    return (x, )
end

# --------------------------------------------------------------------------------------------------
function createAgent(; T, fx_m_wind, v_wind, m_goal, v_goal, lambda)

    # control prior
    m_u = Vector{Float64}([0.0 for k=1:T ])
    v_u = Vector{Float64}([lambda^(-1) for k=1:T])

    # goal
    m_x = Vector{Float64}([m_goal for k=1:T])
    v_x = Vector{Float64}([v_goal for k=1:T])  

    # wind
    m_w = Vector{Float64}([0.0 for k=1:T])
    v_w = Vector{Float64}([v_wind for k=1:T])

    # initial position and variance
    m_x_t_last = 0.0
    v_x_t_last = convert(Float64, tiny)

    # Set current inference results
    result = nothing

    #=
    The `infer` function is the heart of the agent. It calls the `inference` function to perform 
    Bayesian inference by message passing.
    =#
    infer = (agent_action_t::Float64, agent_obs_t::Float64, t::Int) -> begin
        m_w[:] = [fx_m_wind(t+k) for k in 0:T-1]  # mean wind
        m_u[1] = agent_action_t  # register action with the generative model
        v_u[1] = tiny  # clamp control prior to performed action
        m_x[1] = agent_obs_t  # register observation with the generative model
        v_x[1] = tiny  # clamp goal prior to observation

        data = Dict(:m_u => m_u, 
                    :v_u => v_u, 
                    :m_x => m_x, 
                    :v_x => v_x,
                    :m_w => m_w,
                    :v_w => v_w,
                    :m_x_t_last => m_x_t_last,
                    :v_x_t_last => v_x_t_last
                )

        model  = ref_model(; T=T) 

        result = inference(
            model = model, 
            data = data,
            initmarginals = (
                u = NormalMeanVariance(0.0, huge),
                uw = NormalMeanVariance(0.0, huge),
                x = NormalMeanVariance(0.0, huge),
                x_t_last = NormalMeanVariance(0.0, huge),
            ),
        ) 
        #@infiltrate; @assert false
    end

    # The `act` function returns the inferred best possible action
    act = () -> begin
        if result !== nothing
            return mode(result.posteriors[:u][2])[1]
        else
            return 0.0 # Without inference result we return some 'random' action
        end
    end

    # The `future` function returns the inferred future states
    future = () -> begin 
        if result !== nothing 
            return getindex.(mode.(result.posteriors[:x]), 1) 
        else
            return zeros(T)
        end
    end

    #=
    The `slide` function modifies the `(m_x_t_last, v_x_t_last)` for the next step
    and shifts (or slides) the array of future goals `(m_x, V_x)` and inferred actions `(m_u, V_u)`
    =#
    slide = () -> begin
        (x, ) = result.returnval
        slide_msg_idx = 3 # This index is model dependent
        (m_x_t_last, v_x_t_last) = mean_var(getrecent(messageout(x[2], slide_msg_idx)))

        # these are not actually necessary for this simple problem, as the vectors do not change
        m_u = circshift(m_u, -1)
        m_u[end] = 0.0

        v_u = circshift(v_u, -1)
        v_u[end] = lambda^(-1)

        m_x = circshift(m_x, -1)
        m_x[end] = m_goal

        v_x = circshift(v_x, -1)
        v_x[end] = v_goal
    end

    # @infiltrate; @assert false
    return (infer, act, slide, future)    
end

# --------------------------------------------------------------------------------------------------
function plotTrajectory(obs, act, per_n, fx_m_wind, epsilon)

    (L,N) = size(obs)

    # mean of wind
    p1 = Plots.scatter(0:L, fx_m_wind.(0:L), color="black", xlim= (0,L), grid=true, ylabel="E[w]", legend=false)

    # Trajectory per run
    p2 = Plots.plot(legend=false)
    Plots.hline!(p2, [1.0], color="red", ls=:dash)
    for n=1:per_n:N
        Plots.plot!(1:L, obs[:,n], color="black", alpha=0.1)
    end
    Plots.plot!(xlim=(1,L), ylim=(-5.5, 5.5), grid=true, ylabel="Elevation (x)")

    # Control signal per run
    p3 = Plots.plot(legend=false)
    for n=1:per_n:N
        Plots.plot!(1:L, act[:,n], color="black", alpha=0.1)
    end
    Plots.plot!(xlim= (1,L), ylim=(-1.5, 2.5), grid=true, ylabel= "Control Signal (a)")

    # Violation ratio (of runs) over time
    p4= Plots.plot(legend=false)
    Plots.hline!(p4, [epsilon], color="red", ls=:dash)
    r = vec(mean(obs .< 1.0, dims=2))
    Plots.scatter!(1:L, r, color="black")
    Plots.plot!(xlim=(1,L), ylim=(0, 0.05), grid=true, xlabel="Time (t)", ylabel="Target Violation Ratio")

    Plots.plot(p1,p2,p3,p4, layout=(4,1), size=(800,1000), left_margin=5Plots.mm, )
    Plots.savefig("./sim_reference.png")
    #@infiltrate; @assert false
end

# --------------------------------------------------------------------------------------------------
function main()
    #=
    no hidden states, agent directly observes its elevation and wind velocity mean/variance
    x = elevation (e.g., meters)
    T = time for look-ahead horizon (e.g., seconds)
    action = selected control value, vertical velocity (e.g., m/s)
    t = current moment
    k = time steps in look ahead horizon, k= t:t+T-1
    m_w = mean of vertical wind velocity (e.g., m/s)  
    v_w = variance of vertical wind velocity (e.g, m^2)
    u = control variable = action = vertical velocity (e.g, m/s)
    L = simulation time (e.g., seconds)

    m_ = mean of _
    v_ = variance of _
    =#

    # Simulation parameters
    L = 20
    T = 10
    v_wind = 0.2
    m_goal = 2.0
    v_goal = 0.18478
    lambda = 0.01  # control prior precision
    epsilon = 0.01  # unsafe mass

    fx_m_wind(t::Int64) = 5<=t<10 ? -1.0 : 0.0 # Wind mean as function of time t
    N = 200  # number of trials

    # Step through experimental protocol
    actions =  Matrix{Float64}(undef, L, N) 
    observations = Matrix{Float64}(undef, L, N) 

    for n in 1:N
        # Let there be a world
        (execute_ai, observe_ai) = createWorld(
            v_wind = v_wind,
        ) 

        # Let there be an agent
        (infer_ai, act_ai, slide_ai, future_ai) = createAgent(; 
            T  = T, 
            fx_m_wind = fx_m_wind,
            v_wind = v_wind,
            m_goal = m_goal, 
            v_goal = v_goal, 
            lambda = lambda,
        ) 

        for t=1:L
            actions[t,n] = act_ai()  # invoke an action from the agent
            futures = future_ai()  # fetch the predicted future states
            execute_ai(actions[t,n], fx_m_wind(t))  # the action influences hidden external states
            observations[t,n] = observe_ai()  # observe the current environmental outcome
            infer_ai(actions[t,n], observations[t,n], t)  # infer beliefs from current model state
            slide_ai()  # prepare for next iteration
        end
    end

    E_avg = round(mean(sum(actions.^2, dims=1)), digits=2) # Average quadratic cost of control (over all runs)
    @show E_avg

    per_n = ceil(Int, N/100) # Plot one in every per_n trajectories
    plotTrajectory(observations, actions, per_n, fx_m_wind, epsilon)

end

end  # module ----------------------------------------------------------

RXRef.main()

An example output graph is attached.

sim_reference

John-Boik commented 1 year ago

This discussion continues on a discourse post.

John-Boik commented 1 year ago

Minor issue, but it looks like xi_fw will not be defined if the constraint is not active.

wmkouw commented 1 year ago

Nice catch. Thanks, @John-Boik.

I pushed a small update and sent a new PR.

bvdmitri commented 1 year ago

Addressed by #87 and #89