ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
34.14k stars 5.8k forks source link

[RLlib] Connectors API get_actions does not compute action_logp when actions are present. #44662

Closed DenBuzz closed 1 month ago

DenBuzz commented 7 months ago

What happened + What you expected to happen

I'm working with an autoregressive model and trying to leverage the new API stack. Because of the nature of autoregressive models, I need to be sampling actions during the model's forward passes to use as inputs to later stages of the model. I figured this was fine because I can simply return those actions in the "actions" key along side the "action_dist_inputs" that were used to compute those actions. The RLModule documentation here seems to suggest that is totally fine and "action_logp" would be computed automatically later. However, when trying to train I get "action_logp" key errors when PPOTorchLearner attempts to compute the logp_ratio during compute_loss_for_module.

I believe the issue stems from ray/rllib/connectors/module_to_env/get_actions.py where if "actions" are present in the data already there is an early return skipping any attempt compute "action_logp" values.

    def _get_actions(self, data, sa_rl_module, explore):
        # Action have already been sampled -> Early out.
        if Columns.ACTIONS in data:
            return

        # ACTION_DIST_INPUTS field returned by `forward_exploration|inference()` ->
        # Create a new action distribution object.
        action_dist = None
        if Columns.ACTION_DIST_INPUTS in data:
            if explore:
                action_dist_class = sa_rl_module.get_exploration_action_dist_cls()
            else:
                action_dist_class = sa_rl_module.get_inference_action_dist_cls()
            action_dist = action_dist_class.from_logits(
                data[Columns.ACTION_DIST_INPUTS],
            )
            if not explore:
                action_dist = action_dist.to_deterministic()

            # Sample actions from the distribution.
            actions = action_dist.sample()
            data[Columns.ACTIONS] = actions

            # For convenience and if possible, compute action logp from distribution
            # and add to output.
            if Columns.ACTION_LOGP not in data:
                data[Columns.ACTION_LOGP] = action_dist.logp(actions)

My plan to work around this right now, is to manually compute the log_probs myself during the forward pass and add them to the output. Or better yet, try using a custom connector?

Versions / Dependencies

This is present in the 2.10 release as well as the current master. I'm running python 3.11 on Manjaro.

Reproduction script

Well it's hard to build a quick example of this without building a custom module from scratch. I will look around through examples to see if I can find a quick way to reproduce.

Issue Severity

Low: It annoys or frustrates me.

DenBuzz commented 7 months ago

I got an example of the error by modifying the most up to date mobilenet_rlm example.

The changes I made were:

"""
This example shows how to take full control over what models and action distribution
are being built inside an RL Module. With this pattern, we can bypass a Catalog and
explicitly define our own models within a given RL Module.
"""

# __sphinx_doc_begin__
import gymnasium as gym
import numpy as np

from ray.rllib.core.columns import Columns
from ray.rllib.env.single_agent_env_runner import SingleAgentEnvRunner

from ray.rllib.algorithms.ppo.ppo import PPOConfig
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule
from ray.rllib.core.models.configs import MLPHeadConfig
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.examples.envs.classes.random_env import RandomEnv
from ray.rllib.models.torch.torch_distributions import TorchCategorical
from ray.rllib.examples._old_api_stack.models.mobilenet_v2_encoder import (
    MobileNetV2EncoderConfig,
    MOBILENET_INPUT_SHAPE,
)
from ray.rllib.core.models.configs import ActorCriticEncoderConfig

class MobileNetTorchPPORLModule(PPOTorchRLModule):
    """A PPORLModules with mobilenet v2 as an encoder.

    The idea behind this model is to demonstrate how we can bypass catalog to
    take full control over what models and action distribution are being built.
    In this example, we do this to modify an existing RLModule with a custom encoder.
    """

    def setup(self):
        mobilenet_v2_config = MobileNetV2EncoderConfig()
        # Since we want to use PPO, which is an actor-critic algorithm, we need to
        # use an ActorCriticEncoderConfig to wrap the base encoder config.
        actor_critic_encoder_config = ActorCriticEncoderConfig(
            base_encoder_config=mobilenet_v2_config
        )

        self.encoder = actor_critic_encoder_config.build(framework="torch")
        mobilenet_v2_output_dims = mobilenet_v2_config.output_dims

        pi_config = MLPHeadConfig(
            input_dims=mobilenet_v2_output_dims,
            output_layer_dim=2,
        )

        vf_config = MLPHeadConfig(
            input_dims=mobilenet_v2_output_dims, output_layer_dim=1
        )

        self.pi = pi_config.build(framework="torch")
        self.vf = vf_config.build(framework="torch")

        self.action_dist_cls = TorchCategorical

    def _forward_exploration(self, batch, **kwargs):
        output = super()._forward_exploration(batch, **kwargs)
        dist_class = self.get_exploration_action_dist_cls().from_logits(
            output[Columns.ACTION_DIST_INPUTS]
        )
        output[Columns.ACTIONS] = dist_class.sample()
        return output

    def output_specs_exploration(self):
        return super().output_specs_exploration() + [Columns.ACTIONS]

