hill-a / stable-baselines

A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
http://stable-baselines.readthedocs.io/
MIT License
4.12k stars 724 forks source link

ACKTR model crashes using CnnLnLstmPolicy #387

Open MartinBertran opened 5 years ago

MartinBertran commented 5 years ago

Describe the bug Describe the bug ACKTR example code crashes when modified to use MlpLnLstmPolicy. Apparent bug in KFAC code

Code example

import gym
import vizdoomgym
from stable_baselines.common.policies import CnnLnLstmPolicy, MlpLnLstmPolicy, MlpPolicy
from stable_baselines.common.vec_env import SubprocVecEnv, DummyVecEnv
from stable_baselines import ACKTR
n_cpu=4
if __name__=="__main__":
    env = SubprocVecEnv([lambda: gym.make('VizdoomCorridor-v0') for i in range(n_cpu)])
    model = ACKTR(CnnLnLstmPolicy, env, verbose=1)

    model.learn(total_timesteps=20000000)

results in:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
~/anaconda3/envs/pythonRL/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
   1658   try:
-> 1659     c_op = c_api.TF_FinishOperation(op_desc)
   1660   except errors.InvalidArgumentError as e:

InvalidArgumentError: Shape must be rank 2 but is rank 1 for 'kfac/MatMul_2' (op: 'MatMul') with input shapes: [32], [32].

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-2-12e263ec93c1> in <module>
     11 
     12     env = SubprocVecEnv([lambda: gym.make('VizdoomCorridor-v0') for i in range(n_cpu)])
---> 13     model = ACKTR(CnnLnLstmPolicy, env, verbose=1)
     14 
     15     model.learn(total_timesteps=20000000)

~/ReinforcementLearning/stable-baselines/stable_baselines/acktr/acktr_disc.py in __init__(self, policy, env, gamma, nprocs, n_steps, ent_coef, vf_coef, vf_fisher_coef, learning_rate, max_grad_norm, kfac_clip, lr_schedule, verbose, tensorboard_log, _init_setup_model, async_eigen_decomp, policy_kwargs, full_tensorboard_log)
    101 
    102         if _init_setup_model:
--> 103             self.setup_model()
    104 
    105     def _get_pretrain_placeholders(self):

~/ReinforcementLearning/stable-baselines/stable_baselines/acktr/acktr_disc.py in setup_model(self)
    195 
    196                         print(self.joint_fisher)
--> 197                         optim.compute_and_apply_stats(self.joint_fisher, var_list=params)
    198 
    199                 self.train_model = train_model

~/ReinforcementLearning/stable-baselines/stable_baselines/acktr/kfac.py in compute_and_apply_stats(self, loss_sampled, var_list)
    332             varlist = tf.trainable_variables()
    333 
--> 334         stats = self.compute_stats(loss_sampled, var_list=varlist)
    335         return self.apply_stats(stats)
    336 

~/ReinforcementLearning/stable-baselines/stable_baselines/acktr/kfac.py in compute_stats(self, loss_sampled, var_list)
    475 
    476                     cov_b = tf.matmul(bprop_factor, bprop_factor,
--> 477                                       transpose_a=True) / tf.cast(tf.shape(bprop_factor)[0], tf.float32)
    478 
    479                     update_ops.append(cov_b)

~/anaconda3/envs/pythonRL/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py in matmul(a, b, transpose_a, transpose_b, adjoint_a, adjoint_b, a_is_sparse, b_is_sparse, name)
   2453     else:
   2454       return gen_math_ops.mat_mul(
-> 2455           a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name)
   2456 
   2457 

~/anaconda3/envs/pythonRL/lib/python3.6/site-packages/tensorflow/python/ops/gen_math_ops.py in mat_mul(a, b, transpose_a, transpose_b, name)
   5331   _, _, _op = _op_def_lib._apply_op_helper(
   5332         "MatMul", a=a, b=b, transpose_a=transpose_a, transpose_b=transpose_b,
-> 5333                   name=name)
   5334   _result = _op.outputs[:]
   5335   _inputs_flat = _op.inputs

