HumanCompatibleAI / imitation

Clean PyTorch implementations of imitation and reward learning algorithms
https://imitation.readthedocs.io/
MIT License
1.22k stars 233 forks source link

RewardNetwork predict_processed doesn't work without next_state and done #836

Open gustavodemari opened 6 months ago

gustavodemari commented 6 months ago

Bug description

RewardNet predict_processed method only works using state, action, next_state and done attributes, despite trained using only state, action.

For example, the BasicRewardNet by default trains a network using only state, action, i.e, $R(s, a)$. However, the predict_processed needs state, action, next_state and done attributes.

Thus, maybe predict_processed should have next_state and done optional (see below) and inside the method should check if next_state and done are None to change the behavior.

def predict_processed(
        self,
        state: np.ndarray,
        action: np.ndarray,
        next_state: Optional[np.ndarray] = None,
        done: Optional[np.ndarray] = None,
        **kwargs,
    ) -> np.ndarray:

Steps to reproduce

#!/usr/bin/env python
# coding: utf-8

import numpy as np
from imitation.policies.serialize import load_policy
from imitation.util.util import make_vec_env
from imitation.data.wrappers import RolloutInfoWrapper

SEED = 42

env = make_vec_env(
    "seals:seals/CartPole-v0",
    rng=np.random.default_rng(SEED),
    n_envs=8,
    post_wrappers=[
        lambda env, _: RolloutInfoWrapper(env)
    ],  # needed for computing rollouts later
)
expert = load_policy(
    "ppo-huggingface",
    organization="HumanCompatibleAI",
    env_name="seals/CartPole-v0",
    venv=env,
)

from imitation.data import rollout

rollouts = rollout.rollout(
    expert,
    env,
    rollout.make_sample_until(min_timesteps=None, min_episodes=60),
    rng=np.random.default_rng(SEED),
)

from imitation.algorithms.adversarial.gail import GAIL
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy
from stable_baselines3.common.evaluation import evaluate_policy

learner = PPO(
    env=env,
    policy=MlpPolicy,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0004,
    gamma=0.95,
    n_epochs=5,
    seed=SEED,
)
reward_net = BasicRewardNet(
    observation_space=env.observation_space,
    action_space=env.action_space,
    normalize_input_layer=RunningNorm,
)
gail_trainer = GAIL(
    demonstrations=rollouts,
    demo_batch_size=1024,
    gen_replay_buffer_capacity=512,
    n_disc_updates_per_round=8,
    venv=env,
    gen_algo=learner,
    reward_net=reward_net,
)

env.seed(SEED)
learner_rewards_before_training, _ = evaluate_policy(
    learner, env, 100, return_episode_rewards=True
)

gail_trainer.train(200_000)

env.seed(SEED)
learner_rewards_after_training, _ = evaluate_policy(
    learner, env, 100, return_episode_rewards=True
)

print(
    "Rewards before training:",
    np.mean(learner_rewards_before_training),
    "+/-",
    np.std(learner_rewards_before_training),
)
print(
    "Rewards after training:",
    np.mean(learner_rewards_after_training),
    "+/-",
    np.std(learner_rewards_after_training),
)

n_samples = 10

print(f"Generating {n_samples} samples")

obs = np.vstack([env.observation_space.sample() for i in range(n_samples)])
action = np.vstack([env.action_space.sample() for i in range(n_samples)])
next_obs = np.vstack([env.observation_space.sample() for i in range(n_samples)])
done = np.array([False] * len(obs))

print(f"Predicting rewards using {n_samples} samples")
rewards_predict_processed = reward_net.predict_processed(state=obs, action=action, next_state=next_obs, done=done)
print(f"Rewards: {rewards_predict_processed}")

print(f"Predicting rewards using {n_samples} samples, without next_state and done")
reward_net.predict_processed(state=obs, action=action)
reward_net.predict_processed(state=obs, action=action, next_state=None, done=None)

Environment

