ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.34k stars 5.64k forks source link

[rllib] Recurrent torch models cause error with A2C #7693

Closed pmacalpine closed 4 years ago

pmacalpine commented 4 years ago

What is the problem?

When I try and use a recurrent torch model with A2C I get the following list index out of range error due to a state passed to the model's forward function being empty: 2020-03-21 23:38:55,276 ERROR trial_runner.py:513 -- Trial A2C_CartPole-v1_00000: Error processing event. Traceback (most recent call last): File "/data/home/patmac/ray/python/ray/tune/trial_runner.py", line 459, in _process_trial result = self.trial_executor.fetch_result(trial) File "/data/home/patmac/ray/python/ray/tune/ray_trial_executor.py", line 381, in fetch_result result = ray.get(trial_future[0], DEFAULT_GET_TIMEOUT) File "/data/home/patmac/ray/python/ray/worker.py", line 1511, in get raise value.as_instanceof_cause() ray.exceptions.RayTaskError(IndexError): ray::A2C.train() (pid=17292, ip=172.16.226.199) File "python/ray/_raylet.pyx", line 445, in ray._raylet.execute_task File "python/ray/_raylet.pyx", line 423, in ray._raylet.execute_task.function_executor File "/data/home/patmac/ray/python/ray/rllib/agents/trainer.py", line 502, in train raise e File "/data/home/patmac/ray/python/ray/rllib/agents/trainer.py", line 491, in train result = Trainable.train(self) File "/data/home/patmac/ray/python/ray/tune/trainable.py", line 256, in train result = self._train() File "/data/home/patmac/ray/python/ray/rllib/agents/trainer_template.py", line 146, in _train return self._train_exec_impl() File "/data/home/patmac/ray/python/ray/rllib/agents/trainer_template.py", line 178, in _train_exec_impl res = next(self.train_exec_impl) File "/data/home/patmac/ray/python/ray/util/iter.py", line 635, in __next__ return next(self.built_iterator) File "/data/home/patmac/ray/python/ray/util/iter.py", line 619, in set_restore_context for item in it: File "/data/home/patmac/ray/python/ray/util/iter.py", line 645, in apply_foreach for item in it: File "/data/home/patmac/ray/python/ray/util/iter.py", line 684, in apply_filter for item in it: File "/data/home/patmac/ray/python/ray/util/iter.py", line 645, in apply_foreach for item in it: File "/data/home/patmac/ray/python/ray/util/iter.py", line 716, in apply_flatten for item in it: File "/data/home/patmac/ray/python/ray/util/iter.py", line 669, in add_wait_hooks item = next(it) File "/data/home/patmac/ray/python/ray/util/iter.py", line 645, in apply_foreach for item in it: File "/data/home/patmac/ray/python/ray/util/iter.py", line 645, in apply_foreach for item in it: File "/data/home/patmac/ray/python/ray/util/iter.py", line 645, in apply_foreach for item in it: File "/data/home/patmac/ray/python/ray/util/iter.py", line 395, in base_iterator yield ray.get(futures, timeout=timeout) ray.exceptions.RayTaskError(IndexError): ray::RolloutWorker.par_iter_next() (pid=17293, ip=172.16.226.199) File "python/ray/_raylet.pyx", line 445, in ray._raylet.execute_task File "python/ray/_raylet.pyx", line 423, in ray._raylet.execute_task.function_executor File "/data/home/patmac/ray/python/ray/util/iter.py", line 957, in par_iter_next return next(self.local_it) File "/data/home/patmac/ray/python/ray/util/iter.py", line 619, in set_restore_context for item in it: File "/data/home/patmac/ray/python/ray/rllib/evaluation/rollout_worker.py", line 251, in gen_rollouts yield self.sample() File "/data/home/patmac/ray/python/ray/rllib/evaluation/rollout_worker.py", line 492, in sample batches = [self.input_reader.next()] File "/data/home/patmac/ray/python/ray/rllib/evaluation/sampler.py", line 53, in next batches = [self.get_data()] File "/data/home/patmac/ray/python/ray/rllib/evaluation/sampler.py", line 96, in get_data item = next(self.rollout_provider) File "/data/home/patmac/ray/python/ray/rllib/evaluation/sampler.py", line 338, in _env_runner callbacks, soft_horizon, no_done_at_end) File "/data/home/patmac/ray/python/ray/rllib/evaluation/sampler.py", line 487, in _process_observations outputs.append(episode.batch_builder.build_and_reset(episode)) File "/data/home/patmac/ray/python/ray/rllib/evaluation/sample_batch_builder.py", line 199, in build_and_reset self.postprocess_batch_so_far(episode) File "/data/home/patmac/ray/python/ray/rllib/evaluation/sample_batch_builder.py", line 153, in postprocess_batch_so_far pre_batch, other_batches, episode) File "/data/home/patmac/ray/python/ray/rllib/policy/torch_policy_template.py", line 110, in postprocess_trajectory convert_to_non_torch_type(other_agent_batches), episode) File "/data/home/patmac/ray/python/ray/rllib/agents/a3c/a3c_torch_policy.py", line 46, in add_advantages last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1]) File "/data/home/patmac/ray/python/ray/rllib/agents/a3c/a3c_torch_policy.py", line 74, in _value _ = self.model({"obs": torch.Tensor([obs]).to(self.device)}, [], [1]) File "/data/home/patmac/ray/python/ray/rllib/models/modelv2.py", line 150, in __call__ res = self.forward(restored, state or [], seq_lens) File "examples/torch_rnn.py", line 59, in forward h_in = hidden_state[0].reshape(-1, self.rnn_hidden_dim) IndexError: list index out of range

