Closed DenBuzz closed 1 month ago
I got an example of the error by modifying the most up to date mobilenet_rlm
example.
The changes I made were:
_forward_exploration
to include the sampled actionsSingleAgentEnvRunner
"""
This example shows how to take full control over what models and action distribution
are being built inside an RL Module. With this pattern, we can bypass a Catalog and
explicitly define our own models within a given RL Module.
"""
# __sphinx_doc_begin__
import gymnasium as gym
import numpy as np
from ray.rllib.core.columns import Columns
from ray.rllib.env.single_agent_env_runner import SingleAgentEnvRunner
from ray.rllib.algorithms.ppo.ppo import PPOConfig
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule
from ray.rllib.core.models.configs import MLPHeadConfig
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.examples.envs.classes.random_env import RandomEnv
from ray.rllib.models.torch.torch_distributions import TorchCategorical
from ray.rllib.examples._old_api_stack.models.mobilenet_v2_encoder import (
MobileNetV2EncoderConfig,
MOBILENET_INPUT_SHAPE,
)
from ray.rllib.core.models.configs import ActorCriticEncoderConfig
class MobileNetTorchPPORLModule(PPOTorchRLModule):
"""A PPORLModules with mobilenet v2 as an encoder.
The idea behind this model is to demonstrate how we can bypass catalog to
take full control over what models and action distribution are being built.
In this example, we do this to modify an existing RLModule with a custom encoder.
"""
def setup(self):
mobilenet_v2_config = MobileNetV2EncoderConfig()
# Since we want to use PPO, which is an actor-critic algorithm, we need to
# use an ActorCriticEncoderConfig to wrap the base encoder config.
actor_critic_encoder_config = ActorCriticEncoderConfig(
base_encoder_config=mobilenet_v2_config
)
self.encoder = actor_critic_encoder_config.build(framework="torch")
mobilenet_v2_output_dims = mobilenet_v2_config.output_dims
pi_config = MLPHeadConfig(
input_dims=mobilenet_v2_output_dims,
output_layer_dim=2,
)
vf_config = MLPHeadConfig(
input_dims=mobilenet_v2_output_dims, output_layer_dim=1
)
self.pi = pi_config.build(framework="torch")
self.vf = vf_config.build(framework="torch")
self.action_dist_cls = TorchCategorical
def _forward_exploration(self, batch, **kwargs):
output = super()._forward_exploration(batch, **kwargs)
dist_class = self.get_exploration_action_dist_cls().from_logits(
output[Columns.ACTION_DIST_INPUTS]
)
output[Columns.ACTIONS] = dist_class.sample()
return output
def output_specs_exploration(self):
return super().output_specs_exploration() + [Columns.ACTIONS]
config = (
PPOConfig()
.experimental(_enable_new_api_stack=True)
.rl_module(
rl_module_spec=SingleAgentRLModuleSpec(module_class=MobileNetTorchPPORLModule)
)
.environment(
RandomEnv,
env_config={
"action_space": gym.spaces.Discrete(2),
# Test a simple Image observation space.
"observation_space": gym.spaces.Box(
0.0,
1.0,
shape=MOBILENET_INPUT_SHAPE,
dtype=np.float32,
),
},
)
.rollouts(num_rollout_workers=0, env_runner_cls=SingleAgentEnvRunner)
# The following training settings make it so that a training iteration is very
# quick. This is just for the sake of this example. PPO will not learn properly
# with these settings!
.training(train_batch_size=32, sgd_minibatch_size=16, num_sgd_iter=1)
)
config.build().train()
# __sphinx_doc_end__
Unless the intent is for the "action_logp" key to be added somewhere else, I think this could be a fixed version:
def _get_actions(self, data, sa_rl_module, explore):
# ACTION_DIST_INPUTS field returned by `forward_exploration|inference()` ->
# Create a new action distribution object.
if Columns.ACTION_DIST_INPUTS in data:
if explore:
action_dist_class = sa_rl_module.get_exploration_action_dist_cls()
else:
action_dist_class = sa_rl_module.get_inference_action_dist_cls()
action_dist = action_dist_class.from_logits(
data[Columns.ACTION_DIST_INPUTS],
)
if not explore:
action_dist = action_dist.to_deterministic()
# Sample actions from the distribution.
if Columns.ACTIONS not in data:
data[Columns.ACTIONS] = action_dist.sample()
# For convenience and if possible, compute action logp from distribution
# and add to output.
if Columns.ACTION_LOGP not in data:
data[Columns.ACTION_LOGP] = action_dist.logp(data[Columns.ACTIONS])
@DenBuzz Thanks for raising this issue. There are indeed some inconsistencies between docs and code here. The code above does not run however, but I could not reproduce the errors using our rllib/examples/rl_modules/autoregressive_actions_rlm.py
while commenting logps out. While the get_actions
connector piece is still not calculating logps they are computing during loss computation.
Is there any reason why logps need to be used during exploration? Otherwise changing the docs accordingly will suffice here imo.
Thanks for taking a look! I'm happy to see that there's a new example for an autoregressive rl module.
My understanding of logp values is that they are not required for exploration but they are referenced during the loss computation. Specifically they must be present in the batch for the logp_ratio
computation on line 73 in PPOTorchLearner
. The only way those logp values will be present is if they are returned during the forward_exploration
along with the actions
and action_dist_inputs
or they are added later by something like the get_actions
connector. And in the context of an autoregressive model that returns sampled actions under the "actions" key, that get_actions
connector will not attempt to add those logp
values.
I was also able to recreate the issue using the autoregressive_actions_rlm.py
example as well.
First, that example is not properly configured to use new api_stack. When running it I receive the warning
"You have setup a RLModuleSpec (via calling config.rl_module(...)
), but have not enabled the new API stack. To enable it, call config.api_stack(enable_rl_module_and_learner=True)
"
So I added .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)
after line 78 to make sure that autoregressive rl_module was actually getting used and we were using the new env_runners and connectors.
Once the module was being used, I got errors saying the compute_values
method was not implemented because the module had implemented _compute_values
by mistake instead. So I deleted the underscore on lines 145 and 282 in autoregressive_actions_rlm.py
.
And with that, the example was training with the correct module and without errors.
Finally, to reproduce the error I removed the ACTION_LOGP
key from the exploration_specs and train_specs and commented out line 235 where those values were being added.
I can attach the whole files if needed.
I think the solution I proposed before should still solve this problem. Just because actions are returned from the forward pass doesn't mean the get_actions
connector should skip trying to compute the action logp's. However, given that there's a new example that is explicitly adding those values during forward_exploration
, maybe that's the intended pattern and we only need to update the docs accordingly?
What happened + What you expected to happen
I'm working with an autoregressive model and trying to leverage the new API stack. Because of the nature of autoregressive models, I need to be sampling actions during the model's forward passes to use as inputs to later stages of the model. I figured this was fine because I can simply return those actions in the "actions" key along side the "action_dist_inputs" that were used to compute those actions. The RLModule documentation here seems to suggest that is totally fine and "action_logp" would be computed automatically later. However, when trying to train I get "action_logp" key errors when
PPOTorchLearner
attempts to compute the logp_ratio duringcompute_loss_for_module
.I believe the issue stems from ray/rllib/connectors/module_to_env/get_actions.py where if "actions" are present in the data already there is an early return skipping any attempt compute "action_logp" values.
My plan to work around this right now, is to manually compute the log_probs myself during the forward pass and add them to the output. Or better yet, try using a custom connector?
Versions / Dependencies
This is present in the 2.10 release as well as the current master. I'm running python 3.11 on Manjaro.
Reproduction script
Well it's hard to build a quick example of this without building a custom module from scratch. I will look around through examples to see if I can find a quick way to reproduce.
Issue Severity
Low: It annoys or frustrates me.