ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
32.14k stars 5.48k forks source link

[RLlib] New API Stack: Action masking not working with wrapper, default encoder config issue #44780

Open PhilippWillms opened 3 months ago

PhilippWillms commented 3 months ago

What happened + What you expected to happen

The new API stack for RLlib seems to have challenges with observation wrappers, which are quite handy for action masking models. Unlike #44452, it is now the SingleAgentEnvRunner finding no default encoder config for the 2D box space. This error message sounds familiar, however, in the stable RLlib stack this topic was overcome by using a wrapper. This wrapper does not seem to be recognized correctly.

2024-04-16 21:50:28,071 ERROR actor_manager.py:517 -- Ray error, taking actor 1 out of service. The actor died because of an error raised in its creation task, ray::SingleAgentEnvRunner.init() (pid=3964, ip=192.168.178.26, actor_id=aff7d1bf298e2b06dcd9d5c801000000, repr=<ray.rllib.env.single_agent_env_runner.SingleAgentEnvRunner object at 0x0000024793D75F60>) File "python\ray_raylet.pyx", line 1887, in ray._raylet.execute_task File "python\ray_raylet.pyx", line 1828, in ray._raylet.execute_task.function_executor File "C:\Users\Philipp\anaconda3\envs\torch-gpu-310\lib\site-packages\ray_private\function_manager.py", line 691, in actor_method_executor return method(ray_actor, *args, kwargs) File "C:\Users\Philipp\anaconda3\envs\torch-gpu-310\lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span return method(self, *_args, *_kwargs) File "C:\Users\Philipp\anaconda3\envs\torch-gpu-310\lib\site-packages\ray\rllib\env\single_agent_env_runner.py", line 82, in init self.module: RLModule = module_spec.build() File "C:\Users\Philipp\anaconda3\envs\torch-gpu-310\lib\site-packages\ray\rllib\core\rl_module\rl_module.py", line 102, in build module = self.module_class(module_config) File "C:\Users\Philipp\anaconda3\envs\torch-gpu-310\lib\site-packages\ray\rllib\core\rl_module\rl_module.py", line 399, in new_init previous_init(self, args, kwargs) File "C:\Users\Philipp\anaconda3\envs\torch-gpu-310\lib\site-packages\ray\rllib\core\rl_module\rl_module.py", line 399, in new_init previous_init(self, *args, kwargs) File "C:\Users\Philipp\anaconda3\envs\torch-gpu-310\lib\site-packages\ray\rllib\examples\rl_modules\classes\action_masking_rlm.py", line 29, in init super().init(config) File "C:\Users\Philipp\anaconda3\envs\torch-gpu-310\lib\site-packages\ray\rllib\core\rl_module\rl_module.py", line 399, in new_init previous_init(self, *args, *kwargs) File "C:\Users\Philipp\anaconda3\envs\torch-gpu-310\lib\site-packages\ray\rllib\core\rl_module\rl_module.py", line 399, in new_init previous_init(self, args, kwargs) File "C:\Users\Philipp\anaconda3\envs\torch-gpu-310\lib\site-packages\ray\rllib\core\rl_module\torch\torch_rl_module.py", line 86, in init RLModule.init(self, *args, **kwargs) File "C:\Users\Philipp\anaconda3\envs\torch-gpu-310\lib\site-packages\ray\rllib\core\rl_module\rl_module.py", line 391, in init self.setup() File "C:\Users\Philipp\anaconda3\envs\torch-gpu-310\lib\site-packages\ray\rllib\algorithms\ppo\ppo_rl_module.py", line 20, in setup catalog = self.config.get_catalog() File "C:\Users\Philipp\anaconda3\envs\torch-gpu-310\lib\site-packages\ray\rllib\core\rl_module\rl_module.py", line 211, in get_catalog return self.catalog_class( File "C:\Users\Philipp\anaconda3\envs\torch-gpu-310\lib\site-packages\ray\rllib\algorithms\ppo\ppo_catalog.py", line 69, in init super().init( File "C:\Users\Philipp\anaconda3\envs\torch-gpu-310\lib\site-packages\ray\rllib\core\models\catalog.py", line 112, in init__ self._determine_components_hook() File "C:\Users\Philipp\anaconda3\envs\torch-gpu-310\lib\site-packages\ray\rllib\core\models\catalog.py", line 132, in _determine_components_hook self._encoder_config = self._get_encoder_config( File "C:\Users\Philipp\anaconda3\envs\torch-gpu-310\lib\site-packages\ray\rllib\core\models\catalog.py", line 368, in _get_encoder_config raise ValueError( ValueError: No default encoder config for obs space=Box(0.0, 1.0, (3, 4), float32), lstm=False and attention=False found. 2D Box spaces are not supported. They should be either flattened to a 1D Box space or enhanced to be a 3D box space.