absl-py==2.0.0 aiohttp==3.9.1 aiosignal==1.3.1 alembic==1.13.1 anyio==4.2.0 argon2-cffi==23.1.0 argon2-cffi-bindings==21.2.0 arrow==1.3.0 asttokens==2.4.1 async-lru==2.0.4 async-timeout==4.0.3 attrs==23.2.0 Babel==2.14.0 backcall==0.2.0 beautifulsoup4==4.12.2 bleach==6.1.0 cachetools==5.3.2 certifi==2023.11.17 cffi==1.16.0 charset-normalizer==3.3.2 cloudpickle==3.0.0 colorama==0.4.6 colorlog==6.8.0 comm==0.2.1 contourpy==1.1.1 cycler==0.12.1 Cython==3.0.7 dataclasses==0.6 datasets==2.16.1 debugpy==1.8.0 decorator==5.1.1 defusedxml==0.7.1 dfa==2.1.2 dill==0.3.7 docopt==0.6.2 exceptiongroup==1.2.0 execnet==2.0.2 executing==2.0.1 Farama-Notifications==0.0.4 fastjsonschema==2.19.1 filelock==3.13.1 fonttools==4.47.0 fqdn==1.5.1 frozenlist==1.4.1 fsspec==2023.10.0 funcy==1.18 gitdb==4.0.11 GitPython==3.1.40 google-auth==2.26.1 google-auth-oauthlib==1.0.0 GPy==1.10.0 GPyOpt==1.2.6 greenlet==3.0.3 grpcio==1.60.0 gym==0.26.2 gym-notices==0.0.8 gymnasium==0.29.1 h5py==3.10.0 huggingface-hub==0.20.1 huggingface-sb3==3.0 idna==3.6 imitation==1.0.0 importlib-metadata==7.0.1 importlib-resources==6.1.1 iniconfig==2.0.0 ipykernel==6.28.0 ipython==8.12.3 isoduration==20.11.0 istype==0.2.0 jedi==0.19.1 Jinja2==3.1.2 joblib==1.3.2 json5==0.9.14 jsonpickle==3.0.2 jsonpointer==2.4 jsonschema==4.20.0 jsonschema-specifications==2023.12.1 jupyter-events==0.9.0 jupyter-lsp==2.2.1 jupyter_client==8.6.0 jupyter_core==5.7.0 jupyter_server==2.12.2 jupyter_server_terminals==0.5.1 jupyterlab==4.0.10 jupyterlab_pygments==0.3.0 jupyterlab_server==2.25.2 kiwisolver==1.4.5 lazytree==0.3.2 lenses==0.5.0 Mako==1.3.0 Markdown==3.5.1 markdown-it-py==3.0.0 MarkupSafe==2.1.3 matplotlib==3.7.4 matplotlib-inline==0.1.6 mdurl==0.1.2 mistune==3.0.2 mpmath==1.3.0 multidict==6.0.4 multiprocess==0.70.15 munch==4.0.0 mypy-extensions==1.0.0 nbclient==0.9.0 nbconvert==7.14.0 nbformat==5.9.2 nest-asyncio==1.5.8 networkx==3.1 notebook_shim==0.2.3 numpy==1.24.4 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==8.9.2.26 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.18.1 nvidia-nvjitlink-cu12==12.3.101 nvidia-nvtx-cu12==12.1.105 oauthlib==3.2.2 optuna==3.5.0 orderedset==2.0.3 overrides==7.4.0 packaging==23.2 pandas==2.0.3 pandocfilters==1.5.0 paramz==0.9.5 parso==0.8.3 pexpect==4.9.0 pickleshare==0.7.5 pillow==10.2.0 pip==23.3.1 pkgutil_resolve_name==1.3.10 platformdirs==4.1.0 pluggy==1.3.0 probabilistic-automata==0.4.2 prometheus-client==0.19.0 prompt-toolkit==3.0.43 protobuf==4.25.1 psutil==5.9.7 ptyprocess==0.7.0 pure-eval==0.2.2 py==1.11.0 py-cpuinfo==9.0.0 py-spy==0.3.14 pyarrow==14.0.2 pyarrow-hotfix==0.6 pyasn1==0.5.1 pyasn1-modules==0.3.0 pycparser==2.21 pygame==2.5.2 Pygments==2.17.2 pyparsing==3.1.1 pyrsistent==0.20.0 pytest==7.4.4 pytest-forked==1.6.0 pytest-xdist==2.5.0 python-dateutil==2.8.2 python-json-logger==2.0.7 pytz==2023.3.post1 PyYAML==6.0.1 pyzmq==25.1.2 referencing==0.32.1 requests==2.31.0 requests-oauthlib==1.3.1 rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 rich==13.7.0 rpds-py==0.16.2 rsa==4.9 sacred==0.8.5 scikit-learn==1.3.2 scipy==1.10.1 seals==0.2.1 Send2Trash==1.8.2 setuptools==68.2.2 singledispatch==4.1.0 six==1.16.0 smmap==5.0.1 sniffio==1.3.0 soupsieve==2.5 SQLAlchemy==2.0.25 stable-baselines3==2.2.1 stack-data==0.6.3 structlog==23.3.0 sympy==1.12 tensorboard==2.14.0 tensorboard-data-server==0.7.2 terminado==0.18.0 threadpoolctl==3.2.0 tinycss2==1.2.1 tomli==2.0.1 torch==2.1.2 tornado==6.4 tqdm==4.66.1 traitlets==5.14.1 triton==2.1.0 types-python-dateutil==2.8.19.20240106 typing-inspect==0.5.0 typing_extensions==4.9.0 tzdata==2023.4 uri-template==1.3.0 urllib3==2.1.0 wasabi==1.1.2 wcwidth==0.2.12 webcolors==1.13 webencodings==0.5.1 websocket-client==1.7.0 Werkzeug==3.0.1 wheel==0.41.2 wrapt==1.16.0 xeus-python==0.15.12 xeus-python-shell==0.5.0 xxhash==3.4.1 yarl==1.9.4 zipp==3.17.0

CAI23sbP commented 2 months ago

How are you @gustavodemari ? In my opinion, it is not a bug. See this link, flatten_trajectories creates next_obs and dones automatically. In this code which is used in GAIL for training, you can see flatten_trajectories s family, which is called flatten_trajectories_with_rew. So, you just choose about dones and next_obs in initialize BasicRewardNet, whether to use them or not.