hill-a / stable-baselines

A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
http://stable-baselines.readthedocs.io/
MIT License
4.16k stars 725 forks source link

[question]GAIL and discretized observation #603

Open IwamotoTaro opened 4 years ago

IwamotoTaro commented 4 years ago

Thank you for the wonderful tool.

(1) I was able to complete training using Cartpole expert data and GAIL. (2) Next, I added a wrapper to Cartpole to discretize observations and was able to complete training at TRPO. (3) Finally, I tried GAIL training with discretized observations, and an error (line 121 of gail / adversary.py) occurred.

Is GAIL training with discretized observations currently available at a stable-baselines?

araffin commented 4 years ago

Hello, Please fill the issue template completely.

IwamotoTaro commented 4 years ago

Code example The file has been uploaded. GAIL and discretized observation.zip ・ TRPO_cartpoleD.py: TRPO training at Cartpole for discretized observation ・ Gen_cartpoleD.py: Creates expert data for 5 episodes (operates the cart with the left and right arrow keys) ・ Cartpole_trajD.npz: Sample of expert data ・ GAIL_cartpoleD.py: GAIL training with discretized observation

Error messages and stack traces

C:\anaconda3\lib\site-packages\gym\envs\registration.py:14: PkgResourcesDeprecationWarning: Parameters to load are deprecated.  Call .resolve and .require separately.
  result = entry_point.load(False)
actions (315, 1)
obs (315, 1)
rewards (315,)
episode_returns (5,)
episode_starts (315,)
Total trajectories: -1
Total transitions: 315
Average returns: 63.0
Std for returns: 30.469657037781044
WARNING:tensorflow:From C:\anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
Traceback (most recent call last):
  File "C:\anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 511, in _apply_op_helper
    preferred_dtype=default_dtype)
  File "C:\anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 1175, in internal_convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
  File "C:\anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 977, in _TensorTensorConversionFunction
    (dtype.name, t.dtype.name, str(t)))
ValueError: Tensor conversion requested dtype int64 for Tensor with dtype float32: 'Tensor("adversary/obfilter/Cast:0", shape=(), dtype=float32)'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "GAIL_cartpoleD.py", line 49, in <module>
    model = GAIL(MlpPolicy, env, dataset, verbose=1)
  File "C:\anaconda3\lib\site-packages\stable_baselines\gail\model.py", line 49, in __init__
    self.setup_model()
  File "C:\anaconda3\lib\site-packages\stable_baselines\trpo_mpi\trpo_mpi.py", line 129, in setup_model
    entcoeff=self.adversary_entcoeff)#
  File "C:\anaconda3\lib\site-packages\stable_baselines\gail\adversary.py", line 77, in __init__
    generator_logits = self.build_graph(self.generator_obs_ph, self.generator_acs_ph, reuse=False)
  File "C:\anaconda3\lib\site-packages\stable_baselines\gail\adversary.py", line 121, in build_graph
    obs = (obs_ph - self.obs_rms.mean) / self.obs_rms.std
  File "C:\anaconda3\lib\site-packages\tensorflow\python\ops\math_ops.py", line 812, in binary_op_wrapper
    return func(x, y, name=name)
  File "C:\anaconda3\lib\site-packages\tensorflow\python\ops\gen_math_ops.py", line 10130, in sub
    "Sub", x=x, y=y, name=name)
  File "C:\anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 547, in _apply_op_helper
    inferred_from[input_arg.type_attr]))
TypeError: Input 'y' of 'Sub' Op has type float32 that does not match type int64 of argument 'x'.

System Info

araffin commented 4 years ago

obs = (obs_ph - self.obs_rms.mean) / self.obs_rms.std

It looks like this normalization should be deactivated when using discrete observation, otherwise cast error occurs. I'm also not sure if the correct preprocessing is applied for discrete observations (I only had a quick look).

The error comes from this line: https://github.com/hill-a/stable-baselines/blob/master/stable_baselines/gail/adversary.py#L118 (and it seems that the docstring is wrong)