ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.11k stars 5.6k forks source link

[RLlib] Build_for_inference() in env_runner_v2.py created empty state_out_1 and lead to failure of initiation #42978

Open QLs-Learning-Bot opened 7 months ago

QLs-Learning-Bot commented 7 months ago

What happened + What you expected to happen

I had a RNN model inheriting modelV2 that had worked well with ray 2.2.

In ray 2.9, I set the option as required: config.experimental(_enable_new_api_stack=False).build()

The error information is attached below here, and I tried to debug on myself though unsuccessfully. I noticed that the "sample_batches_by_policy" did not contain "state_out_1" when running "ray/rllib/evaluation/env_runner_v2.py". When calling the next function build_for_inference at line 326 of ray/rllib/connectors/agent/view_requirement.py, self.view_requirements created "state_in_1" with an empty list, which finally caused the IndexError. self.view_requirements['state_in_1'] viewed in debug mode looks like this:

ViewRequirement(data_col='state_out_1', space=Box(-1.0, 1.0, (256,), float32), shift=-1, index=None, batch_repeat_value=20, used_for_compute_actions=True, used_for_training=True, shift_arr=array([-1]))

Here is the error information:

2024-02-05 05:07:41,742 ERROR tune_controller.py:1374 -- Trial task failed for trial PPO_MultiAgentArena_v3_85c05_00000
Traceback (most recent call last):
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future
    result = ray.get(future)
             ^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/_private/auto_init_hook.py", line 22, in auto_init_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/_private/worker.py", line 2624, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(IndexError): ray::PPO.train() (pid=987409, ip=10.47.57.189, actor_id=da518257234fa0c302d5fd4d01000000, repr=PPO)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/tune/trainable/trainable.py", line 342, in train
    raise skipped from exception_cause(skipped)
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/tune/trainable/trainable.py", line 339, in train
    result = self.step()
             ^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py", line 852, in step
    results, train_iter_ctx = self._run_one_training_iteration()
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py", line 3042, in _run_one_training_iteration
    results = self.training_step()
              ^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo.py", line 407, in training_step
    train_batch = synchronous_parallel_sample(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/execution/rollout_ops.py", line 83, in synchronous_parallel_sample
    sample_batches = worker_set.foreach_worker(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py", line 705, in foreach_worker
    handle_remote_call_result_errors(remote_results, self._ignore_worker_failures)
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py", line 78, in handle_remote_call_result_errors
    raise r.get()
ray.exceptions.RayTaskError(IndexError): ray::RolloutWorker.apply() (pid=987409, ip=10.47.57.189, actor_id=d64b201bd95cea973cd5da4701000000, repr=<ray.rllib.evaluation.rollout_worker._modify_class.<locals>.Class object at 0x7fd832842e10>)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/utils/actor_manager.py", line 189, in apply
    raise e
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/utils/actor_manager.py", line 178, in apply
    return func(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/execution/rollout_ops.py", line 84, in <lambda>
    lambda w: w.sample(), local_worker=False, healthy_only=True
              ^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py", line 694, in sample
    batches = [self.input_reader.next()]
               ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/evaluation/sampler.py", line 91, in next
    batches = [self.get_data()]
               ^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/evaluation/sampler.py", line 276, in get_data
    item = next(self._env_runner)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/evaluation/env_runner_v2.py", line 344, in run
    outputs = self.step()
              ^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/evaluation/env_runner_v2.py", line 370, in step
    active_envs, to_eval, outputs = self._process_observations(
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/evaluation/env_runner_v2.py", line 637, in _process_observations
    processed = policy.agent_connectors(acd_list)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/connectors/agent/pipeline.py", line 41, in __call__
    ret = c(ret)
          ^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/connectors/connector.py", line 265, in __call__
    return [self.transform(d) for d in acd_list]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/connectors/connector.py", line 265, in <listcomp>
    return [self.transform(d) for d in acd_list]
            ^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/connectors/agent/view_requirement.py", line 118, in transform
    sample_batch = agent_collector.build_for_inference()
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/agent_collector.py", line 366, in build_for_inference
    self._cache_in_np(np_data, data_col)
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/agent_collector.py", line 613, in _cache_in_np
    cache_dict[key] = [_to_float_np_array(d) for d in self.buffers[key]]
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/agent_collector.py", line 613, in <listcomp>
    cache_dict[key] = [_to_float_np_array(d) for d in self.buffers[key]]
                       ^^^^^^^^^^^^^^^^^^^^^
  File "/home/lime/miniconda3/envs/ray29/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/agent_collector.py", line 32, in _to_float_np_array
    if torch and torch.is_tensor(v[0]):
                                 ~^^^
IndexError: list index out of range

Versions / Dependencies

Ray==2.9.1 Python==3.11

Reproduction script

The simple RNN model mySimpleRNN.py:

from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.framework import try_import_torch
from typing import Callable

torch, nn = try_import_torch()
class AnotherTorchRNNModel(RecurrentNetwork, nn.Module):
    def __init__(
            self,
            obs_space,
            action_space,
            num_outputs,
            model_config,
            name,
            rnn_hidden_size=256,
            l2_lambda = 3,
            l2_lambda_inp=0,
            device="cuda"
    ):
        nn.Module.__init__(self)
        super().__init__(obs_space, action_space, num_outputs, model_config, name)

        self.obs_size = get_preprocessor(obs_space)(obs_space).size
        self.rnn_hidden_size = model_config["custom_model_config"]["rnn_hidden_size"]
        self.l2_lambda = model_config["custom_model_config"]["l2_lambda"]
        self.l2_lambda_inp = model_config["custom_model_config"]["l2_lambda_inp"]

        # Build the Module from 0fc + RNN + 2xfc (action + value outs).
        # self.fc1 = nn.Linear(self.obs_size, self.fc_size)
        self.rnn = nn.RNN(self.obs_size, self.rnn_hidden_size, batch_first=True, nonlinearity='relu')
        self.action_branch = nn.Linear(self.rnn_hidden_size, num_outputs)
        self.value_branch = nn.Linear(self.rnn_hidden_size, 1)
        # Holds the current "base" output (before logits layer).
        self._features = None
        self.l2_loss = None
        self.l2_loss_inp = None
        self.original_loss = None

        self.activations = {}
        self.hooks = []
        self.device = device

    @override(ModelV2)
    def get_initial_state(self):
        # Place hidden states on same device as model.
        h = [
            self.rnn.weight_ih_l0.new(1, self.rnn_hidden_size).zero_().squeeze(0),
            self.rnn.weight_ih_l0.new(1, self.rnn_hidden_size).zero_().squeeze(0),
        ]
        return h

    @override(ModelV2)
    def value_function(self):
        assert self._features is not None, "must call forward() first"
        return torch.reshape(self.value_branch(self._features), [-1])

    @override(RecurrentNetwork)
    def forward_rnn(self, inputs, state, seq_lens):
        """Feeds inputs (B x T x ..) through the Gru Unit.
        Returns the resulting outputs as a sequence (B x T x ...).
        Values are stored in self._cur_value in simple (B) shape (where B
        contains both the B and T dims!).
        Returns:
            NN Outputs (B x T x ...) as sequence.
            The state batches as a List of two items (c- and h-states).
        """
        x = inputs
        y = torch.unsqueeze(state[0], 0)
        self._features, h = self.rnn(x, y)
        action_out = self.action_branch(self._features)
        return action_out, [torch.squeeze(h, 0)]

    @override(ModelV2)
    def custom_loss(self, policy_loss, loss_inputs):

        l2_lambda = self.l2_lambda
        l2_reg = torch.tensor(0.).to(self.device)
        # l2_reg += torch.norm(self.rnn.weight_hh_l0.data)
        l2_reg += torch.norm(self.rnn.weight_hh_l0).to(self.device)

        l2_lambda_inp = self.l2_lambda_inp
        l2_reg_inp = torch.tensor(0.).to(self.device)
        l2_reg_inp += torch.norm(self.rnn.weight_ih_l0).to(self.device)

        self.l2_loss = l2_lambda * l2_reg
        self.l2_loss_inp = l2_lambda_inp * l2_reg_inp
        self.original_loss = policy_loss

        assert self.l2_loss.requires_grad, "l2 loss no gradient"
        assert self.l2_loss_inp.requires_grad, "l2 loss no gradient"

        custom_loss = self.l2_loss + self.l2_loss_inp

        # depending on input add loss
        total_loss = [p_loss + custom_loss for p_loss in policy_loss]

        return total_loss

    def metrics(self):
        metrics = {
            "weight_loss": self.l2_loss.item(),
            # TODO Nguyen figure out if good or not
            "original_loss": self.original_loss[0].item(),
        }
        # you can print them to command line here. with Torch models its somehow not reported to the logger
        # print(metrics)

The testCartPole.py where I modified the model to use.

import os
import random
import importlib
import ray
from ray import air, tune
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
from ray.rllib.models import ModelCatalog
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.test_utils import check_learning_achieved

tf1, tf, tfv = try_import_tf()

parser = argparse.ArgumentParser()

parser.add_argument("--num-agents", type=int, default=4)
parser.add_argument("--num-policies", type=int, default=2)
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument(
    "--framework",
    choices=["tf", "tf2", "torch"],
    default="torch",
    help="The DL framework specifier.",
)
parser.add_argument(
    "--as-test",
    action="store_true",
    help="Whether this script should be run as a test: --stop-reward must "
    "be achieved within --stop-timesteps AND --stop-iters.",
)
parser.add_argument(
    "--stop-iters", type=int, default=200, help="Number of iterations to train."
)
parser.add_argument(
    "--stop-timesteps", type=int, default=100000, help="Number of timesteps to train."
)
parser.add_argument(
    "--stop-reward", type=float, default=300.0, help="Reward at which we stop training."
)
# os.environ["RLLIB_ENABLE_RL_MODULE"] = "False"
if __name__ == "__main__":
    args = parser.parse_args()

    ray.init(num_cpus=args.num_cpus or None, local_mode=True)

    # Register the models to use.
    # Each policy can have a different configuration (including custom model).

    def get_model(
        model_file,
        fc_size=200,
        rnn_hidden_size=256,
        max_seq_len=20,
        l2_curr=3,
        l2_inp=0,
        device="cuda",
        **_,
    ):
        md = importlib.import_module(model_file)
        myModel = getattr(md, "AnotherTorchRNNModel")
        modelName = "rnn_noFC"
        ModelCatalog.register_custom_model(modelName, myModel)
        print(f"Model Registered {model_file}.")
        model_dict = {
            "custom_model": modelName,
            "max_seq_len": max_seq_len,
            "custom_model_config": {
                "fc_size": fc_size,
                "rnn_hidden_size": rnn_hidden_size,
                "l2_lambda": l2_curr,
                "l2_lambda_inp": l2_inp,
                "device": device,  # or 'cuda'
            },
        }
        return model_dict

    model_dict = get_model("mySimpleRNN")
    def gen_policy(i):
        config = PPOConfig.overrides(
            model=model_dict,
            gamma=random.choice([0.95, 0.99]),
        )
        return PolicySpec(config=config)

    # Setup PPO with an ensemble of "num_policies" different policies.
    policies = {"policy_{}".format(i): gen_policy(i) for i in range(args.num_policies)}
    policy_ids = list(policies.keys())

    def policy_mapping_fn(agent_id, episode, worker, **kwargs):
        pol_id = random.choice(policy_ids)
        return pol_id

    config = (
        PPOConfig().experimental( _enable_new_api_stack=False)
        .environment(MultiAgentCartPole, env_config={"num_agents": args.num_agents})
        .framework(args.framework)
        .training(num_sgd_iter=10)
        .multi_agent(policies=policies, policy_mapping_fn=policy_mapping_fn)
    )

    # config.model.update(model_dict)
    config.experimental( _enable_new_api_stack=False).build()
    stop = {
        "episode_reward_mean": args.stop_reward,
        "timesteps_total": args.stop_timesteps,
        "training_iteration": args.stop_iters,
    }
    checkpoint_config = air.CheckpointConfig(
        checkpoint_frequency=5,
        # num_to_keep=100,
        checkpoint_at_end=True,
    )

    results = tune.Tuner(
        "PPO",
        param_space=config.to_dict(),
        run_config=air.RunConfig(
            stop=stop,
            verbose=1,
            checkpoint_config=checkpoint_config,
            local_dir="/home/lime/Documents/CartPoleTest",
        ),
    ).fit()

    if args.as_test:
        check_learning_achieved(results, args.stop_reward)
    ray.shutdown()

Issue Severity

High: It blocks me from completing my task.

JiangpengLI86 commented 5 months ago

I have also encountered this issue. :(

Any solution?

YuriyKortev commented 4 months ago

im encountered the same issue, except 'key' was ''state_out_0''. Is there some workaround?