ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
32.25k stars 5.49k forks source link

[RLlib] RolloutWorker still using ModelCatalog with RL Module API #40353

Open gresavage opened 9 months ago

gresavage commented 9 months ago

What happened + What you expected to happen

What happened

When using the new RL Module API "custom_model_config" is required to be empty. This forces the user to move any custom config settings into the main "model" config. However, when doing so it was discovered that ray.rllib.evaluation.rollout_worker.RolloutWorker still uses ray.rllib.models.catalog.ModelCatalog.get_preprocessor_for_space to get the preprocessor, rather than ray.rllib.core.models.catalog.Catalog.get_preprocessor

Since ModelCatalog.get_preprocessor_for_space is incredibly strict about enforcing that options only contains keys from MODEL_DEFAULTS this creates a paradoxical scenario where enabling RL Module API and including custom model configuration keys in the "model" dictionary raises an error, but the user also cannot include the keys under a "custom_model_config" dictionary.

What I expect to happen

RolloutWorker and the entire RLlib stack should enforce consistent logic when using the RL Module API. In this specific case RolloutWorker._get_complete_policy_specs_dict should either:

  1. use ray.rllib.core.models.catalog.Catalog.get_preprocessor to fetch the preprocessor
  2. add a check around line 1782 of that same metod to use ray.rllib.core.models.catalog.Catalog.get_preprocessor when using RL Module API and ModelCatalog.get_preprocessor_for_space otherwise until the old ModelV2 API is fully deprecated.

Versions / Dependencies

Ray 2.7 Python 3.10

Reproduction script

Should be pretty simple to replicate - just enable preprocessing and RL Module API by setting _enable_rl_module_api=True and pass a model configuration dict containing keys which are not part of MODEL_DEFAULTS

I can provide a working script later if necessary

Issue Severity

Medium: It is a significant difficulty but I can work around it.

simonsays1980 commented 7 months ago

I guess I can replicate using this script using ray==2.7.1:

from tqdm import trange
from ray.rllib.algorithms import Algorithm
from ray.rllib.algorithms.ppo import PPOConfig

env_name = "LunarLanderContinuous-v2"
config = (
    PPOConfig()
    .environment(env_name, env_config={})
    .framework("torch")
    .debugging(seed=0)
    .rl_module(
        _enable_rl_module_api=True,
    )
    #.experimental(_enable_new_api_stack=True)#
    .training(
        model={
            "_disable_preprocessor_api": False,
            "unused_additional_model_key": True,            
        },
        _enable_learner_api=True,
    )
)

algo = config.build()

iterator = trange(2)
for epoch in iterator:
    result = algo.train()
    iterator.set_postfix(
        {
            "reward_max": result["episode_reward_max"],
            "reward_mean": result["episode_reward_mean"],
        }
    )

Please verify @gresavage

sven1977 commented 6 months ago

I can confirm this is bug on our end.

We'll provide a fix (and clean up the described conundrum thereby :) ). In the meantime, as a quick workaround, could you simply add the extra key(s) to your MODEL_DEFAULTS.

At least the repro script above runs fine with this hack:

In rllib/models/catalog.py:

MODEL_DEFAULTS: ModelConfigDict = {
    "unused_additional_model_key": 1,

    # Experimental flag.
    # If True, user specified no preprocessor to be created
    # (via config._disable_preprocessor_api=True). If True, observations
    # will arrive in model as they are returned by the env.
    "_disable_preprocessor_api": False,
    ....
sven1977 commented 6 months ago

Actually, this should be even easier: Could you try setting the preprocessor API to False in your experimental settings?

    config.experimental(
        _enable_new_api_stack=True,
        _disable_preprocessor_api=True,
    )
sven1977 commented 6 months ago

Then it should really NOT go through any preprocessor code anymore and the error should disappear.

Either way, this does not change the fact that we have to clean up the model configuration process.