hill-a / stable-baselines

A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
http://stable-baselines.readthedocs.io/
MIT License
4.16k stars 725 forks source link

Cannot evaluate if trained using more than 1 env [Custom env (Unity)] #1015

Closed mily20001 closed 4 years ago

mily20001 commented 4 years ago

Hi, I'm using custom environment created using UnityEnvironment and UnityToGymWrapper. To create env I adjusted code from cmd_util.py. It looks like this:

def make_unity_env(env_directory, num_env, visual, log_path, start_index=0):
    def make_env(rank):
        def _init():
            engine_configuration_channel = EngineConfigurationChannel()
            unity_env = UnityEnvironment(env_directory, worker_id=rank,
                                         side_channels=[engine_configuration_channel])
            env = UnityToGymWrapper(unity_env, uint8_visual=True, flatten_branched=True)
            engine_configuration_channel.set_configuration_parameters(time_scale=3.0, width=84, height=84,
                                                                      quality_level=0)
            env = Monitor(env, os.path.join(log_path, str(rank)) if log_path is not None else None,
                          allow_early_resets=True)
            return env

        return _init

    if visual:
        return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
    else:
        pass

I'm using visual observations with size 84x84. I have problem when I'm using more than 1 environment (num_env). Training works (results are not perfect, but there is clear progress in comparison with random agent), but then if I want to evaluate model on single env (or use EvalCallback) I get this error:

Traceback (most recent call last):
  File "(...)/miniconda3/envs/stable_baselines/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 956, in run
    run_metadata_ptr)
  File "(...)/miniconda3/envs/stable_baselines/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1156, in _run
    (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape (1, 84, 84, 1) for Tensor 'input/Ob:0', which has shape '(8, 84, 84, 1)'

This particular model was trained using 8 envs, and basically error is the same for other env count (except when using single env):

Cannot feed value of shape (1, 84, 84, 1) for Tensor 'input/Ob:0', which has shape '(<NUM_ENV>, 84, 84, 1)'

I've tried changing SubprocVecEnv to DummyVecEnv, but the error was the same. I've tested same scenario using CartPole gym and it was working fine.

System Info

mily20001 commented 4 years ago

Here you can find full example: https://github.com/mily20001/stable_baselines_unity/blob/master/example.tar.lzma It contains unity environment, python script and conda .yml file with dependencies. To run:

tar -xvf example.tar.lzma
cd example
conda env create -f env.yml
conda activate stable_baselines
python ./train_unity.py

After few seconds of training eval_callback kicks in and everything crashes.

Everything should run on any not-too-old Linux machine

If you need some other information, just let me know, I'll do my best to answer

Miffyli commented 4 years ago

Please attach the full training code next time so we do not have to download and extract packages :). Usually the problem can be solved by looking over the code related to stable-baselines.

Looking at the code it seems like you use LSTM policies which require same amount of environments during evaluation as used during training (mentioned in docs here).

araffin commented 4 years ago

You can find a solution here (you will need to create a custom EvalCallback: https://github.com/hill-a/stable-baselines/issues/166#issuecomment-502350843

Probably duplicate of https://github.com/hill-a/stable-baselines/issues/714#issuecomment-592462749

mily20001 commented 4 years ago

Ok, now it works, thank you :) What do you think about extending default eval callback so that it check self.model.policy.recurrent prop and based on that it completes observations with zeros (based on value from self.model.n_envs)? I'm open to create PR with that if you are interested

Miffyli commented 4 years ago

That'd alleviate headache of many users so sounds like a good suggestion :). Feel free to create a PR for it unless @araffin has anything against this.