ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.94k stars 5.77k forks source link

[RLlib] Algorithm.training_step fails after episode termination because `"__all__"` is considered a module name #42396

Closed garymm closed 8 months ago

garymm commented 10 months ago

What happened + What you expected to happen

I'm trying to get a very basic RL Module working based on the examples and tests in the repo and I hit this issue.

When you run the attached reproduction script, it fails with:

Traceback (most recent call last):
  File "/Users/garymm/src/garymm/rl/./rl/ray/rllib_modules.py", line 139, in <module>
    train()
  File "/Users/garymm/src/garymm/rl/./rl/ray/rllib_modules.py", line 127, in train
    result = algo.train()
  File "/Users/garymm/mambaforge/envs/rl/lib/python3.10/site-packages/ray/tune/trainable/trainable.py", line 342, in train
    raise skipped from exception_cause(skipped)
  File "/Users/garymm/mambaforge/envs/rl/lib/python3.10/site-packages/ray/tune/trainable/trainable.py", line 339, in train
    result = self.step()
  File "/Users/garymm/mambaforge/envs/rl/lib/python3.10/site-packages/ray/rllib/algorithms/algorithm.py", line 852, in step
    results, train_iter_ctx = self._run_one_training_iteration()
  File "/Users/garymm/mambaforge/envs/rl/lib/python3.10/site-packages/ray/rllib/algorithms/algorithm.py", line 3042, in _run_one_training_iteration
    results = self.training_step()
  File "/Users/garymm/mambaforge/envs/rl/lib/python3.10/site-packages/ray/rllib/algorithms/algorithm.py", line 1657, in training_step
    self.workers.sync_weights(
  File "/Users/garymm/mambaforge/envs/rl/lib/python3.10/site-packages/ray/rllib/evaluation/worker_set.py", line 409, in sync_weights
    weights = weights_src.get_weights(policies)
  File "/Users/garymm/mambaforge/envs/rl/lib/python3.10/site-packages/ray/rllib/core/learner/learner_group.py", line 466, in get_weights
    state = self._learner.get_module_state(module_ids)
  File "/Users/garymm/mambaforge/envs/rl/lib/python3.10/site-packages/ray/rllib/core/learner/learner.py", line 806, in get_module_state
    module_states = self.module.get_state(module_ids)
  File "/Users/garymm/mambaforge/envs/rl/lib/python3.10/site-packages/ray/rllib/core/rl_module/marl_module.py", line 280, in get_state
    return {
  File "/Users/garymm/mambaforge/envs/rl/lib/python3.10/site-packages/ray/rllib/core/rl_module/marl_module.py", line 281, in <dictcomp>
    module_id: self._rl_modules[module_id].get_state()
KeyError: '__all__'

I'm not sure if this is the right fix, but if I change this in Algorithm.training_step:

self.workers.sync_weights(
    from_worker_or_learner_group=from_worker_or_trainer,
    policies=list(train_results.keys()),
    global_vars=global_vars,
)

to:

self.workers.sync_weights(
    from_worker_or_learner_group=from_worker_or_trainer,
    policies=[p for p in train_results if p != "__all__"],
    global_vars=global_vars,
)

It seems to fix the problem. Not at all confident that's right now, I'm very new to this code base.

I'm not sure how to work around this without modifying Ray. I'll post if I figure out a work-around.

Versions / Dependencies

2.9.0

Reproduction script

import argparse
import sys
from typing import Any, Mapping

import ray
import torch
from ray.rllib.algorithms import Algorithm, AlgorithmConfig
from ray.rllib.core.models.specs.typing import SpecType
from ray.rllib.core.learner.learner import LearnerHyperparameters
from ray.rllib.core.learner.torch.torch_learner import TorchLearner
from ray.rllib.core.rl_module.rl_module import (
    ModuleID,
    RLModule,
    RLModuleConfig,
    SingleAgentRLModuleSpec,
)
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule
from ray.rllib.models.torch.torch_distributions import TorchCategorical
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils.nested_dict import NestedDict
from ray.rllib.utils.typing import TensorType

from rl.ray import log_filter

class MyModule(TorchRLModule):
    def __init__(self, config: RLModuleConfig) -> None:
        super().__init__(config)

    def setup(self):
        input_dim = self.config.observation_space.shape[0]
        hidden_dim = self.config.model_config_dict["fcnet_hiddens"][0]
        output_dim = self.config.action_space.n

        self.policy = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, output_dim),
        )

        self.input_dim = input_dim

    def get_train_action_dist_cls(self):
        return TorchCategorical

    def get_exploration_action_dist_cls(self):
        return TorchCategorical

    def get_inference_action_dist_cls(self):
        return TorchCategorical

    @override(RLModule)
    def output_specs_exploration(self) -> SpecType:
        return [SampleBatch.ACTION_DIST_INPUTS]

    @override(RLModule)
    def output_specs_inference(self) -> SpecType:
        return [SampleBatch.ACTION_DIST_INPUTS]

    @override(RLModule)
    def output_specs_train(self) -> SpecType:
        return [SampleBatch.ACTION_DIST_INPUTS]

    @override(RLModule)
    def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]:
        with torch.no_grad():
            return self._forward_train(batch)

    @override(RLModule)
    def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]:
        with torch.no_grad():
            return self._forward_train(batch)

    @override(RLModule)
    def _forward_train(self, batch: NestedDict) -> Mapping[str, Any]:
        action_logits = self.policy(batch["obs"])
        return {SampleBatch.ACTION_DIST_INPUTS: action_logits}

