ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
32.52k stars 5.51k forks source link

[RLlib] - `Algorithm.add_module` does not use the `module_state` argument. #46247

Open simonsays1980 opened 1 month ago

simonsays1980 commented 1 month ago

What happened + What you expected to happen

What happened

Using the Algorithm.add_module with a module_state does not use the module state, but instead loads or builds the module directly from the passed in SingleAgentRLModuleSpec. This results in an error about missing network weights due to the inference-only design.

trying to add module: th-rlm-16
2024-06-24 11:26:56,456 ERROR actor_manager.py:185 -- Worker exception caught during `apply()`: Error(s) in loading state_dict for PPOTorchRLModule:
        Missing key(s) in state_dict: "encoder.actor_encoder.net.mlp.0.weight", "encoder.actor_encoder.net.mlp.0.bias", "encoder.actor_encoder.net.mlp.2.weight", "encoder.actor_encoder.net.mlp.2.bias", "encoder.critic_encoder.net.mlp.0.weight", "encoder.critic_encoder.net.mlp.0.bias", "encoder.critic_encoder.net.mlp.2.weight", "encoder.critic_encoder.net.mlp.2.bias", "vf.net.mlp.0.weight", "vf.net.mlp.0.bias". 
        Unexpected key(s) in state_dict: "encoder.encoder.net.mlp.0.weight", "encoder.encoder.net.mlp.0.bias", "encoder.encoder.net.mlp.2.weight", "encoder.encoder.net.mlp.2.bias". 
Traceback (most recent call last):
  File "/home/thwu/miniconda3/envs/rlmodule/lib/python3.10/site-packages/ray/rllib/utils/actor_manager.py", line 181, in apply
    return func(self, *args, **kwargs)
  File "/home/thwu/miniconda3/envs/rlmodule/lib/python3.10/site-packages/ray/rllib/env/env_runner_group.py", line 555, in _set_weights
    env_runner.set_weights(_weights, global_vars)
  File "/home/thwu/miniconda3/envs/rlmodule/lib/python3.10/site-packages/ray/util/tracing/tracing_helper.py", line 467, in _resume_span
    return method(self, *_args, **_kwargs)
  File "/home/thwu/miniconda3/envs/rlmodule/lib/python3.10/site-packages/ray/rllib/env/multi_agent_env_runner.py", line 672, in set_weights
    self.module.set_state(weights)
  File "/home/thwu/miniconda3/envs/rlmodule/lib/python3.10/site-packages/ray/rllib/core/rl_module/marl_module.py", line 334, in set_state
    self._rl_modules[module_id].set_state(state)
  File "/home/thwu/miniconda3/envs/rlmodule/lib/python3.10/site-packages/ray/rllib/core/rl_module/torch/torch_rl_module.py", line 73, in set_state
    self.load_state_dict(state_dict)
  File "/home/thwu/miniconda3/envs/rlmodule/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for PPOTorchRLModule:
        Missing key(s) in state_dict: "encoder.actor_encoder.net.mlp.0.weight", "encoder.actor_encoder.net.mlp.0.bias", "encoder.actor_encoder.net.mlp.2.weight", "encoder.actor_encoder.net.mlp.2.bias", "encoder.critic_encoder.net.mlp.0.weight", "encoder.critic_encoder.net.mlp.0.bias", "encoder.critic_encoder.net.mlp.2.weight", "encoder.critic_encoder.net.mlp.2.bias", "vf.net.mlp.0.weight", "vf.net.mlp.0.bias". 
        Unexpected key(s) in state_dict: "encoder.encoder.net.mlp.0.weight", "encoder.encoder.net.mlp.0.bias", "encoder.encoder.net.mlp.2.weight", "encoder.encoder.net.mlp.2.bias".

What you expected to happen

That a module state can be loaded into a module when calling Algorithm.add_module with the Learner's module being inference_only=False and the EnvRunner's module being inference_only=True. The module state should at best come from a Learner's module (b/c it has all networks).

Versions / Dependencies

Python 3.11 Ray master

Reproduction script

def add_trained_modules(
    algorithm: rllib_algorithm.Algorithm,
    module_specs: list[serialization.ModuleSpec],
    evaluation_workers: bool = True,
):
    """Add a list of ModuleSpec to an RLLib Algorithm"""
    for module_spec in module_specs:
        # if algorithm.get_module(module_spec.name):
        #     continue

        module = module_spec.load_module()
        model_config_dict = module.config.model_config_dict
        model_config_dict["_inference_only"] = False
        print(f"trying to add module: {module_spec.name}")
        algorithm.add_module(
            module_spec.name,
            rl_module.SingleAgentRLModuleSpec(
                module_class=ppo_torch_rl_module.PPOTorchRLModule,
                observation_space=module.config.observation_space,
                action_space=module.config.action_space,
                model_config_dict=model_config_dict,  # module.config.model_config_dict,
                catalog_class=module.config.catalog_class,
            ),
            module_state=module.get_state(inference_only=False),
            evaluation_workers=evaluation_workers,
        )
        print(f"added module: {module_spec.name}")

Issue Severity

Medium: It is a significant difficulty but I can work around it.

simonsays1980 commented 3 weeks ago

@sven1977 Is this related to your cleanup PR with the Checkpointable?