Closed John-Boik closed 1 year ago
Hi John-Boik, cool to see you're interested in Thijs' paper. He's the expert and will probably be able to guide you more. I just wanted to let you know that there are indeed plans on our side to incorporate chance constraints in RxInfer. The first hurdle to cross is the creation of a factor node corresponding to a TruncatedGaussian. That distribution exists in ForneyLab, but not in RxInfer / ReactiveMP. Those two inherit from Distributions.jl, which does implement it (link), so it shouldn't be a major challenge. The second hurdle is to specify the chance constraint appropriately, through the constraint specification procedure. That is probably harder for an external contributor.
Thanks @wmkouw. Perhaps you will want to move this discussion elsewhere. I've written code (below) for running the reference model contained in the @ThijsvdLaar chance constraints paper. I used the mountain car example as a template. Next I'm hoping to extend the reference model to the full chance constraint model. But I have a few questions:
mean_var(getrecent(messageout(x[2], slide_msg_idx)))
that is taken from the mountain car example. I understand these are private ReactiveMP functions. Can you explain how the slide_msg_idx can be chosen? Is there some way to inspect what the indexes pertain to? Is there a better approach that does not use private functions?Any suggestions/guidance regarding items 2 and 3 above would be welcomed.
The script below, which I save in a file RXref.jl
will run in the julia REPL and produce results very close to the published ones. An example output graph is attached after the code.
#=
To run script in REPL:
include("./RXRef.jl")
=#
module RXRef
import Plots
import Random
# use some private functionality from ReactiveMP,
import RxInfer.ReactiveMP: getrecent, messageout
using Formatting
using Infiltrator
using Revise
using RxInfer
Random.seed!(51233) # Set random seed for reproducibility
Plots.scalefontsizes()
Plots.scalefontsizes(0.8)
# ==================================================================================================
# --------------------------------------------------------------------------------------------------
function createWorld(; v_wind)
# functions for interacting with the simulated environment.
x_0 = 0.0 # Initial position, drone elevation
x_t0 = x_0
x_t1 = x_0
execute = (action_t::Float64, m_wind_t::Float64) -> begin
# Execute the action
# action_t = ascention velocity, dx = action_t * t, with t=1
x_t1 = x_t0 + action_t + m_wind_t + sqrt(v_wind) * randn() # Update elevation
x_t0 = x_t1 # Prepare for next step
end
observe = () -> begin
return x_t0 # Observe the current state
end
return (execute, observe)
end
# --------------------------------------------------------------------------------------------------
@model function ref_model(; T)
# the mean and variance of observation at the last t
m_x_t_last = datavar(Float64)
v_x_t_last = datavar(Float64)
# sample the observation at the last t
x_t_last ~ GaussianMeanVariance(m_x_t_last, v_x_t_last)
x_k_last = x_t_last
# control
m_u = datavar(Float64, T)
v_u = datavar(Float64, T)
# obs
m_x = datavar(Float64, T)
v_x = datavar(Float64, T)
# wind
m_w = datavar(Float64, T)
v_w = datavar(Float64, T)
# random variables
u = randomvar(T) # control
x = randomvar(T) # height
uw = randomvar(T) # control + wind variance
# loop over horizon
for k = 1:T
x[k] ~ GaussianMeanVariance(m_x[k], v_x[k]) # goal prior
u[k] ~ GaussianMeanVariance(m_u[k], v_u[k]) # control prior
uw[k] ~ GaussianMeanVariance(u[k], v_w[k]) # control + wind variance
x[k] ~ x_k_last + uw[k] + m_w[k]
x_k_last = x[k]
end
return (x, )
end
# --------------------------------------------------------------------------------------------------
function createAgent(; T, fx_m_wind, v_wind, m_goal, v_goal, lambda)
# control prior
m_u = Vector{Float64}([0.0 for k=1:T ])
v_u = Vector{Float64}([lambda^(-1) for k=1:T])
# goal
m_x = Vector{Float64}([m_goal for k=1:T])
v_x = Vector{Float64}([v_goal for k=1:T])
# wind
m_w = Vector{Float64}([0.0 for k=1:T])
v_w = Vector{Float64}([v_wind for k=1:T])
# initial position and variance
m_x_t_last = 0.0
v_x_t_last = convert(Float64, tiny)
# Set current inference results
result = nothing
#=
The `infer` function is the heart of the agent. It calls the `inference` function to perform
Bayesian inference by message passing.
=#
infer = (agent_action_t::Float64, agent_obs_t::Float64, t::Int) -> begin
m_w[:] = [fx_m_wind(t+k) for k in 0:T-1] # mean wind
m_u[1] = agent_action_t # register action with the generative model
v_u[1] = tiny # clamp control prior to performed action
m_x[1] = agent_obs_t # register observation with the generative model
v_x[1] = tiny # clamp goal prior to observation
data = Dict(:m_u => m_u,
:v_u => v_u,
:m_x => m_x,
:v_x => v_x,
:m_w => m_w,
:v_w => v_w,
:m_x_t_last => m_x_t_last,
:v_x_t_last => v_x_t_last
)
model = ref_model(; T=T)
result = inference(
model = model,
data = data,
initmarginals = (
u = NormalMeanVariance(0.0, huge),
uw = NormalMeanVariance(0.0, huge),
x = NormalMeanVariance(0.0, huge),
x_t_last = NormalMeanVariance(0.0, huge),
),
)
#@infiltrate; @assert false
end
# The `act` function returns the inferred best possible action
act = () -> begin
if result !== nothing
return mode(result.posteriors[:u][2])[1]
else
return 0.0 # Without inference result we return some 'random' action
end
end
# The `future` function returns the inferred future states
future = () -> begin
if result !== nothing
return getindex.(mode.(result.posteriors[:x]), 1)
else
return zeros(T)
end
end
#=
The `slide` function modifies the `(m_x_t_last, v_x_t_last)` for the next step
and shifts (or slides) the array of future goals `(m_x, V_x)` and inferred actions `(m_u, V_u)`
=#
slide = () -> begin
(x, ) = result.returnval
slide_msg_idx = 3 # This index is model dependent
(m_x_t_last, v_x_t_last) = mean_var(getrecent(messageout(x[2], slide_msg_idx)))
# these are not actually necessary for this simple problem, as the vectors do not change
m_u = circshift(m_u, -1)
m_u[end] = 0.0
v_u = circshift(v_u, -1)
v_u[end] = lambda^(-1)
m_x = circshift(m_x, -1)
m_x[end] = m_goal
v_x = circshift(v_x, -1)
v_x[end] = v_goal
end
# @infiltrate; @assert false
return (infer, act, slide, future)
end
# --------------------------------------------------------------------------------------------------
function plotTrajectory(obs, act, per_n, fx_m_wind, epsilon)
(L,N) = size(obs)
# mean of wind
p1 = Plots.scatter(0:L, fx_m_wind.(0:L), color="black", xlim= (0,L), grid=true, ylabel="E[w]", legend=false)
# Trajectory per run
p2 = Plots.plot(legend=false)
Plots.hline!(p2, [1.0], color="red", ls=:dash)
for n=1:per_n:N
Plots.plot!(1:L, obs[:,n], color="black", alpha=0.1)
end
Plots.plot!(xlim=(1,L), ylim=(-5.5, 5.5), grid=true, ylabel="Elevation (x)")
# Control signal per run
p3 = Plots.plot(legend=false)
for n=1:per_n:N
Plots.plot!(1:L, act[:,n], color="black", alpha=0.1)
end
Plots.plot!(xlim= (1,L), ylim=(-1.5, 2.5), grid=true, ylabel= "Control Signal (a)")
# Violation ratio (of runs) over time
p4= Plots.plot(legend=false)
Plots.hline!(p4, [epsilon], color="red", ls=:dash)
r = vec(mean(obs .< 1.0, dims=2))
Plots.scatter!(1:L, r, color="black")
Plots.plot!(xlim=(1,L), ylim=(0, 0.05), grid=true, xlabel="Time (t)", ylabel="Target Violation Ratio")
Plots.plot(p1,p2,p3,p4, layout=(4,1), size=(800,1000), left_margin=5Plots.mm, )
Plots.savefig("./sim_reference.png")
#@infiltrate; @assert false
end
# --------------------------------------------------------------------------------------------------
function main()
#=
no hidden states, agent directly observes its elevation and wind velocity mean/variance
x = elevation (e.g., meters)
T = time for look-ahead horizon (e.g., seconds)
action = selected control value, vertical velocity (e.g., m/s)
t = current moment
k = time steps in look ahead horizon, k= t:t+T-1
m_w = mean of vertical wind velocity (e.g., m/s)
v_w = variance of vertical wind velocity (e.g, m^2)
u = control variable = action = vertical velocity (e.g, m/s)
L = simulation time (e.g., seconds)
m_ = mean of _
v_ = variance of _
=#
# Simulation parameters
L = 20
T = 10
v_wind = 0.2
m_goal = 2.0
v_goal = 0.18478
lambda = 0.01 # control prior precision
epsilon = 0.01 # unsafe mass
fx_m_wind(t::Int64) = 5<=t<10 ? -1.0 : 0.0 # Wind mean as function of time t
N = 200 # number of trials
# Step through experimental protocol
actions = Matrix{Float64}(undef, L, N)
observations = Matrix{Float64}(undef, L, N)
for n in 1:N
# Let there be a world
(execute_ai, observe_ai) = createWorld(
v_wind = v_wind,
)
# Let there be an agent
(infer_ai, act_ai, slide_ai, future_ai) = createAgent(;
T = T,
fx_m_wind = fx_m_wind,
v_wind = v_wind,
m_goal = m_goal,
v_goal = v_goal,
lambda = lambda,
)
for t=1:L
actions[t,n] = act_ai() # invoke an action from the agent
futures = future_ai() # fetch the predicted future states
execute_ai(actions[t,n], fx_m_wind(t)) # the action influences hidden external states
observations[t,n] = observe_ai() # observe the current environmental outcome
infer_ai(actions[t,n], observations[t,n], t) # infer beliefs from current model state
slide_ai() # prepare for next iteration
end
end
E_avg = round(mean(sum(actions.^2, dims=1)), digits=2) # Average quadratic cost of control (over all runs)
@show E_avg
per_n = ceil(Int, N/100) # Plot one in every per_n trajectories
plotTrajectory(observations, actions, per_n, fx_m_wind, epsilon)
end
end # module ----------------------------------------------------------
RXRef.main()
An example output graph is attached.
This discussion continues on a discourse post.
Minor issue, but it looks like xi_fw will not be defined if the constraint is not active.
Nice catch. Thanks, @John-Boik.
I pushed a small update and sent a new PR.
Addressed by #87 and #89
Is anyone planning to convert the ForneyLab code published for Chance-Constrained Active Inference, van de Laar et al. (2021), into code that will run in RxInfer? If so, great! What is the status?
If not, I might attempt to do so, but I'm still trying to learn RXInfer and could use some guidance to get started with the conversion. Below I'm using the RXInfer mountain car code as an example template. Am I correct that the steps to convert the code would include the following?
initializeAgent()
code in chance_constrained_agent.jl as is into the newcreate_agent()
function. However, the former returns only(infer, act)
while the latter returns(infer, act, slide, future)
. I assume I would need to include at least a simpleslide()
function, but perhaps not afuture()
function as there are no future target observations in this problem.initializeWorld()
code in environment.jl as is into thecreate_world()
function. Both return(execute, observe)
as required.@rule
macro.@node
macro.isApplicable()
function in updates.jl mostly as is.Are their major differences in coding a problem for ForneyLab vs. RXInfer, such that, for example, the function
truncatedGaussianMoments
would no longer have the signaturetruncatedGaussianMoments(m::Float64, V::Float64, a::Float64, b::Float64)
? Or as another example, would an outgoing rule still have a signature as inruleSPChanceConstraintOutG(msg_out::Message{<:Gaussian, Univariate}, G::Tuple, epsilon::Float64; atol=default_atol)
? Is the use of functions such asunsafeMeanCov()
andunsafeMode()
still required?Any suggestions/guidance getting started would be welcome. Is there any reason that converting this code would be particularly challenging?