Versions / Dependencies

Python==3.10.13 ray==3.0.0.dev0 gymnasium==0.28.1 numpy==1.24.4 torch==2.2.1+cu121 tensorflow==2.16.1 Windows 11

Reproduction script

import logging
from typing import OrderedDict
from typing import Tuple

import gymnasium
import numpy as np
import ray
from gymnasium.spaces import Box
from gymnasium.spaces import Dict
from gymnasium.spaces import Discrete
from gymnasium.wrappers import TransformObservation
from ray.rllib.algorithms import PPOConfig
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.examples.rl_modules.classes.action_masking_rlm import TorchActionMaskRLM
from ray.rllib.env.single_agent_env_runner import SingleAgentEnvRunner
from ray.tune.registry import register_env

logger = logging.getLogger()
logger.setLevel("WARN")

class MyRealObsWrapper(TransformObservation):
    """Special Wrapper needed for new RLlib API stack."""

    def __init__(self, env):  # noqa: D107
        super().__init__(env, self.__transform)

    def __transform(self, orig_obs):  # noqa: D107
        new_obs = orig_obs
        for b in new_obs.keys():
            if b not in ["static_features"]:
                new_obs[b] = np.reshape(new_obs[b], -1)
        # Important to update the observation space, otherwise the RLlib algorithms will not work
        self.observation_space["observations"] = Box(
            0, 1, (len(new_obs["observations"]),)
        )
        return new_obs

class MyEnv(gymnasium.Env):
    """Simple custom environment with nested obs space and action masking."""

    def __init__(self, *args, **kwargs):  # noqa: D107
        # print("Init method called.")
        self.action_space = Discrete(3)
        self.observation_space = Dict(
            {
                "action_mask": Box(
                    low=0, high=1, shape=(self.action_space.n,), dtype=np.int8
                ),
                "observations": Box(
                    low=0.0,
                    high=1.0,
                    shape=(3, 4),
                    dtype=np.float32,
                ),
                # "static_features": Dict(...)
            }
        )
        self.episode_done = False
        self._action_max_helper = np.ones(self.action_space.n, dtype=np.int8)
        self.state = np.zeros((3, 4), dtype=np.float32)

    def step(
        self, action: int
    ) -> Tuple[OrderedDict, float, bool, bool, dict]:  # noqa: D102
        # print(f"Step function called with action {action}.")
        # Error handling for invalid action
        if (action < 0) | (action > self.action_space.n):
            e_string = f"Action [{action}] is not valid! Size of the action space: [{self.action_space.n}]."
            raise Exception(e_string)
        if self._action_max_helper[action] == 0:
            e_string = f"Action [{action}] is not valid as chosen already in episode !"
            raise Exception(e_string)

        some_dict = {}
        if action not in some_dict.keys():
            some_dict[action] = 1
            # logger.warning("Action key added to dict.")
        # print(f"Existing value in dict: {some_dict[action]}")

        reward = 0 - action
        self.state[action][0] = 1
        self._action_max_helper[action] = 0
        if all(self._action_max_helper[k] == 0 for k in range(3)):
            self.episode_done = True
        # print(f"State after step: {self.state}.")
        return self._get_state_repr(), reward, self.episode_done, False, {}

    def _get_state_repr(self) -> OrderedDict:
        return {
            "action_mask": self._action_max_helper,
            "observations": self.state,
        }

    def reset(
        self, *, seed=None, options=None
    ) -> Tuple[OrderedDict, dict]:  # noqa: D102
        # print("Reset method called.")
        self.episode_done = False
        # Initial state representation = shape of the obs space.
        self.state = np.zeros((3, 4), dtype=np.float32)
        # Initial action mask = all actions are allowed.
        self._action_max_helper = np.ones(self.action_space.n, dtype=np.int8)
        return self._get_state_repr(), {}

def env_creator(env_config):
    """Create the environment with a wrapper."""
    env = MyEnv()
    env = MyRealObsWrapper(env)
    return env

