ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
34.21k stars 5.81k forks source link

[RLlib] Algorithm cannot be created in driver #45890

Open lonsdale8734 opened 5 months ago

lonsdale8734 commented 5 months ago

What happened + What you expected to happen

When num_learners is set > 0 with new api stack, Algorithm cannot be created in driver out of ray cluster. It is a bug or by design?

Run with python -W ignore::DeprecationWarning demo.py with raise error:

ray/_private/worker.py", line 534, in should_capture_child_tasks_in_placement_group
    return self.core_worker.should_capture_child_tasks_in_placement_group()
           ^^^^^^^^^^^^^^^^
AttributeError: 'Worker' object has no attribute 'core_worker'

Run with ray job submit --address http://clusterip:8265 --working-dir . -- python demo.py will be ok.

Versions / Dependencies

python==3.11.9
ray[all]==2.24.0

Reproduction script

import ray
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import pretty_print

def is_running_in_ray_job():
    import os

    return "RAY_JOB_CONFIG_JSON_ENV_VAR" in os.environ

if is_running_in_ray_job():
    ray.init()
else:
    ray.init(address="ray://clusterip:10001")

algo = (
    PPOConfig()
    # enable new api stack
    .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)
    .env_runners(num_env_runners=1)
    .resources(num_gpus=0)
    .learners(num_learners=1)  # cause the crash
    .environment(env="CartPole-v1")
    .build()
)

# for i in range(10):
#     result = algo.train()
#     print(pretty_print(result))
#
#     if i % 5 == 0:
#         checkpoint_dir = algo.save().checkpoint.path
#         print(f"Checkpoint saved in directory {checkpoint_dir}")

Issue Severity

Medium: It is a significant difficulty but I can work around it.

lonsdale8734 commented 5 months ago

Workarround:

import ray
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import pretty_print

ray.init(address="ray://clusterip:10001")

@ray.remote
def main():
    algo = (
        PPOConfig()
        # enable new api stack
        .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)
        .env_runners(num_env_runners=1)
        .resources(num_gpus=0)
        .learners(num_learners=1)
        .environment(env="CartPole-v1")
        .build()
    )

    for i in range(10):
        result = algo.train()
        print(pretty_print(result))

        if i % 5 == 0:
            checkpoint_dir = algo.save().checkpoint.path
            print(f"Checkpoint saved in directory {checkpoint_dir}")

f = main.remote()
ray.get(f)
simonsays1980 commented 5 months ago

@lonsdale8734 Thanks for raising this issue. Running your script locally runs, so there are no errors in the RLlib logic. Running remote on a cluster via ray jpb API, you should ensure that the dashboard port 8265 is reachable.

@anyscalesam Even though the main logic is using RLlib the error appears to be rather in Ray Core than in RLlib.