config = (
    PPOConfig()
    .experimental(_enable_new_api_stack=True)
    .rl_module(
        rl_module_spec=SingleAgentRLModuleSpec(module_class=MobileNetTorchPPORLModule)
    )
    .environment(
        RandomEnv,
        env_config={
            "action_space": gym.spaces.Discrete(2),
            # Test a simple Image observation space.
            "observation_space": gym.spaces.Box(
                0.0,
                1.0,
                shape=MOBILENET_INPUT_SHAPE,
                dtype=np.float32,
            ),
        },
    )
    .rollouts(num_rollout_workers=0, env_runner_cls=SingleAgentEnvRunner)
    # The following training settings make it so that a training iteration is very
    # quick. This is just for the sake of this example. PPO will not learn properly
    # with these settings!
    .training(train_batch_size=32, sgd_minibatch_size=16, num_sgd_iter=1)
)

config.build().train()
# __sphinx_doc_end__
DenBuzz commented 7 months ago

Unless the intent is for the "action_logp" key to be added somewhere else, I think this could be a fixed version:

    def _get_actions(self, data, sa_rl_module, explore):
        # ACTION_DIST_INPUTS field returned by `forward_exploration|inference()` ->
        # Create a new action distribution object.
        if Columns.ACTION_DIST_INPUTS in data:
            if explore:
                action_dist_class = sa_rl_module.get_exploration_action_dist_cls()
            else:
                action_dist_class = sa_rl_module.get_inference_action_dist_cls()
            action_dist = action_dist_class.from_logits(
                data[Columns.ACTION_DIST_INPUTS],
            )
            if not explore:
                action_dist = action_dist.to_deterministic()

            # Sample actions from the distribution.
            if Columns.ACTIONS not in data:
                data[Columns.ACTIONS] = action_dist.sample()

            # For convenience and if possible, compute action logp from distribution
            # and add to output.
            if Columns.ACTION_LOGP not in data:
                data[Columns.ACTION_LOGP] = action_dist.logp(data[Columns.ACTIONS])
simonsays1980 commented 4 months ago

@DenBuzz Thanks for raising this issue. There are indeed some inconsistencies between docs and code here. The code above does not run however, but I could not reproduce the errors using our rllib/examples/rl_modules/autoregressive_actions_rlm.py while commenting logps out. While the get_actions connector piece is still not calculating logps they are computing during loss computation.

Is there any reason why logps need to be used during exploration? Otherwise changing the docs accordingly will suffice here imo.

DenBuzz commented 4 months ago

Thanks for taking a look! I'm happy to see that there's a new example for an autoregressive rl module.

My understanding of logp values is that they are not required for exploration but they are referenced during the loss computation. Specifically they must be present in the batch for the logp_ratio computation on line 73 in PPOTorchLearner. The only way those logp values will be present is if they are returned during the forward_exploration along with the actions and action_dist_inputs or they are added later by something like the get_actions connector. And in the context of an autoregressive model that returns sampled actions under the "actions" key, that get_actions connector will not attempt to add those logp values.


I was also able to recreate the issue using the autoregressive_actions_rlm.py example as well.

First, that example is not properly configured to use new api_stack. When running it I receive the warning "You have setup a RLModuleSpec (via calling config.rl_module(...)), but have not enabled the new API stack. To enable it, call config.api_stack(enable_rl_module_and_learner=True)"

So I added .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True) after line 78 to make sure that autoregressive rl_module was actually getting used and we were using the new env_runners and connectors.

Once the module was being used, I got errors saying the compute_values method was not implemented because the module had implemented _compute_values by mistake instead. So I deleted the underscore on lines 145 and 282 in autoregressive_actions_rlm.py.

And with that, the example was training with the correct module and without errors.

Finally, to reproduce the error I removed the ACTION_LOGP key from the exploration_specs and train_specs and commented out line 235 where those values were being added.

I can attach the whole files if needed.


I think the solution I proposed before should still solve this problem. Just because actions are returned from the forward pass doesn't mean the get_actions connector should skip trying to compute the action logp's. However, given that there's a new example that is explicitly adding those values during forward_exploration, maybe that's the intended pattern and we only need to update the docs accordingly?