It seems problematic that [] is being passed as the state in a3c_torch_policy.py. Maybe this is slightly related to #7206 and recurrent models not being supported yet with torch, but if would be really nice if they are supported.

Ray version and other system information (Python version, TensorFlow version, OS): Ray: ray: 0.9.0.dev0 (revision 89d959fd6ac206a1a3b5a6cb151d19290df6a235) python: 3.7.2 torch: 1.4.0 OS: Ubuntu 16.04.6 LTS

Reproduction (REQUIRED)

Please provide a script that can be run to reproduce the issue. The script should have no external library dependencies (i.e., use fake or mock data / environments):

import numpy as np
import gym

from gym.spaces import Discrete, Box

from ray.rllib.models import Model, ModelCatalog

from ray.rllib.utils.annotations import override
from ray.rllib.utils import try_import_torch

import ray
from ray import tune

from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.misc import normc_initializer, valid_padding, SlimFC
from ray.rllib.models.preprocessors import get_preprocessor

_, nn = try_import_torch()

import torch
import torch.nn.functional as F

class RNNModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name)
        nn.Module.__init__(self)
        self.obs_size = _get_size(obs_space)
        self.rnn_hidden_dim = model_config["lstm_cell_size"]
        self.fc1 = nn.Linear(self.obs_size, self.rnn_hidden_dim)
        self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim)

        self._logits = SlimFC(
            self.rnn_hidden_dim, num_outputs, initializer=nn.init.xavier_uniform_)
        self._value_branch = SlimFC(
            self.rnn_hidden_dim, 1, initializer=normc_initializer())
        self._cur_value = None

    @override(TorchModelV2)
    def get_initial_state(self):
        # make hidden states on same device as model
        return [self.fc1.weight.new(1, self.rnn_hidden_dim).zero_().squeeze(0)]

    @override(TorchModelV2)
    def forward(self, input_dict, hidden_state, seq_lens):
        x = nn.functional.relu(self.fc1(input_dict["obs_flat"].float()))
        h_in = hidden_state[0].reshape(-1, self.rnn_hidden_dim)
        h = self.rnn(x, h_in)

        logits = self._logits(h)
        self._cur_value = self._value_branch(h).squeeze(1)
        return logits, [h]

    @override(TorchModelV2)
    def value_function(self):
        assert self._cur_value is not None, "must call forward() first"
        return self._cur_value

def _get_size(obs_space):
    return get_preprocessor(obs_space)(obs_space).size

if __name__ == "__main__":
    ray.init()
    ModelCatalog.register_custom_model("rnn_model", RNNModel)
    tune.run(
        "A2C",
        stop={
            "timesteps_total": 200000000,
        },
        config={
            "env": "CartPole-v1", 
            "model": {
                "custom_model": "rnn_model",
                "custom_options": {
                    "lstm_cell_size": 64,
                }
            },
            "num_workers": 4,  # parallelism
            "use_pytorch": True,
        },
    )

If we cannot run your script, we cannot fix your issue.

pmacalpine commented 4 years ago

As of the ray 0.8.5 release rllib has better support for recurrent torch models so closing this issue.

Arthur-Null commented 4 years ago

I am facing the same problem. I am using ray 0.8.5. File ray/rllib/models/modelv2.py, line 164, in call res = self.forward(restored, state or [], seq_lens) It seems [] is being passed as the state here.