# Use classic API to register environment
register_env("myenv_wrapped", env_creator)

if __name__ == "__main__":
    rlm_spec = SingleAgentRLModuleSpec(module_class=TorchActionMaskRLM)

    # Algorithm Config, but with the latest RLlib API
    config = (
        PPOConfig()
        .environment("myenv_wrapped")
        # We need to disable preprocessing of observations, because preprocessing
        # would flatten the observation dict of the environment.
        .experimental(_disable_preprocessor_api=True, _enable_new_api_stack=True)
        .framework("torch")
        .rollouts(env_runner_cls=SingleAgentEnvRunner)
        .resources(num_gpus=1, num_cpus_per_worker=2, num_gpus_per_worker=0.3)
        .rl_module(rl_module_spec=rlm_spec)
        .training(lr=1e-3, train_batch_size=50, sgd_minibatch_size=10, model={"uses_new_env_runners": True})
    )

    algo = config.build()

    # run manual training loop and print results after each iteration
    for i in range(10):
        result = algo.train()
        print(f"Training iteration: {i+1} done")
        # pprint(result)

    ray.shutdown()

Issue Severity

High: It blocks me from completing my task.

AliceHZhu commented 2 months ago

Got the error below with version 2.9.2, might due to the same issue. What is the version that works with the correct encoder config supporting MultiDiscrete obs space?

ValueError: No default encoder config for obs space=MultiDiscrete([2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2]), lstm=False and attention=False found.

PhilippWillms commented 1 month ago

@simonsays1980 , @sven1977 : Any update on this issue? Can you probably give an outlook when it will be fixed? Thank you very much!

grizzlybearg commented 3 days ago

@simonsays1980 It appears that the pull request https://github.com/ray-project/ray/pull/46146 doesn't fully resolve the observation wrapper issue. Nonetheless, I've built my work (action masking with the new API) on that pull request, so thank you for your contribution. @PhilippWillms, you're correct, the issue at https://github.com/ray-project/ray/issues/46631 is very similar to this one. Have you found a solution?

PhilippWillms commented 3 days ago

@simonsays1980 It appears that the pull request #46146 doesn't fully resolve the observation wrapper issue. Nonetheless, I've built my work (action masking with the new API) on that pull request, so thank you for your contribution. @PhilippWillms, you're correct, the issue at #46631 is very similar to this one. Have you found a solution?

No, I continue to use old API stack at the moment.

grizzlybearg commented 3 days ago

@simonsays1980 It appears that the pull request #46146 doesn't fully resolve the observation wrapper issue. Nonetheless, I've built my work (action masking with the new API) on that pull request, so thank you for your contribution. @PhilippWillms, you're correct, the issue at #46631 is very similar to this one. Have you found a solution?

No, I continue to use old API stack at the moment.

I'm exploring the possibility of using a custom CNNEncoder as it's the only viable option. The most viable implementation being close to:

import torch
import torch.nn as nn
import numpy as np
from gymnasium.spaces import Box

# Define the observation space
window_size = 1000
n_features = 50
box = Box(
    low=-np.inf,
    high=np.inf,
    shape=(window_size, n_features),
    dtype=np.float32,
)

class ConvObservationEncoder(nn.Module):
    def __init__(self, input_shape, output_dim):
        super(ConvObservationEncoder, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=input_shape[1], out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(128 * input_shape[0], output_dim)

    def forward(self, x):
        x = x.permute(0, 2, 1)  # Change shape to (batch_size, n_features, window_size)
        x = self.conv1(x)
        x = nn.ReLU()(x)
        x = self.conv2(x)
        x = nn.ReLU()(x)
        x = self.flatten(x)
        x = self.linear(x)
        return x

# Example usage
input_shape = box.shape
output_dim = 128  # Example output dimension
encoder = ConvObservationEncoder(input_shape, output_dim)

# Create a dummy observation
observation = np.random.rand(*input_shape).astype(np.float32)
observation_tensor = torch.tensor(observation).unsqueeze(0)  # Add batch dimension

# Encode the observation
encoded_observation = encoder(observation_tensor)
print(encoded_observation.shape)  # Should print: torch.Size([1, 128])

In this code:

This approach can be more efficient and effective, especially if your observations have spatial or temporal dependencies as is my case. Of course, I'll have to make further mods to ensure that it fits RLlib's new API stack