Closed tyler-ingebrand closed 2 years ago
Update: Spent some time editing PPO's derivation phase. I ended up making numerous changes, so I am going to copy paste the entire file(s) for now. It seems to be working (IE it trains) although it may be failing silently somehow. Here is my script again (few changes):
# ---
# title: JuliaRL\_PPO\_Pendulum
# cover: assets/JuliaRL_PPO_Pendulum.png
# description: PPO applied to Pendulum
# date: 2021-05-22
# author: "[Jun Tian](https://github.com/findmyway)"
# ---
#+ tangle=true
using ReinforcementLearning
using StableRNGs
using Flux
using Flux.Losses
using Distributions
using IntervalSets
function RL.Experiment(
::Val{:JuliaRL},
::Val{:PPO},
::Val{:Pendulum},
::Nothing;
save_dir = nothing,
seed = 123,
)
rng = StableRNG(seed)
inner_env = PendulumEnv(T = Float32, rng = rng)
A = action_space(inner_env)
low = A.left
high = A.right
ns = length(state(inner_env))
N_ENV = 8
UPDATE_FREQ = 2048
env = MultiThreadEnv([
PendulumEnv(T = Float32, rng = StableRNG(hash(seed + i))) |>
env -> ActionTransformedEnv(env,
action_mapping = x -> clamp(x[1] * 2, low, high),
action_space_mapping = x -> Space([-2.0 .. 2.0, -2.0 .. 2.0])) for i in 1:N_ENV
])
na = 2
init = glorot_uniform(rng)
agent = Agent(
policy = PPOPolicy(
approximator = ActorCritic(
actor = GaussianNetwork(
pre = Chain(
Dense(ns, 64, relu; init = glorot_uniform(rng)),
Dense(64, 64, relu; init = glorot_uniform(rng)),
),
μ = Chain(Dense(64, na, tanh; init = glorot_uniform(rng)), vec),
logσ = Chain(Dense(64, na; init = glorot_uniform(rng)), vec),
),
critic = Chain(
Dense(ns, 64, relu; init = glorot_uniform(rng)),
Dense(64, 64, relu; init = glorot_uniform(rng)),
Dense(64, 1; init = glorot_uniform(rng)),
),
optimizer = ADAM(3e-4),
) |> gpu,
γ = 0.99f0,
λ = 0.95f0,
clip_range = 0.2f0,
max_grad_norm = 0.5f0,
n_epochs = 10,
n_microbatches = 32,
actor_loss_weight = 1.0f0,
critic_loss_weight = 0.5f0,
entropy_loss_weight = 0.00f0,
dist = Normal,
rng = rng,
update_freq = UPDATE_FREQ,
),
trajectory = PPOTrajectory(;
capacity = UPDATE_FREQ,
state = Matrix{Float32} => (ns, N_ENV),
action = Matrix{Float32} => (na, N_ENV),
action_log_prob = Matrix{Float32} => (na, N_ENV),
reward = Vector{Float32} => (N_ENV,),
terminal = Vector{Bool} => (N_ENV,),
),
)
stop_condition = StopAfterStep(100_000, is_show_progress=!haskey(ENV, "CI"))
hook = TotalBatchRewardPerEpisode(N_ENV)
Experiment(agent, env, stop_condition, hook, "# Play Pendulum with PPO")
end
#+ tangle=false
using Plots
ex = E`JuliaRL_PPO_Pendulum`
run(ex)
n = minimum(map(length, ex.hook.rewards))
m = mean([@view(x[1:n]) for x in ex.hook.rewards])
s = std([@view(x[1:n]) for x in ex.hook.rewards])
plot(m,ribbon=s)
# 153, 196
inner_env = PendulumEnv()
A = action_space(inner_env)
low = A.left
high = A.right
env = ActionTransformedEnv( inner_env,
action_mapping = x -> clamp(x[1] * 2, low, high),
action_space_mapping = x -> Space([-2.0 .. 2.0, -2.0 .. 2.0]))
demo = Experiment( ex.policy,
env,
StopWhenDone(),
RolloutHook(display, closeall),
"PPO <-> Demo")
run(demo)
And here is the updated PPO.jl (lots of changes, especially in the derivation stuff):
export PPOPolicy, PPOTrajectory, MaskedPPOTrajectory
const PPOTrajectory = Trajectory{
<:NamedTuple{
(:action_log_prob, SART...),
<:Tuple{
<:CircularArrayBuffer,
<:CircularArrayBuffer,
<:CircularArrayBuffer,
<:CircularArrayBuffer,
<:CircularArrayBuffer,
},
},
}
function PPOTrajectory(; capacity, action_log_prob, kwargs...)
merge(
CircularArrayTrajectory(;
capacity = capacity + 1,
action_log_prob = action_log_prob,
),
CircularArraySARTTrajectory(; capacity = capacity, kwargs...),
)
end
const MaskedPPOTrajectory = Trajectory{
<:NamedTuple{
(:action_log_prob, SLART...),
<:Tuple{
<:CircularArrayBuffer,
<:CircularArrayBuffer,
<:CircularArrayBuffer,
<:CircularArrayBuffer,
<:CircularArrayBuffer,
<:CircularArrayBuffer,
},
},
}
function MaskedPPOTrajectory(; capacity, action_log_prob, kwargs...)
merge(
CircularArrayTrajectory(;
capacity = capacity + 1,
action_log_prob = action_log_prob,
),
CircularArraySLARTTrajectory(; capacity = capacity, kwargs...),
)
end
function Base.length(t::Union{PPOTrajectory,MaskedPPOTrajectory})
x = t[:terminal]
size(x, ndims(x))
end
"""
PPOPolicy(;kwargs)
# Keyword arguments
- `approximator`,
- `γ = 0.99f0`,
- `λ = 0.95f0`,
- `clip_range = 0.2f0`,
- `max_grad_norm = 0.5f0`,
- `n_microbatches = 4`,
- `n_epochs = 4`,
- `actor_loss_weight = 1.0f0`,
- `critic_loss_weight = 0.5f0`,
- `entropy_loss_weight = 0.01f0`,
- `dist = Categorical`,
- `rng = Random.GLOBAL_RNG`,
By default, `dist` is set to `Categorical`, which means it will only works
on environments of discrete actions. To work with environments of continuous
actions `dist` should be set to `Normal` and the `actor` in the `approximator`
should be a `GaussianNetwork`. Using it with a `GaussianNetwork` supports
multi-dimensional action spaces, though it only supports it under the assumption
that the dimensions are independent since the `GaussianNetwork` outputs a single
`μ` and `σ` for each dimension which is used to simplify the calculations.
"""
mutable struct PPOPolicy{A<:ActorCritic,D,R} <: AbstractPolicy
approximator::A
γ::Float32
λ::Float32
clip_range::Float32
max_grad_norm::Float32
n_microbatches::Int
n_epochs::Int
actor_loss_weight::Float32
critic_loss_weight::Float32
entropy_loss_weight::Float32
rng::R
n_random_start::Int
update_freq::Int
update_step::Int
# for logging
norm::Matrix{Float32}
actor_loss::Matrix{Float32}
critic_loss::Matrix{Float32}
entropy_loss::Matrix{Float32}
loss::Matrix{Float32}
end
function PPOPolicy(;
approximator,
update_freq,
n_random_start = 0,
update_step = 0,
γ = 0.99f0,
λ = 0.95f0,
clip_range = 0.2f0,
max_grad_norm = 0.5f0,
n_microbatches = 4,
n_epochs = 4,
actor_loss_weight = 1.0f0,
critic_loss_weight = 0.5f0,
entropy_loss_weight = 0.01f0,
dist = Categorical,
rng = Random.GLOBAL_RNG,
)
PPOPolicy{typeof(approximator),dist,typeof(rng)}(
approximator,
γ,
λ,
clip_range,
max_grad_norm,
n_microbatches,
n_epochs,
actor_loss_weight,
critic_loss_weight,
entropy_loss_weight,
rng,
n_random_start,
update_freq,
update_step,
zeros(Float32, n_microbatches, n_epochs),
zeros(Float32, n_microbatches, n_epochs),
zeros(Float32, n_microbatches, n_epochs),
zeros(Float32, n_microbatches, n_epochs),
zeros(Float32, n_microbatches, n_epochs),
)
end
function RLBase.prob(
p::PPOPolicy{<:ActorCritic{<:GaussianNetwork},Normal},
state::AbstractArray,
mask,
)
if p.update_step < p.n_random_start
@error "todo"
else
μ, logσ = p.approximator.actor(send_to_device(device(p.approximator), state)) |> send_to_host
μ, logσ = reshape(μ,(:, size(state)[2]) ), reshape(logσ, (:, size(state)[2]) )
StructArray{Normal}((μ, exp.(logσ)))
end
end
function RLBase.prob(p::PPOPolicy{<:ActorCritic,Categorical}, state::AbstractArray, mask)
logits = p.approximator.actor(send_to_device(device(p.approximator), state))
if !isnothing(mask)
logits .+= ifelse.(mask, 0f0, typemin(Float32))
end
logits = logits |> softmax |> send_to_host
if p.update_step < p.n_random_start
[
Categorical(fill(1 / length(x), length(x)); check_args = false) for
x in eachcol(logits)
]
else
[Categorical(x; check_args = false) for x in eachcol(logits)]
end
end
function RLBase.prob(p::PPOPolicy, env::MultiThreadEnv)
mask = ActionStyle(env) === FULL_ACTION_SET ? legal_action_space_mask(env) : nothing
prob(p, state(env), mask)
end
function RLBase.prob(p::PPOPolicy, env::AbstractEnv)
s = state(env)
s = Flux.unsqueeze(s, ndims(s) + 1)
mask = ActionStyle(env) === FULL_ACTION_SET ? legal_action_space_mask(env) : nothing
prob(p, s, mask)
end
(p::PPOPolicy)(env::MultiThreadEnv) = rand.(p.rng, prob(p, env))
# !!! https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/pull/533/files#r728920324
(p::PPOPolicy)(env::AbstractEnv) = rand.(p.rng, prob(p, env))
function (agent::Agent{<:PPOPolicy})(env::MultiThreadEnv)
dist = prob(agent.policy, env)
action = rand.(agent.policy.rng, dist)
if ndims(action) == 2
# action_log_prob = sum(logpdf.(dist, action), dims = 1)
action_log_prob = logpdf.(dist, action)
else
action_log_prob = logpdf.(dist, action)
end
EnrichedAction(action; action_log_prob = action_log_prob)
end
function RLBase.update!(
p::PPOPolicy,
t::Union{PPOTrajectory,MaskedPPOTrajectory},
::AbstractEnv,
::PreActStage,
)
length(t) == 0 && return # in the first update, only state & action are inserted into trajectory
p.update_step += 1
if p.update_step % p.update_freq == 0
_update!(p, t)
end
end
function _update!(p::PPOPolicy, t::AbstractTrajectory)
rng = p.rng
AC = p.approximator
γ = p.γ
λ = p.λ
n_epochs = p.n_epochs
n_microbatches = p.n_microbatches
clip_range = p.clip_range
w₁ = p.actor_loss_weight
w₂ = p.critic_loss_weight
w₃ = p.entropy_loss_weight
D = device(AC)
to_device(x) = send_to_device(D, x)
n_envs, n_rollout = size(t[:terminal])
@assert n_envs * n_rollout % n_microbatches == 0 "size mismatch"
microbatch_size = n_envs * n_rollout ÷ n_microbatches
n = length(t)
states_plus = to_device(t[:state])
if t isa MaskedPPOTrajectory
LAM = to_device(t[:legal_actions_mask])
end
states_flatten_on_host = flatten_batch(select_last_dim(t[:state], 1:n))
states_plus_values =
reshape(send_to_host(AC.critic(flatten_batch(states_plus))), n_envs, :)
# TODO: make generalized_advantage_estimation GPU friendly
advantages = generalized_advantage_estimation(
t[:reward],
states_plus_values,
γ,
λ;
dims = 2,
terminal = t[:terminal],
)
returns = to_device(advantages .+ select_last_dim(states_plus_values, 1:n_rollout))
advantages = to_device(advantages)
actions_flatten = flatten_batch(select_last_dim(t[:action], 1:n))
action_log_probs = flatten_batch(select_last_dim(t[:action_log_prob], 1:n))
# TODO: normalize advantage
for epoch in 1:n_epochs
rand_inds = shuffle!(rng, Vector(1:n_envs*n_rollout))
for i in 1:n_microbatches
inds = rand_inds[(i-1)*microbatch_size+1:i*microbatch_size]
if t isa MaskedPPOTrajectory
lam = select_last_dim(
flatten_batch(select_last_dim(LAM, 2:n+1)),
inds,
)
else
lam = nothing
end
# s = to_device(select_last_dim(states_flatten_on_host, inds))
# !!! we need to convert it into a continuous CuArray otherwise CUDA.jl will complain scalar indexing
s = to_device(collect(select_last_dim(states_flatten_on_host, inds)))
a = to_device(collect(select_last_dim(actions_flatten, inds)))
if eltype(a) === Int
a = CartesianIndex.(a, 1:length(a))
end
r = vec(returns)[inds]
log_p = to_device(collect(select_last_dim(action_log_probs, inds)))
adv = vec(advantages)[inds]
# need to reshape adv here to be the same size as log_p. There is 1 advantage estimate for each action. We have
# 1 mu and variance for each dimension of the action space. So we need to multiple each dimension of a given action
# by its respective advantage. To do so, we are going to copy adv to be the same vector repeated, so we can multiply
# elementwise. The following first flips adv from being 512x1 to 1x512, then makes it 2x512 so we can do elementwise ops
# with log_p, which is 2x512
adv = repeat(permutedims(adv), size(log_p)[1], 1)
ps = Flux.params(AC)
gs = gradient(ps) do
v′ = AC.critic(s) |> vec
if AC.actor isa GaussianNetwork
μ, logσ = AC.actor(s)
μ, logσ = reshape(μ,(size(a)[1], :) ), reshape(logσ, (size(a)[1], :) )
log_p′ₐ = normlogpdf(μ, exp.(logσ), a)
entropy_loss = mean(size(logσ, 1) * (log(2.0f0π) + 1) .+ sum(logσ; dims = 1)) / 2
else
# actor is assumed to return discrete logits
raw_logit′ = AC.actor(s)
if isnothing(lam)
logit′ = raw_logit′
else
logit′ = raw_logit′ .+ ifelse.(lam, 0.0f0, typemin(Float32))
end
p′ = softmax(logit′)
log_p′ = logsoftmax(logit′)
log_p′ₐ = log_p′[a]
entropy_loss = -sum(p′ .* log_p′) * 1 // size(p′, 2)
end
ratio = exp.(log_p′ₐ .- log_p)
surr1 = ratio .* adv
surr2 = clamp.(ratio, 1.0f0 - clip_range, 1.0f0 + clip_range) .* adv
actor_loss = -mean(min.(surr1, surr2))
critic_loss = mean((r .- v′) .^ 2)
loss = w₁ * actor_loss + w₂ * critic_loss - w₃ * entropy_loss
ignore() do
p.actor_loss[i, epoch] = actor_loss
p.critic_loss[i, epoch] = critic_loss
p.entropy_loss[i, epoch] = entropy_loss
p.loss[i, epoch] = loss
end
loss
end
p.norm[i, epoch] = clip_by_global_norm!(gs, ps, p.max_grad_norm)
update!(AC, gs)
end
end
end
function RLBase.update!(
trajectory::Union{PPOTrajectory,MaskedPPOTrajectory},
::PPOPolicy,
env::MultiThreadEnv,
::PreActStage,
action::EnrichedAction,
)
push!(
trajectory;
state = state(env),
action = action.action,
action_log_prob = action.meta.action_log_prob,
)
if trajectory isa MaskedPPOTrajectory
push!(trajectory; legal_actions_mask = legal_action_space_mask(env))
end
end
Maybe try this out and see if the changes are good
A simpler approach is to predict a vector of action and then reshape it to the necessary size before feeding it to the environment.
But yes, we'd better support it natively.
Thanks for providing the patch code. I think in most cases, you don't need to repeat the adv
since we can leverage broadcasting. But I understand your motivation here. I'll review it when porting PPO to the latest code in the master branch.
Few questions: 1) With respect to repeating advantage, is there a better way? Such as a function in Julia to do that broadcast? It is a bit weird since we are broadcasting along only 1 dimension of the array since the others match up. I suppose a for loop would have been better too, but wanted to keep the code clean so I didn't add it. Edit** found it. Just needed to do the adjunct function. See https://discourse.julialang.org/t/broadcasting-matrix-vector/16683
2) What is the advantage of leaving it as a vector the entire time, and only reshaping at the end? Are there performance gains? Or is it just stylistic?
3) Does Julia have runtime cost of reshaping an array? Since the underlying data does not change, it seems like there should be a constant time way of changing the dimensions without copying the entire array. But, it depends on implementation.
Few questions:
- With respect to repeating advantage, is there a better way? Such as a function in Julia to do that broadcast? It is a bit weird since we are broadcasting along only 1 dimension of the array since the others match up. I suppose a for loop would have been better too, but wanted to keep the code clean so I didn't add it.
In many cases, the broadcasting will happen automatically as expected. But I need to check the change you proposed here later.
- What is the advantage of leaving it as a vector the entire time, and only reshaping at the end? Are there performance gains? Or is it just stylistic?
I think it's just that we didn't consider action space of higher dimensions before.
- Does Julia have runtime cost of reshaping an array? Since the underlying data does not change, it seems like there should be a constant time way of changing the dimensions without copying the entire array. But, it depends on implementation.
Almost no extra cost with the latest version of Julia. You can verify it with BenchmarkTools.jl
Whoops, just figured out how to do 1. See https://discourse.julialang.org/t/broadcasting-matrix-vector/16683. Basically I needed to do the adjunct function to switch the vector from being Nx1 to 1xN, then broadcasting works as expected.
Also FYI, TD3 has the same issue. Is there a multidimensional action space example we could use instead of Pendulum? This should force any related bugs to appear when trying to get such an example working (on all algs).
In any case, thank you!
Hello. Was this issue ever actually resolved? I seem to be suffering the same issue when trying to use an action space with 2 dimensions. If it was updated, there doesn't seem to be an example that I can find to demonstrate the proper way to set up the PPOTrajectory. Thanks!
Its been awhile, but I think the key think for getting the trajectory to work is to make it allocate space for every action dim:
action = Matrix{Float32} => (na, N_ENV),
action_log_prob = Matrix{Float32} => (na, N_ENV),
Otherwise it does not have a big enough buffer to store all of the actions. The above code worked for me at some point, but I have not tried in awhile or tested on multiple versions.
I am not sure if its been fixed officially yet either.
Thanks for the reply!
I did try it while setting the action and action_log_prob like in your comment and still got the ERROR: DimensionMismatch: array could not be broadcast to match destination
message. I tried a few alternative ways as well and either got the DimensionMismatch error or other errors downstream.
Ah, it looks like the line you have commented out in your solution
function (agent::Agent{<:PPOPolicy})(env::MultiThreadEnv)
dist = prob(agent.policy, env)
action = rand.(agent.policy.rng, dist)
if ndims(action) == 2
# action_log_prob = sum(logpdf.(dist, action), dims = 1)
action_log_prob = logpdf.(dist, action)
else
action_log_prob = logpdf.(dist, action)
end
EnrichedAction(action; action_log_prob = action_log_prob)
end
is still in the version that I have. I do indeed have 2 dims in my action. I guess I'll try overriding this method for now to see if it helps.
PPO does not currently work on MDPs with multiple dimensions in their action space. The main example, and probably the one used for debugging, is Pendulum which only has 1 action. I believe the assumption that there is only 1 action is somewhat baked into the algorithms current implementation. For example, I've changed that example to have 2 dimensional action space rather than 1. Note the second dimension is not used, but is still exposed to the algorithm:
This example fails, though it seems like it should be possible. Alternatively, I've tried changing the trajectory to look like so:
Which looks more correct, but also fails. I think it is due to a few points where the action and action log prob are vectorized in the algorithm. I've tried converting the action to the right shape, or removing a vector function call, by doing the following (in ppo.jl):
Add
μ, logσ = reshape(μ,(:, size(state)[2]) ), reshape(logσ, (:, size(state)[2]) )
at line 153 Comment out 196 and addaction_log_prob = logpdf.(dist, action)
instead.After doing the following, it still fails in the update phase during derivation because that part also uses vec() a few times. I am will try to investigate that as well but it will take awhile as I don't understand PPO that well.
I am on Julia 1.7.2, [158674fc] ReinforcementLearning v0.10.0 [de1b191a] ReinforcementLearningCore v0.8.11 [25e41dd2] ReinforcementLearningEnvironments v0.6.12 [587475ba] Flux v0.12.10
Let me know if this is actually user error.
**edit My also need to check if GaussianNetwork, which I believe returns vector, needs a change? This is the cause for the edit at 153. https://github.com/JuliaReinforcementLearning/ReinforcementLearningZoo.jl/blob/ca8f3474ba239bde70a3b11071d98befa586ff7c/src/algorithms/policy_gradient/vpg.jl