JuliaPOMDP / DeepQLearning.jl

Implementation of the Deep Q-learning algorithm to solve MDPs
Other
72 stars 13 forks source link
deep-reinforcement-learning machine-learning pomdps reinforcement-learning

DeepQLearning

Build status codecov

This package provides an implementation of the Deep Q learning algorithm for solving MDPs. For more information see https://arxiv.org/pdf/1312.5602.pdf. It uses POMDPs.jl and Flux.jl

It supports the following innovations:

Installation

using Pkg
Pkg.add("DeepQLearning")

Usage

using DeepQLearning
using POMDPs
using Flux
using POMDPModels
using POMDPSimulators
using POMDPTools

# load MDP model from POMDPModels or define your own!
mdp = SimpleGridWorld();

# Define the Q network (see Flux.jl documentation)
# the gridworld state is represented by a 2 dimensional vector.
model = Chain(Dense(2, 32), Dense(32, length(actions(mdp))))

exploration = EpsGreedyPolicy(mdp, LinearDecaySchedule(start=1.0, stop=0.01, steps=10000/2))

solver = DeepQLearningSolver(qnetwork = model, max_steps=10000, 
                             exploration_policy = exploration,
                             learning_rate=0.005,log_freq=500,
                             recurrence=false,double_q=true, dueling=true, prioritized_replay=true)
policy = solve(solver, mdp)

sim = RolloutSimulator(max_steps=30)
r_tot = simulate(sim, mdp, policy)
println("Total discounted reward for 1 simulation: $r_tot")

Specifying exploration / evaluation policy

An exploration policy and evaluation policy can be specified in the solver parameters.

An exploration policy can be provided in the form of a function that must return an action. The function provided will be called as follows: f(policy, env, obs, global_step, rng) where policy is the NN policy being trained, env the environment, obs the observation at which to take the action, global_step the interaction step of the solver, and rng a random number generator. This package provides by default an epsilon greedy policy with linear decrease of epsilon with global_step.

An evaluation policy can be provided in a similar manner. The function will be called as follows: f(policy, env, n_eval, max_episode_length, verbose) where policy is the NN policy being trained, env the environment, n_eval the number of evaluation episode, max_episode_length the maximum number of steps in one episode, and verbose a boolean to enable printing or not. The evaluation function must returns three elements:

Q-Network

The qnetwork options of the solver should accept any Chain object. It is expected that they will be multi-layer perceptrons or convolutional layers followed by dense layer. If the network is ending with dense layers, the dueling option will split all the dense layers at the end of the network.

If the observation is a multi-dimensional array (e.g. an image), one can use the flattenbatch function to flatten all the dimensions of the image. It is useful to connect convolutional layers and dense layers for example. flattenbatch will flatten all the dimensions but the batch size.

The input size of the network is problem dependent and must be specified when you create the q network.

This package exports the type AbstractNNPolicy which represents neural network based policy. In addition to the functions from POMDPs.jl, AbstractNNPolicy objects supports the following:

Saving/Reloading model

See Flux.jl documentation for saving and loading models. The DeepQLearning solver saves the weights of the Q-network as a bson file in solver.logdir/"qnetwork.bson".

Logging

Logging is done through TensorBoardLogger.jl. A log directory can be specified in the solver options, to disable logging you can set the logdir option to nothing.

GPU Support

DeepQLearning.jl should support running the calculations on GPUs through the package CuArrays.jl. You must checkout the branch gpu-support. Note that it has not been tested thoroughly. To run the solver on GPU you must first load CuArrays and then proceed as usual.

using CuArrays
using DeepQLearning
using POMDPs
using Flux
using POMDPModels

mdp = SimpleGridWorld();

# the model weights will be send to the gpu in the call to solve
model = Chain(Dense(2, 32), Dense(32, length(actions(mdp))))

solver = DeepQLearningSolver(qnetwork = model, max_steps=10000, 
                             learning_rate=0.005,log_freq=500,
                             recurrence=false,double_q=true, dueling=true, prioritized_replay=true)
policy = solve(solver, mdp)

Solver Options

Fields of the Q Learning solver: