ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.76k stars 5.74k forks source link

[RLlib] Restoring an algorithm from_checkpoint expects same number of rollout workers available #36761

Open mzat-msft opened 1 year ago

mzat-msft commented 1 year ago

What happened + What you expected to happen

When I train an algorithm with tune specifying for example num_tune_samples=10 and try to restore the best algorithm using Algorithm.from_checkpoint(), Ray tries to get 10 CPUs from the machine. If the machine has not enough CPUs available it starts to throw this warning and never restore the algorithm:

(autoscaler +2m24s) Warning: The following resource request cannot be scheduled right now: {'CPU': 1.0}. This is likely due to all cluster resources being claimed by actors. Consider creating fewer actors or adding more nodes to this Ray cluster.

I would expect this to be portable and to work on any machine I bring the checkpoints along.

Versions / Dependencies

Observed with ray==2.3.0 and tensorflow==2.11.1 on Linux but I believe it is a common issue

Reproduction script

from gymnasium import Env
from ray.rllib.algorithms.algorithm import Algorithm
from ray.tune.registry import register_env

class DummyEnv(Env):
    def __init__(self, env_config, observation_space=None, action_space=None):
        if observation_space is None:
            raise TypeError("observation_space cannot be of type None.")
        self.observation_space = observation_space

        if action_space is None:
            raise TypeError("action_space cannot be of type None.")
        self.action_space = action_space

    def step(self, action):
        return self.observation_space.sample(), 0, False, False, {}

    def reset(self, *, seed=None, options=None):
        return self.observation_space.sample(), {}

def restore_agent(
    observation_space,
    action_space,
    checkpoint_path,
    name_env="sim_env",
):
    register_env(name_env, lambda conf: DummyEnv(conf, observation_space, action_space))
    return Algorithm.from_checkpoint(checkpoint_path)

agent = restore_agent(spaces.Discrete(2), spaces.Discrete(2), 'checkpoints/folder/with-10-rollout-workers', 'name-of-environment-used-for-training')

Issue Severity

High: It blocks me from completing my task.

avnishn commented 1 year ago

Something you can do here is directly restore the RLModule that is inside of the policy instead, either for training or for inference.

Here are some tests that act as pretty good documentation on the new way that we recommend restoring trained policies/RLModules:

https://sourcegraph.com/github.com/ray-project/ray/-/blob/release/rllib_tests/checkpointing_tests/test_e2e_rl_module_restore.py?L176

Let me know if something like this works for you. Thanks :)

avnishn commented 1 year ago

related: https://github.com/ray-project/ray/issues/36830

mzat-msft commented 1 year ago

Hi, thanks for your suggestion.

Are you suggesting to basically rebuild the algorithm config, overriding the number of workers, and then use Algorithm.restore() to load the weights? IIUC this is equivalent to what I implemented here? https://github.com/Azure/plato/commit/cfba87d06ed0a0d883a39f70e0d393cd0c812391

sven1977 commented 10 months ago

Yes, we need to fix this. :) We will (in the near future) go back to requiring the user to always bring along their (original or changed) configs when restoring.

For now as a workaround, the following hack should work:

from ray.rllib.utils.checkpoints import get_checkpoint_info

# Instead of calling .from_checkpoint directly, do this procedure:
checkpoint_info = get_checkpoint_info(checkpoint)
state = Algorithm._checkpoint_info_to_algorithm_state(
    checkpoint_info=checkpoint_info,
    policy_ids=None,
    policy_mapping_fn=None,
    policies_to_train=None,
)

state["config"] = ...  # drop-in your own, altered (num_rollout_workers?) AlgorithmConfig (not old config dict!!) object here.

algo = Algorithm.from_state(state)

# This `algo` should now have/require fewer rollout workers.