HumanCompatibleAI / imitation

Clean PyTorch implementations of imitation and reward learning algorithms
https://imitation.readthedocs.io/
MIT License
1.26k stars 239 forks source link

Observation shape mismatch #547

Closed BryanZ666 closed 1 year ago

BryanZ666 commented 2 years ago

Bug description

I try to rollout the trajectories by PPO/A2C for task "BreakoutNoFrameskip-v4". An error occurred "ValueError: Observation spaces do not match". I think the Observation got transposed, (4,84,84) !=(84,84,4).

Steps to reproduce

from stable_baselines3 import PPO,A2C
from stable_baselines3.ppo import MlpPolicy,CnnPolicy
from stable_baselines3.common.evaluation import evaluate_policy
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack

import argparse

import gym
import torch
import numpy as np

if __name__ == "__main__":
    seed = 0
    env = make_atari_env("BeamRiderNoFrameskip-v4", n_envs=1, seed=seed)
    env = VecFrameStack(env, n_stack=4)

    expert = A2C(policy=CnnPolicy,env=env,seed=seed)
    expert.learn(25000)
    expert.save('./expert')

    expert.load('./expert')

    reward, _ = evaluate_policy(expert, env, 10)
    print("Expert Performance",reward)

    rollouts = rollout.rollout(
        expert,
        env,
        rollout.make_sample_until(min_timesteps=None, min_episodes=2),
    )
    transitions = rollout.flatten_trajectories(rollouts)

Environment

$ pip3.9 freeze --all absl-py==1.2.0 ale-py==0.7.4 astor==0.8.1 astunparse==1.6.3 AutoROM==0.4.2 AutoROM.accept-rom-license==0.4.2 # Editable install with no version control (baselines==0.1.5) -e /data/brzheng/Project/Explainable IL/baselines bleach==1.5.0 cachetools==5.2.0 certifi==2022.6.15 cffi==1.15.1 chai-sacred==0.8.3 charset-normalizer==2.1.1 click==8.1.3 cloudpickle==1.2.2 colorama==0.4.5 cycler==0.11.0 Cython==0.29.32 dill==0.3.5.1 docopt==0.6.2 flatbuffers==1.12 fonttools==4.37.1 future==0.18.2 gast==0.4.0 gitdb==4.0.9 GitPython==3.1.27 glfw==2.5.4 google-auth==2.11.0 google-auth-oauthlib==0.4.6 google-pasta==0.2.0 grpcio==1.34.1 gym==0.21.0 gym-notices==0.0.8 h5py==3.1.0 html5lib==0.9999999 idna==3.3 imageio==2.21.2 imitation==0.3.1 importlib-metadata==4.12.0 importlib-resources==5.9.0 joblib==1.1.0 jsonpickle==2.2.0 keras==2.9.0 keras-nightly==2.5.0.dev2021032900 Keras-Preprocessing==1.1.2 kiwisolver==1.4.4 libclang==14.0.6 lockfile==0.12.2 Markdown==3.4.1 MarkupSafe==2.1.1 matplotlib==3.5.3 mpi4py==3.1.3 mujoco-py==1.50.1.68 munch==2.5.0 numpy==1.23.2 oauthlib==3.2.0 opencv-python==4.6.0.66 opt-einsum==3.3.0 packaging==21.3 pandas==1.4.3 patchelf==0.15.0.0 Pillow==9.2.0 pip==22.2.2 progressbar2==4.0.0 protobuf==3.19.4 py-cpuinfo==8.0.0 pyasn1==0.4.8 pyasn1-modules==0.2.8 pycparser==2.21 pyglet==1.3.2 pyparsing==3.0.9 python-dateutil==2.8.2 python-utils==3.3.3 pytz==2022.2.1 pyzmq==23.2.1 requests==2.28.1 requests-oauthlib==1.3.1 rsa==4.9 scikit-learn==1.1.2 scipy==1.9.1 seals==0.1.2 setuptools==49.2.1 six==1.15.0 smmap==5.0.0 stable-baselines3 @ git+https://github.com/DLR-RM/stable-baselines3@a7f30b04e3285b62ed72ed3a7183972c03358681 tensorboard==2.9.1 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.1 tensorflow==2.9.1 tensorflow-estimator==2.9.0 tensorflow-io-gcs-filesystem==0.26.0 tensorflow-tensorboard==1.5.1 termcolor==1.1.0 threadpoolctl==3.1.0 torch==1.11.0 tqdm==4.64.0 typing-extensions==3.7.4.3 urllib3==1.26.12 Werkzeug==2.2.2 wheel==0.37.1 wrapt==1.12.1 zipp==3.8.1 zmq==0.0.0

dfilan commented 2 years ago

Indeed - StableBaselines 3 internally transposes the env to be channels-first for training, but rollout.rollout is using the env that you passed in. The solution is to instead call

rollouts = rollout.rollout(
        expert,
        expert.get_env(),
        rollout.make_sample_until(min_timesteps=None, min_episodes=2),
    )

Then you get a different error that I don't really understand - unwrap_traj ends up calling the 'rollout' key of some object that doesn't have that key.

At any rate, we should probably document this issue of rolling out on the wrong env somewhere.

dfilan commented 2 years ago

Related: https://github.com/HumanCompatibleAI/imitation/issues/486, https://github.com/HumanCompatibleAI/imitation/pull/519

BryanZ666 commented 2 years ago

Thanks for your comments, is there anything I could try to overcome this problem? what if I set the unwrap to False to bypass the unwrap_traj function? btw, I simply changed the expert to be PPO but the agent seems could not learn from the environment, what could be the problem?

AdamGleave commented 2 years ago

Then you get a different error that I don't really understand - unwrap_traj ends up calling the 'rollout' key of some object that doesn't have that key.

https://github.com/HumanCompatibleAI/imitation/blob/master/src/imitation/data/wrappers.py#L160 adds the "rollout" key. I do not know why that would be getting lost in algorithm.get_env() -- I guess one of the wrappers SB3 applise might eliminate the info dict, though that seems counterintuitive.

Setting unwrap to False is probably a viable workaround. It does mean you'll lose the last observation (VecEnv discards it as it auto-resets), but that might be OK for your use case.

BryanZ666 commented 2 years ago

Yeah, now it can rollouts trajectories! But the agent could learn meaningful policy using the default BC parameter.

bc_trainer = bc.BC(
        observation_space=env.observation_space,
        action_space=env.action_space,
        demonstrations=transitions,)
bc_trainer.train(n_epochs=300)

is there anything I could try to improve it?

AdamGleave commented 2 years ago

To sanity check, the transitions you're training BC on are transposed correctly relative to the environment you're giving BC? I imagine it would error out if not, but if it was treating channel dimension as batch dimension that would certainly prevent it from learning...

Otherwise, I don't have much to advise other than "tune hyperparameters". Behavioral cloning is a pretty weak imitation learning algorithm anyway, so it's common for it to require methods like data augmentation. Something like GAIL tends to be a lot more robust.

AdamGleave commented 1 year ago

Closing due to inactivity.