JuliaPOMDP / POMDPs.jl

MDPs and POMDPs in Julia - An interface for defining, solving, and simulating fully and partially observable Markov decision processes on discrete and continuous spaces.
http://juliapomdp.github.io/POMDPs.jl/latest/
Other
657 stars 100 forks source link

`DictPolicy` and special Q-learning based on key-value storage #459

Open NeroBlackstone opened 1 year ago

NeroBlackstone commented 1 year ago

If we have a discrete space, discrete action, generative MDP. And states space and actions space are hard to enumerate. But we still want to use the traditional tabular RL algorithm to solve it. So, I implement a DictPolicy, it used to store state-action pair values. (Sure. Users need to add Base.isequal() and Base.hash() for their state and action type.)

DictPolicy.jl :

struct DictPolicy{P<:Union{POMDP,MDP}, T<:AbstractDict{Tuple,Float64}} <: Policy
    mdp::P
    value_dict::T
end

# Returns the action that the policy deems best for the current state
function action(p::DictPolicy, s)
    available_actions = actions(mdp,s)
    max_action = nothing
    max_action_value = 0
    for a in available_actions
        if haskey(p.value_dict,(s,a))
            action_value = p.value_dict[(s,a)]
            if action_value > max_action_value
                max_action = a
                max_action_value = action_value
            end
        else
            p.value_dict[(s,a)] = 0
        end
    end
    if max_action === nothing
        max_action = available_actions[1]
    end
    return max_action
end

# returns the values of each action at state s in a dict
function actionvalues(p::DictPolicy, s) ::Dict
    available_actions = actions(mdp,s)
    action_dict = Dict()
    for a in available_actions
        haskey(p.value_dict,(s,a)) ? action_dict[a]  = value_dict[(s,a)] : action_dict[a] = 0
    end
    return action_dict
end

function Base.show(io::IO, mime::MIME"text/plain", p::DictPolicy{M}) where M <: MDP
    summary(io, p)
    println(io, ':')
    ds = get(io, :displaysize, displaysize(io))
    ioc = IOContext(io, :displaysize=>(first(ds)-1, last(ds)))
    showpolicy(io, mime, p.mdp, p)
end

Then we have a special Q-learning based on key-value storage, we don't need to enumerate states space and actions space in MDP definition. (okay, most code copy from TabularTDLearning.jl, but change Q-value store and read.

dict_q_learning.jl :

@with_kw mutable struct QLearningSolver{E<:ExplorationPolicy} <: Solver
   n_episodes::Int64 = 100
   max_episode_length::Int64 = 100
   learning_rate::Float64 = 0.001
   exploration_policy::E
   Q_vals::Union{Nothing, Dict{Tuple,Float64}} = nothing
   eval_every::Int64 = 10
   n_eval_traj::Int64 = 20
   rng::AbstractRNG = Random.GLOBAL_RNG
   verbose::Bool = true
end

function solve(solver::QLearningSolver, mdp::MDP)
    rng = solver.rng
    if solver.Q_vals === nothing
        Q = Dict{Tuple,Float64}()
    else
        Q = solver.Q_vals
    end
    exploration_policy = solver.exploration_policy
    sim = RolloutSimulator(rng=rng, max_steps=solver.max_episode_length)

    on_policy = DictPolicy(mdp, Q)
    k = 0
    for i = 1:solver.n_episodes
        s = rand(rng, initialstate(mdp))
        t = 0
        while !isterminal(mdp, s) && t < solver.max_episode_length
            a = action(exploration_policy, on_policy, k, s)
            k += 1
            sp, r = @gen(:sp, :r)(mdp, s, a, rng)
            max_sp_prediction = 0
            for k in keys(Q)
                if sp == k[1] && max_sp_prediction < Q[k]
                    max_sp_prediction = Q[k]
                end
            end
            current_s_prediction = 0 
            haskey(Q,(s,a)) ? (current_s_prediction = Q[(s,a)]) : (Q[(s,a)] = 0)
            Q[(s,a)] += solver.learning_rate * (r + discount(mdp) * max_sp_prediction - current_s_prediction)
            s = sp
            t += 1
        end
        if i % solver.eval_every == 0
            r_tot = 0.0
            for traj in 1:solver.n_eval_traj
                r_tot += simulate(sim, mdp, on_policy, rand(rng, initialstate(mdp)))
            end
            solver.verbose ? println("On Iteration $i, Returns: $(r_tot/solver.n_eval_traj)") : nothing
        end
    end
    return on_policy
end

What's your point of view? Do you have any advice? Thank you for taking the time to read my issue. If you think it's meaningful, I can opne a PR and add some test. It's okay if you think it's meaningless and no versatility. I just finish it for solve my MDP.

zsunberg commented 1 year ago

@NeroBlackstone sorry that we never responded to this! This is actually something that people often want to do. If you're still interested in contributing it, I think we can integrate it in with a few small adjustments. Let me know if you're interested in doing that.

NeroBlackstone commented 1 year ago

Hi, thanks for your comment. I'm ready to contribute to this feature.

I will do these things:

  1. Add DictPolicy and some test code in POMDPs.jl.
  2. Once DictPolicy is merged, I will contribute a Q-Learning solver and Prioritized Sweeping using this policy in TabularTDLearning.jl.
  3. A vanilla Prioritized Sweeping in TabularTDLearning.jl

I will open PR for the first step soon. If there are code problems, please point them out.

Thank you very much again.