araffin / rl-baselines-zoo

A collection of 100+ pre-trained RL agents using Stable Baselines, training and hyperparameter optimization included.
https://stable-baselines.readthedocs.io/
MIT License
1.13k stars 208 forks source link

SAC Agent For Ant (PyBulletEnv-v0) Has Dimension Mismatch (Training with GAIL) #93

Open zrobertson466920 opened 4 years ago

zrobertson466920 commented 4 years ago

Important Note: We do not do technical support, nor consulting and don't answer personal questions per email.

Describe the bug When I use the available SAC agent for AntBulletEnv-v0 to create a dataset for GAIL I get a dimension mismatch. I'm working in this repository and slightly modify the enjoy.py script to setup training.

Code example

import os
import sys
import argparse
import importlib
import warnings

# numpy warnings because of tensorflow
warnings.filterwarnings("ignore", category=FutureWarning, module='tensorflow')
warnings.filterwarnings("ignore", category=UserWarning, module='gym')

import gym
import utils.import_envs  # pytype: disable=import-error
import numpy as np
import pickle
import stable_baselines
from stable_baselines.common import set_global_seeds
from stable_baselines.common.vec_env import VecNormalize, VecFrameStack, VecEnv

from stable_baselines.gail import generate_expert_traj
from stable_baselines.gail import ExpertDataset
from stable_baselines import PPO2, SAC, GAIL

from utils import ALGOS, create_test_env, get_latest_run_id, get_saved_hyperparams, find_saved_model
from utils.utils import StoreDict

import numpy as np

# Fix for breaking change in v2.6.0
sys.modules['stable_baselines.ddpg.memory'] = stable_baselines.common.buffers
stable_baselines.common.buffers.Memory = stable_baselines.common.buffers.ReplayBuffer

