chainer / chainerrl

ChainerRL is a deep reinforcement learning library built on top of Chainer.
MIT License
1.16k stars 226 forks source link

TRPO on Atari #415

Open ling-pan opened 5 years ago

ling-pan commented 5 years ago

Hi,

I am wondering whether chainerrl supports TRPO to run atari? I tried to do so by following the code for training PPO on atari, but I am faced with the following error:

Traceback (most recent call last): File "train_trpo_ale.py", line 187, in main() File "train_trpo_ale.py", line 182, in main eval_interval=args.eval_interval, File "/anaconda3/envs/chainerrl/lib/python3.6/site-packages/chainerrl-0.6.0-py3.6.egg/chainerrl/experiments/train_agent.py", line 174, in train_agent_with_evaluation File "/anaconda3/envs/chainerrl/lib/python3.6/site-packages/chainerrl-0.6.0-py3.6.egg/chainerrl/experiments/train_agent.py", line 59, in train_agent File "/anaconda3/envs/chainerrl/lib/python3.6/site-packages/chainerrl-0.6.0-py3.6.egg/chainerrl/agents/trpo.py", line 521, in act_and_train File "/anaconda3/envs/chainerrl/lib/python3.6/site-packages/chainerrl-0.6.0-py3.6.egg/chainerrl/misc/batch_states.py", line 23, in batch_states File "/anaconda3/envs/chainerrl/lib/python3.6/site-packages/chainer-6.0.0b3-py3.6.egg/chainer/dataset/convert.py", line 58, in wrap_call return func(*args, **kwargs) File "/anaconda3/envs/chainerrl/lib/python3.6/site-packages/chainer-6.0.0b3-py3.6.egg/chainer/dataset/convert.py", line 249, in concat_examples return to_device(device, _concat_arrays(batch, padding)) File "/anaconda3/envs/chainerrl/lib/python3.6/site-packages/chainer-6.0.0b3-py3.6.egg/chainer/dataset/convert.py", line 256, in _concat_arrays arrays = numpy.asarray(arrays) File "/anaconda3/envs/chainerrl/lib/python3.6/site-packages/numpy/core/numeric.py", line 538, in asarray return array(a, dtype, copy=False, order=order) TypeError: int() argument must be a string, a bytes-like object or a number, not 'LazyFrames'

It seems that PPO can handle LazyFrames, and I don't know why it fails on TRPO.

Thanks!

muupan commented 5 years ago

I think it should work, but not tested. Can you share train_trpo_ale.py?

ling-pan commented 5 years ago

It's here:

from future import division from future import print_function from future import unicode_literals from future import absolute_import from builtins import * # NOQA from future import standard_library standard_library.install_aliases() # NOQA

import argparse import logging import os

import chainer from chainer import functions as F import gym import gym.wrappers import numpy as np

import chainerrl from chainerrl import links import cupy

from chainerrl.wrappers import atari_wrappers

def main():

parser = argparse.ArgumentParser()
parser.add_argument('--env', type=str, default='BreakoutNoFrameskip-v4',
                    help='Gym Env ID')
parser.add_argument('--gpu', type=int, default=0,
                    help='GPU device ID. Set to -1 to use CPUs only.')
parser.add_argument('--seed', type=int, default=0,
                    help='Random seed [0, 2 ** 32)')
parser.add_argument('--outdir', type=str, default='results',
                    help='Directory path to save output files.'
                         ' If it does not exist, it will be created.')
parser.add_argument('--steps', type=int, default=10 ** 6,
                    help='Total time steps for training.')
parser.add_argument('--max-frames', type=int, default=30 * 60 * 60,  # 30 minutes with 60 fps
                    help='Maximum number of frames for each episode.')
parser.add_argument('--eval-interval', type=int, default=10 ** 5,
                    help='Interval between evaluation phases in steps.')
parser.add_argument('--eval-n-runs', type=int, default=10,
                    help='Number of episodes ran in an evaluation phase')
parser.add_argument('--demo', action='store_true', default=False,
                    help='Run demo episodes, not training')
parser.add_argument('--load', type=str, default='',
                    help='Directory path to load a saved agent data from'
                         ' if it is a non-empty string.')
parser.add_argument('--logger-level', type=int, default=logging.INFO,
                    help='Level of the root logger.')
parser.add_argument('--render', action='store_true', default=False,
                    help='Render the env')
parser.add_argument('--monitor', action='store_true',
                    help='Monitor the env by gym.wrappers.Monitor. Videos and additional log will be saved.')
