ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.97k stars 5.77k forks source link

[RLlib] RandomPolicy does not handle nested spaces #30221

Closed rusu24edward closed 1 year ago

rusu24edward commented 2 years ago

What happened + What you expected to happen

I would like to use rllib's random policy for some agents in a multi-agent training setup. I have created two policies, on to train and the other that is random. I get the following error:

Failure # 1 (occurred at 2022-10-03_08-20-06)
Traceback (most recent call last):
  File "/Users/rusu1/.virtual_envs/v_abmarl/lib/python3.7/site-packages/ray/tune/ray_trial_executor.py", line 901, in get_next_executor_event
    future_result = ray.get(ready_future)
  File "/Users/rusu1/.virtual_envs/v_abmarl/lib/python3.7/site-packages/ray/_private/client_mode_hook.py", line 105, in wrapper
    return func(*args, **kwargs)
  File "/Users/rusu1/.virtual_envs/v_abmarl/lib/python3.7/site-packages/ray/worker.py", line 1809, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(IndexError): e[36mray::PGTrainer.train()e[39m (pid=6262, ip=127.0.0.1, repr=PGTrainer)
  File "/Users/rusu1/.virtual_envs/v_abmarl/lib/python3.7/site-packages/ray/tune/trainable.py", line 349, in train
    result = self.step()
  File "/Users/rusu1/.virtual_envs/v_abmarl/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 1093, in step
    raise e
  File "/Users/rusu1/.virtual_envs/v_abmarl/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 1074, in step
    step_attempt_results = self.step_attempt()
  File "/Users/rusu1/.virtual_envs/v_abmarl/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 1155, in step_attempt
    step_results = self._exec_plan_or_training_iteration_fn()
  File "/Users/rusu1/.virtual_envs/v_abmarl/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 2172, in _exec_plan_or_training_iteration_fn
    results = self.training_iteration()
  File "/Users/rusu1/.virtual_envs/v_abmarl/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 1420, in training_iteration
    new_sample_batches = synchronous_parallel_sample(self.workers)
  File "/Users/rusu1/.virtual_envs/v_abmarl/lib/python3.7/site-packages/ray/rllib/execution/rollout_ops.py", line 74, in synchronous_parallel_sample
    return [worker_set.local_worker().sample()]
  File "/Users/rusu1/.virtual_envs/v_abmarl/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 815, in sample
    batches = [self.input_reader.next()]
  File "/Users/rusu1/.virtual_envs/v_abmarl/lib/python3.7/site-packages/ray/rllib/evaluation/sampler.py", line 116, in next
    batches = [self.get_data()]
  File "/Users/rusu1/.virtual_envs/v_abmarl/lib/python3.7/site-packages/ray/rllib/evaluation/sampler.py", line 289, in get_data
    item = next(self._env_runner)
  File "/Users/rusu1/.virtual_envs/v_abmarl/lib/python3.7/site-packages/ray/rllib/evaluation/sampler.py", line 722, in _env_runner
    clip_actions=clip_actions,
  File "/Users/rusu1/.virtual_envs/v_abmarl/lib/python3.7/site-packages/ray/rllib/evaluation/sampler.py", line 1248, in _process_policy_eval_results
    env_id: int = eval_data[i].env_id
IndexError: list index out of range

We isolated the problem in this discussion post. The issue is the loop over observation batch. A single observation in my simulation looks like this:

{
    'position': ...,
    'left': ...,
    'right': ...
}

and unfortunately this loop in compute actions generates three actions because it loops over the entries of the dictionary, even though this is just a single observation. So it seems that this function struggles with nested spaces.

I can currently work around this, but will soon get to the point where I need to rely on random and heuristic policies for some of the agents.

Versions / Dependencies

This is ray 2.0.0, python 3.7.9, and mac os 10.15.7

Reproduction script

from abmarl.examples import MultiCorridor
from abmarl.managers import TurnBasedManager
from abmarl.external import MultiAgentWrapper
from ray.rllib.examples.policy.random_policy import RandomPolicy

sim = MultiAgentWrapper(TurnBasedManager(MultiCorridor()))

sim_name = "MultiCorridor"
from ray.tune.registry import register_env
register_env(sim_name, lambda sim_config: sim)

# Arbitrarily distinguish between two policies
ref_agent = sim.sim.agents['agent0']
policies = {
    'corridor_0': (None, ref_agent.observation_space, ref_agent.action_space, {}),
    'corridor_1': (RandomPolicy, ref_agent.observation_space, ref_agent.action_space, {}),
    # 'corridor_1': (None, ref_agent.observation_space, ref_agent.action_space, {}),
}

# Arbitrarily choose between the two policies
def policy_mapping_fn(agent_id):
    return f'corridor_{int(agent_id[-1]) % 2}'

# Experiment parameters
params = {
    'experiment': {
        'title': f'{sim_name}',
        'sim_creator': lambda config=None: sim,
    },
    'ray_tune': {
        'run_or_experiment': 'PG',
        'checkpoint_freq': 50,
        'checkpoint_at_end': True,
        'stop': {
            'episodes_total': 2000,
        },
        'verbose': 2,
        'config': {
            # --- Simulation ---
            'disable_env_checking': True,
            'env': sim_name,
            'horizon': 200,
            'env_config': {},
            # --- Multiagent ---
            'multiagent': {
                'policies': policies,
                'policy_mapping_fn': policy_mapping_fn,
            },
            # "lr": 0.0001,
            # --- Parallelism ---
            # Number of workers per experiment: int
            "num_workers": 0,
            # Number of simulations that each worker starts: int
            "num_envs_per_worker": 1, # This must be 1 because we are not "threadsafe"
        },
    }
}

if __name__ == "__main__":
    # Create output directory and save to params
    import os
    import time
    home = os.path.expanduser("~")
    output_dir = os.path.join(
        home, 'abmarl_results/{}_{}'.format(
            params['experiment']['title'], time.strftime('%Y-%m-%d_%H-%M')
        )
    )
    params['ray_tune']['local_dir'] = output_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Copy this configuration file to the output directory
    import shutil
    shutil.copy(os.path.join(os.getcwd(), __file__), output_dir)

    # Initialize and run ray
    import ray
    from ray import tune
    ray.init()
    tune.run(**params['ray_tune'])
    ray.shutdown()

Issue Severity

Medium: It is a significant difficulty but I can work around it.

sven1977 commented 1 year ago

Hey @rusu24edward , thanks for filing this. The above PR fixes the problem and should be merged very soon into master.