def evaluate():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', help='environment ID', type=str, default='AntBulletEnv-v0')
    parser.add_argument('-f', '--folder', help='Log folder', type=str, default='trained_agents')
    parser.add_argument('--algo', help='RL Algorithm', default='sac',
                        type=str, required=False, choices=list(ALGOS.keys()))
    parser.add_argument('-n', '--n-timesteps', help='number of timesteps', default=1000,
                        type=int)
    parser.add_argument('--n-envs', help='number of environments', default=1,
                        type=int)
    parser.add_argument('--exp-id', help='Experiment ID (default: -1, no exp folder, 0: latest)', default=-1,
                        type=int)
    parser.add_argument('--log-rollouts', help='Save Expert Trajectory Data', default=None,
                        type=str)
    parser.add_argument('--n-episodes', help='How Many Episdoes to Rollout', default=100,
                        type=int)
    parser.add_argument('--verbose', help='Verbose mode (0: no output, 1: INFO)', default=1,
                        type=int)
    parser.add_argument('--no-render', action='store_true', default=True,
                        help='Do not render the environment (useful for tests)')
    parser.add_argument('--deterministic', action='store_true', default=False,
                        help='Use deterministic actions')
    parser.add_argument('--stochastic', action='store_true', default=False,
                        help='Use stochastic actions (for DDPG/DQN/SAC)')
    parser.add_argument('--load-best', action='store_true', default=False,
                        help='Load best model instead of last model if available')
    parser.add_argument('--norm-reward', action='store_true', default=False,
                        help='Normalize reward if applicable (trained with VecNormalize)')
    parser.add_argument('--seed', help='Random generator seed', type=int, default=np.random.randint(0,1000))
    parser.add_argument('--reward-log', help='Where to log reward', default='', type=str)
    parser.add_argument('--gym-packages', type=str, nargs='+', default=[], help='Additional external Gym environemnt package modules to import (e.g. gym_minigrid)')
    parser.add_argument('--env-kwargs', type=str, nargs='+', action=StoreDict, help='Optional keyword argument to pass to the env constructor')
    args = parser.parse_args()

    # Going through custom gym packages to let them register in the global registory
    for env_module in args.gym_packages:
        importlib.import_module(env_module)

    env_id = args.env
    algo = args.algo
    folder = args.folder

    if args.exp_id == 0:
        args.exp_id = get_latest_run_id(os.path.join(folder, algo), env_id)
        print('Loading latest experiment, id={}'.format(args.exp_id))

    # Sanity checks
    if args.exp_id > 0:
        log_path = os.path.join(folder, algo, '{}_{}'.format(env_id, args.exp_id))
    else:
        log_path = os.path.join(folder, algo)

    assert os.path.isdir(log_path), "The {} folder was not found".format(log_path)

    model_path = find_saved_model(algo, log_path, env_id, load_best=args.load_best)

    if algo in ['dqn', 'ddpg', 'sac', 'td3']:
        args.n_envs = 1

    set_global_seeds(args.seed)

    is_atari = 'NoFrameskip' in env_id

    stats_path = os.path.join(log_path, env_id)
    hyperparams, stats_path = get_saved_hyperparams(stats_path, norm_reward=args.norm_reward, test_mode=True)

    log_dir = args.reward_log if args.reward_log != '' else None

    env_kwargs = {} if args.env_kwargs is None else args.env_kwargs

    env = create_test_env(env_id, n_envs=args.n_envs, is_atari=is_atari,
                          stats_path=stats_path, seed=args.seed, log_dir=log_dir,
                          should_render=not args.no_render,
                          hyperparams=hyperparams, env_kwargs=env_kwargs)

    # ACER raises errors because the environment passed must have
    # the same number of environments as the model was trained on.
    load_env = None if algo == 'acer' else env

    model = ALGOS[algo].load(model_path, env=load_env)

    generate_expert_traj(model, 'Ant_Test', n_episodes=10)
    dataset = ExpertDataset(expert_path='Ant_Test.npz', traj_limitation=-1, batch_size=128)

    model = GAIL('MlpPolicy', 'AntBulletEnv-v0', dataset, verbose=1)

    # Note: in practice, you need to train for 1M steps to have a working policy
    model.learn(total_timesteps=1.5e6)
    model.save("AntGAIL")

    model = GAIL.load("AntGAIL")

    env = gym.make('AntBulletEnv-v0')
    for i in range(10):
        obs = env.reset()
        rew = []
        done = False
        while done != True:
            action, _states = model.predict(obs)
            obs, rewards, done, info = env.step(action)
            rew.append(rewards)
        print(np.sum(rew))

if __name__ == '__main__':

    evaluate()

When I run this I get the error,

