HumanCompatibleAI / imitation

Clean PyTorch implementations of imitation and reward learning algorithms
https://imitation.readthedocs.io/
MIT License
1.3k stars 247 forks source link

GAIL always raises variable horizon error #669

Closed mertalbaba closed 1 year ago

mertalbaba commented 1 year ago

Bug description

When trying to train GAIL on Humanoid, always get variable horizon error. I am using the code provided on your documentation, which is written below.

Steps to reproduce

import numpy as np
import gym
from stable_baselines3 import PPO
from stable_baselines3 import SAC
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.ppo import MlpPolicy

from imitation.algorithms.adversarial.gail import GAIL
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm
from imitation.util.util import make_vec_env

env_name = "Humanoid-v3"
env = gym.make("Humanoid-v3")
expertAgent = SAC("MlpPolicy", env, verbose=1)
expertAgent.learn(10000)

print("Rollouts...")
rollouts = rollout.rollout(
    expertAgent,
    make_vec_env(
        env_name,
        n_envs=4,
        post_wrappers=[lambda env, _: RolloutInfoWrapper(env)],
        rng=rng,
    ),
    rollout.make_sample_until(min_timesteps=1000000, min_episodes=60),
    rng=rng,
)

print("Training...")
venv = make_vec_env(env_name, n_envs=8, rng=rng)
learner = PPO("MlpPolicy", venv, verbose=1)
reward_net = BasicRewardNet(
    venv.observation_space,
    venv.action_space,
    normalize_input_layer=RunningNorm,
)

gail_trainer = GAIL(
    demonstrations=rollouts,
    demo_batch_size=1024,
    gen_replay_buffer_capacity=2048,
    n_disc_updates_per_round=4,
    venv=venv,
    gen_algo=learner,
    reward_net=reward_net,
    allow_variable_horizon=True
)

gail_trainer.train(10000000)
rewards, _ = evaluate_policy(learner, venv, 100, return_episode_rewards=True)
print("Rewards:", rewards)

Environment

ernestum commented 1 year ago

That is probably because the "Humanoid-v3" environment has a variable horizon. Read more here for why this is an issue You probably want to use the "seals/Humanoid-v0" environment from the seals package instead.

AdamGleave commented 1 year ago

@mertalbaba can you link us to where we provide that code in our docs? If an example is not working we should certainly fix that.

mertalbaba commented 1 year ago

@ernestum Thanks for the solution. It works now. @AdamGleave the example works, since the environment is seals/CartPole-v0. When changing it to Humanoid-v3, I didn’t understand that I need to use seals/Humanoid-v0, therefore it failed.