~/anaconda3/envs/pythonRL/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords)
    786         op = g.create_op(op_type_name, inputs, output_types, name=scope,
    787                          input_types=input_types, attrs=attr_protos,
--> 788                          op_def=op_def)
    789       return output_structure, op_def.is_stateful, op
    790 

~/anaconda3/envs/pythonRL/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    505                 'in a future version' if date is None else ('after %s' % date),
    506                 instructions)
--> 507       return func(*args, **kwargs)
    508 
    509     doc = _add_deprecated_arg_notice_to_docstring(

~/anaconda3/envs/pythonRL/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in create_op(***failed resolving arguments***)
   3298           input_types=input_types,
   3299           original_op=self._default_original_op,
-> 3300           op_def=op_def)
   3301       self._create_op_helper(ret, compute_device=compute_device)
   3302     return ret

~/anaconda3/envs/pythonRL/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in __init__(self, node_def, g, inputs, output_types, control_inputs, input_types, original_op, op_def)
   1821           op_def, inputs, node_def.attr)
   1822       self._c_op = _create_c_op(self._graph, node_def, grouped_inputs,
-> 1823                                 control_input_ops)
   1824 
   1825     # Initialize self._outputs.

~/anaconda3/envs/pythonRL/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
   1660   except errors.InvalidArgumentError as e:
   1661     # Convert to ValueError for backwards compatibility.
-> 1662     raise ValueError(str(e))
   1663 
   1664   return c_op

ValueError: Shape must be rank 2 but is rank 1 for 'kfac/MatMul_2' (op: 'MatMul') with input shapes: [32], [32].

System Info Describe the characteristic of your environment:

Additional context seems to expect bprop to be a batch x channel tensor, but is instead a batch tensor, stems from optim.compute_and_apply_stats(self.joint_fisher, var_list=params), joint_fisher is a (838980, 32) tensor

araffin commented 5 years ago

Hello,

It seems that it may be related from your custom environment. The following code works on my machine:

from stable_baselines.common.cmd_util import make_atari_env
from stable_baselines import ACKTR

env = make_atari_env("BreakoutNoFrameskip-v4", num_env=2, seed=1)
# Reduce number of steps to avoid memory issue
model = ACKTR("CnnLnLstmPolicy", env, n_steps=4, verbose=1)
model.learn(1000)
MartinBertran commented 5 years ago

That code snippet does not work for me

Process ForkProcess-1: Traceback (most recent call last): File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap self.run() File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/multiprocessing/process.py", line 93, in run self._target(self._args, self._kwargs) File "/home/martin/ReinforcementLearning/stable-baselines/stable_baselines/common/vec_env/subproc_vec_env.py", line 13, in _worker env = env_fn_wrapper.var() File "/home/martin/ReinforcementLearning/stable-baselines/stable_baselines/common/cmd_util.py", line 38, in _thunk env = make_atari(env_id) File "/home/martin/ReinforcementLearning/stable-baselines/stable_baselines/common/atari_wrappers.py", line 284, in make_atari env = gym.make(env_id) File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/site-packages/gym/envs/registration.py", line 156, in make return registry.make(id, kwargs) File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/site-packages/gym/envs/registration.py", line 101, in make env = spec.make(kwargs) File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/site-packages/gym/envs/registration.py", line 73, in make env = cls(_kwargs) File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/site-packages/gym/envs/atari/atari_env.py", line 69, in init self.seed() File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/site-packages/gym/envs/atari/atari_env.py", line 93, in seed modes = self.ale.getAvailableModes() AttributeError: 'ALEInterface' object has no attribute 'getAvailableModes' Process ForkProcess-2: Traceback (most recent call last): File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap self.run() File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/multiprocessing/process.py", line 93, in run self._target(self._args, self._kwargs) File "/home/martin/ReinforcementLearning/stable-baselines/stable_baselines/common/vec_env/subproc_vec_env.py", line 13, in _worker env = env_fn_wrapper.var() File "/home/martin/ReinforcementLearning/stable-baselines/stable_baselines/common/cmd_util.py", line 38, in _thunk env = make_atari(env_id) File "/home/martin/ReinforcementLearning/stable-baselines/stable_baselines/common/atari_wrappers.py", line 284, in make_atari env = gym.make(env_id) File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/site-packages/gym/envs/registration.py", line 156, in make return registry.make(id, kwargs) File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/site-packages/gym/envs/registration.py", line 101, in make env = spec.make(kwargs) File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/site-packages/gym/envs/registration.py", line 73, in make env = cls(_kwargs) File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/site-packages/gym/envs/atari/atari_env.py", line 69, in init self.seed() File "/home/martin/anaconda3/envs/pythonRL/lib/python3.6/site-packages/gym/envs/atari/atari_env.py", line 93, in seed modes = self.ale.getAvailableModes() AttributeError: 'ALEInterface' object has no attribute 'getAvailableModes'