class MyLearner(TorchLearner):
    def compute_loss_for_module(
        self,
        module_id: ModuleID,
        hps: LearnerHyperparameters,
        batch: NestedDict,
        fwd_out: Mapping[str, TensorType],
    ) -> Mapping[str, Any]:
        action_dist_inputs = fwd_out[SampleBatch.ACTION_DIST_INPUTS]
        action_dist_class = self._module[module_id].get_train_action_dist_cls()
        action_dist = action_dist_class.from_logits(action_dist_inputs)
        loss = -torch.mean(action_dist.logp(batch[SampleBatch.ACTIONS]))

        return loss

# Copied from ray/rllib/core/testing/bc_algorithm.py to remove TF dep
class MyAlgorithmConfig(AlgorithmConfig):
    def __init__(self, algo_class=None):
        super().__init__(algo_class=algo_class or MyAlgorithm)

    def get_default_rl_module_spec(self):
        return SingleAgentRLModuleSpec(module_class=MyModule)

    def get_default_learner_class(self):
        return MyLearner

class MyAlgorithm(Algorithm):
    @classmethod
    def get_default_policy_class(cls, config: AlgorithmConfig):
        return TorchPolicyV2

def train():
    algo_config: AlgorithmConfig = (
        MyAlgorithmConfig()
        .experimental(_enable_new_api_stack=True)
        .environment("CartPole-v1")
        .training(model={"fcnet_hiddens": [32, 32]})
    )

    algo: Algorithm = algo_config.build()
    _NUM_STEPS = 10
    for _ in range(_NUM_STEPS):
        result = algo.train()
    print("")
    print(result)

if __name__ == "__main__":
    sys.stderr = log_filter.filtered_stderr
    arg_parser = argparse.ArgumentParser()
    arg_parser.add_argument("--ray_addr", type=str)
    args = arg_parser.parse_args()
    ray.init(address=args.ray_addr)
    print(ray.available_resources())
    train()

Issue Severity

High: It blocks me from completing my task.

garymm commented 9 months ago

Seems __all__ may be added here in Learner.compile_results: https://github.com/ray-project/ray/blob/40223ff75a31c4c3fc490923f9578964102cbc70/rllib/core/learner/learner.py#L752

I'm not really sure what all is suppsoed to be for so I'm not sure where the right place to filter it out is. CC @sven1977

sven1977 commented 8 months ago

Hey @garymm , thanks for raising this issue! You are absolutely right, this is causing a problem and needs a fix. We usually don't run anything with the Algorithm's default implementation of training_step (let alone multi-agent stuff) so this slipped through.

In PPO's training_step method, we do something like:

policies_to_update = set(train_results.keys()) - {ALL_MODULES}  # <- ALL_MODULES == "__all__"

and then pass that as policies into the sync_weights call. This is similar to your suggestion.

We'll provide a fix-PR ...

In the meantime, you can also take a look at this currently-in-review PR, which brings self-play and league-based self-play into the new API stack, including example scripts (for PPO): https://github.com/ray-project/ray/pull/43276

But this PR will not fix your problem. I'll create a new one.

sven1977 commented 8 months ago

PR in review: https://github.com/ray-project/ray/pull/43316