pybullet build time: Jun  2 2020 06:47:43
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/stable_baselines/sac/policies.py:194: flatten (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.
Instructions for updating:
Use keras.layers.flatten instead.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/stable_baselines/common/tf_layers.py:57: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.
Instructions for updating:
Use keras.layers.dense instead.
WARNING:tensorflow:From /home/zrobertson/.local/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From /home/zrobertson/.local/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
actions (10000, 8)
obs (10000, 29)
rewards (10000,)
episode_returns (10,)
episode_starts (10000,)
actions (10000, 8)
obs (10000, 29)
rewards (10000,)
episode_returns (10,)
episode_starts (10000,)
Total trajectories: -1
Total transitions: 10000
Average returns: 3490.275921516167
Std for returns: 19.19584521063349
Creating environment from the given name, wrapped in a DummyVecEnv.
pybullet build time: Jun  2 2020 06:47:43
********** Iteration 0 ************
Optimizing Policy...
sampling
done in 2.682 seconds
computegrad
done in 0.106 seconds
conjugate_gradient
      iter residual norm  soln norm
         0      0.124          0
         1     0.0697     0.0323
         2     0.0524       0.31
         3     0.0111      0.764
         4     0.0856       1.41
         5      0.011       1.47
         6    0.00497       2.42
         7     0.0125       3.12
         8     0.0023       3.13
         9    0.00331       3.67
        10    0.00095       3.72
done in 0.202 seconds
Expected: 0.087 Actual: 0.060
Stepsize OK!
vf
done in 0.059 seconds
sampling
done in 2.462 seconds
computegrad
done in 0.005 seconds
conjugate_gradient
      iter residual norm  soln norm
         0      0.306          0
         1     0.0414     0.0311
         2        0.2     0.0867
         3      0.063       0.32
         4      0.277      0.437
         5     0.0418      0.967
         6     0.0204      0.989
         7      0.388       1.64
         8     0.0101       1.92
         9      0.314       2.19
        10    0.00671       2.88
done in 0.017 seconds
Expected: 0.084 Actual: 0.069
Stepsize OK!
vf
done in 0.029 seconds
sampling
done in 2.611 seconds
computegrad
done in 0.004 seconds
conjugate_gradient
      iter residual norm  soln norm
         0      0.166          0
         1     0.0871     0.0361
         2      0.102      0.302
         3      0.166       0.37
         4       0.32      0.809
         5     0.0234      0.916
         6     0.0473      0.942
         7     0.0232       1.02
         8      0.131       1.08
         9     0.0434       1.63
        10     0.0287       1.68
done in 0.022 seconds
Expected: 0.076 Actual: 0.062
Stepsize OK!
vf
done in 0.041 seconds
Optimizing Discriminator...
generator_loss |   expert_loss |       entropy |  entropy_loss | generator_acc |    expert_acc
Traceback (most recent call last):
  File "/home/zrobertson/Atom_Projects/Python/rl-baselines-zoo/enjoy_noise_GAIL.py", line 140, in <module>
    evaluate()
  File "/home/zrobertson/Atom_Projects/Python/rl-baselines-zoo/enjoy_noise_GAIL.py", line 121, in evaluate
    model.learn(total_timesteps=1.5e6)
  File "/usr/local/lib/python3.6/dist-packages/stable_baselines/gail/model.py", line 54, in learn
    return super().learn(total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps)
  File "/usr/local/lib/python3.6/dist-packages/stable_baselines/trpo_mpi/trpo_mpi.py", line 458, in learn
    self.reward_giver.obs_rms.update(np.concatenate((ob_batch, ob_expert), 0))
  File "<__array_function__ internals>", line 6, in concatenate
ValueError: all the input array dimensions for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 28 and the array at index 1 has size 29

Process finished with exit code 1

System Info

Additional context This is a general problem where the dimension of this version of Ant has size 29 for SAC despite the real size being 28. The code works with a2c for example. However, the reward is much higher for SAC so I'd like to use this agent.

araffin commented 4 years ago

Hello, Please take a look at SAC hyperparameters, it is using a time feature wrapper (hence 29).

zrobertson466920 commented 4 years ago

I did a search and couldn't find 'time feature wrapper'. I wrote code to remove the 29th feature. Is this wrapper appended? Is there an easier or more correct solution?

P.S. Since you seem to be the owner. Is there a place/link so I can see what hyperparameters to use for Ant with GAIL? Thanks!

temp_dict = np.load('Ant_Test.npz')
temp_dict = dict(temp_dict)
temp_dict['obs'] = temp_dict['obs'][:,:28]
np.savez('Ant_Test', actions = temp_dict['actions'],
                             episode_returns = temp_dict['episode_returns'],
                             rewards = temp_dict['rewards'],
                             obs = temp_dict['obs'],
                             episode_starts = temp_dict['episode_starts'])
araffin commented 4 years ago

I did a search and couldn't find 'time feature wrapper'. I wrote code to remove the 29th feature. Is this wrapper appended? Is there an easier or more correct solution?

You did not search much apparently, see #79 for more information.

P.S. Since you seem to be the owner. Is there a place/link so I can see what hyperparameters to use for Ant with GAIL? Thanks!

I did not really work much with GAIL, so I cannot really help you on that one.

zrobertson466920 commented 4 years ago

To be clear, I searched for documentation. My issue is that there is no explanation of this wrapper in the documentation. I had to do digging to find this out. To be clear, is removing the 29th feature equivalent to removing the wrapper? I looked at the code and it seems to work by concatenation. This leads me to believe that my change to the observation space restores the original setup. Thanks!