ConnectionResetError Traceback (most recent call last)

in 19 from stable_baselines import ACKTR 20 ---> 21 env = make_atari_env("BreakoutNoFrameskip-v4", num_env=2, seed=1) 22 # Reduce number of steps to avoid memory issue 23 model = ACKTR("CnnLnLstmPolicy", env, n_steps=4, verbose=1) ~/ReinforcementLearning/stable-baselines/stable_baselines/common/cmd_util.py in make_atari_env(env_id, num_env, seed, wrapper_kwargs, start_index, allow_early_resets, start_method) 49 50 return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)], ---> 51 start_method=start_method) 52 53 ~/ReinforcementLearning/stable-baselines/stable_baselines/common/vec_env/subproc_vec_env.py in __init__(self, env_fns, start_method) 91 92 self.remotes[0].send(('get_spaces', None)) ---> 93 observation_space, action_space = self.remotes[0].recv() 94 VecEnv.__init__(self, len(env_fns), observation_space, action_space) 95 ~/anaconda3/envs/pythonRL/lib/python3.6/multiprocessing/connection.py in recv(self) 248 self._check_closed() 249 self._check_readable() --> 250 buf = self._recv_bytes() 251 return _ForkingPickler.loads(buf.getbuffer()) 252 ~/anaconda3/envs/pythonRL/lib/python3.6/multiprocessing/connection.py in _recv_bytes(self, maxsize) 405 406 def _recv_bytes(self, maxsize=None): --> 407 buf = self._recv(4) 408 size, = struct.unpack("!i", buf.getvalue()) 409 if maxsize is not None and size > maxsize: ~/anaconda3/envs/pythonRL/lib/python3.6/multiprocessing/connection.py in _recv(self, size, read) 377 remaining = size 378 while remaining > 0: --> 379 chunk = read(handle, remaining) 380 n = len(chunk) 381 if n == 0: ConnectionResetError: [Errno 104] Connection reset by peer

The same happens for other atari environments using ACKTR + CnnLstmPolicies, like

env = SubprocVecEnv([lambda: gym.make('Breakout-v0') for i in range(n_cpu)])
model = ACKTR(CnnLnLstmPolicy, env, verbose=False, tensorboard_log="./test/")

But it works fine on MlpLstmPolicy

if __name__=="__main__":
    env = SubprocVecEnv([lambda: gym.make('CartPole-v0') for i in range(n_cpu)])
    model = ACKTR(MlpLnLstmPolicy, env, verbose=False, tensorboard_log="./test/")

This seems to be an ACKTR-specific issue for me, PPO2 works for all listed examples

araffin commented 5 years ago

What is your gym version ? (+ associated, like atary-py)

MartinBertran commented 5 years ago

These are all I could think of stable_baselines.version ='2.6.1a0' atari-py==0.1.15 gym==0.13.0 tensorboard==1.14.0 tensorflow==1.13.1 tensorflow-estimator==1.14.0 tensorflow-gpu==1.14.0 vizdoom==1.1.7

araffin commented 4 years ago

The error seems related to tensorflow version (I could reproduce the bug in google colab)

ChengYen-Tang commented 4 years ago

@araffin I also got this error, may I ask which version of tensorflow you are using?

araffin commented 4 years ago

I also got this error, may I ask which version of tensorflow you are using?

tensorflow==1.8.0

So 1.8.0 cpu version

ChengYen-Tang commented 4 years ago

Thank you.