parser.add_argument('--trpo-update-interval', type=int, default=5000,
                    help='Interval steps of TRPO iterations.')
args = parser.parse_args()

logging.basicConfig(level=args.logger_level)

# Set random seed
chainerrl.misc.set_random_seed(args.seed, gpus=(args.gpu,))

args.outdir = chainerrl.experiments.prepare_output_dir(args, args.outdir)
print('Output files are saved in {}'.format(args.outdir))

def make_env(test):
    # Use different random seeds for train and test envs
    env_seed = 2 ** 32 - args.seed if test else args.seed
    env = atari_wrappers.wrap_deepmind(
        atari_wrappers.make_atari(args.env, max_frames=args.max_frames),
        episode_life=not test,
        clip_rewards=not test)
    env.seed(int(env_seed))
    if args.monitor:
        env = gym.wrappers.Monitor(
            env, args.outdir,
            mode='evaluation' if test else 'training')
    if args.render:
        env = chainerrl.wrappers.Render(env)
    return env

env = make_env(test=False)
eval_env = make_env(test=True)
obs_space = env.observation_space
action_space = env.action_space
print('Observation space:', obs_space)
print('Action space:', action_space)

if not isinstance(obs_space, gym.spaces.Box):
    print("""This example only supports gym.spaces.Box observation spaces. To apply it to other observation spaces, use a custom phi function that convert an observation to numpy.ndarray of numpy.float32.""")  # NOQA
    return

# Normalize observations based on their empirical mean and variance
obs_normalizer = chainerrl.links.EmpiricalNormalization(obs_space.low.size)

# Use a Softmax policy for discrete action spaces
policy = chainerrl.policies.FCSoftmaxPolicy(
    obs_space.low.size,
    action_space.n,
    n_hidden_channels=64,
    n_hidden_layers=2,
    last_wscale=0.01,
    nonlinearity=F.tanh,
)

# Use a value function to reduce variance
vf = chainerrl.v_functions.FCVFunction(
    obs_space.low.size,
    n_hidden_channels=64,
    n_hidden_layers=2,
    last_wscale=0.01,
    nonlinearity=F.tanh,
)

if args.gpu >= 0:
    chainer.cuda.get_device_from_id(args.gpu).use()
    policy.to_gpu(args.gpu)
    vf.to_gpu(args.gpu)
    obs_normalizer.to_gpu(args.gpu)

# TRPO's policy is optimized via CG and line search, so it doesn't require a chainer.Optimizer. Only the value function needs it.
vf_opt = chainer.optimizers.Adam()
vf_opt.setup(vf)

# Draw the computational graph and save it in the output directory.
if policy.xp == cupy:
    formatted_obs_space_low = cupy.array(obs_space.low)
else:
    formatted_obs_space_low = obs_space.low

fake_obs = chainer.Variable(
    policy.xp.zeros_like(formatted_obs_space_low, dtype=policy.xp.float32)[None], name='observation'
)
chainerrl.misc.draw_computational_graph([policy(fake_obs)], os.path.join(args.outdir, 'policy'))
chainerrl.misc.draw_computational_graph([vf(fake_obs)], os.path.join(args.outdir, 'vf'))

# Feature extractor
def phi(x):
    return np.asarray(x, dtype=np.float32) / 255

# Hyperparameters in http://arxiv.org/abs/1709.06560
agent = chainerrl.agents.TRPO(
    policy=policy,
    vf=vf,
    vf_optimizer=vf_opt,
    obs_normalizer=obs_normalizer,
    update_interval=args.trpo_update_interval,
    conjugate_gradient_max_iter=20,
    conjugate_gradient_damping=1e-1,
    gamma=0.995,
    lambd=0.97,
    vf_epochs=5,
    entropy_coef=0,
)

if args.load:
    agent.load(args.load)

if args.demo:
    env = make_env(test=True)
    eval_stats = chainerrl.experiments.eval_performance(
        env=eval_env,
        agent=agent,
        n_steps=None,
        n_episodes=args.eval_n_runs,
    )
    print('n_runs: {} mean: {} median: {} stdev {}'.format(
        args.eval_n_runs, eval_stats['mean'], eval_stats['median'], eval_stats['stdev']))
else:
    chainerrl.experiments.train_agent_with_evaluation(
        agent=agent,
        env=env,
        eval_env=eval_env,
        outdir=args.outdir,
        steps=args.steps,
        eval_n_steps=None,
        eval_n_episodes=args.eval_n_runs,
        eval_interval=args.eval_interval,
    )

if name == 'main': main()