ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
32.94k stars 5.58k forks source link

[RLlib] Attribute error when trying to compute action after training Multi Agent PPO with New API Stack #44475

Open Dr-IceCream opened 5 months ago

Dr-IceCream commented 5 months ago

What happened + What you expected to happen

After training Multi Agent PPO with new New API Stack under the guidance of how-to-use-the-new-api-stack I tried to compute actions:

    saved_algorithm = Algorithm.from_checkpoint(
        checkpoint=algorithm_path,
        policy_ids={"controlled_vehicle_0", "controlled_vehicle_1"},
        policy_mapping_fn=lambda agent_id, episode, **kwargs: f"controlled_vehicle_{agent_id}",
    )
    print("saved_algorithm type:", type(saved_algorithm))
    # Evaluate the model
    obs, info = env.reset()
    print("obs:", obs)
    actions = {}
    for agent_id, agent_obs in obs.items():
        policy_id = f"controlled_vehicle_{agent_id}"
        action = saved_algorithm.get_policy(policy_id).compute_single_action(agent_obs)
        actions[agent_id] = action
    print("action", actions)

but I get the error message:

AttributeError: 'MultiAgentEnvRunner' object has no attribute 'get_policy'

I also tried some other way like: action = saved_algorithm.compute_single_action(agent_obs, policy_id) but still get the same error message: AttributeError: 'MultiAgentEnvRunner' object has no attribute 'get_policy'. I have seen a similar issue in #40312, are these two issues the same?

detailed error message are as follows:

Traceback (most recent call last): File "test_evaluate.py", line 151, in evaluate_agent(saved_algorithm, env) File "test_evaluate.py", line 112, in evaluate_agent action = saved_algorithm.get_policy(policy_id).compute_single_action(agent_obs) File "C:\Users\Ice Cream\miniconda3\envs\env_highway\lib\site-packages\ray\util\tracing\tracing_helper.py", line 467, in _resume_span return method(self, *_args, **_kwargs) File "C:\Users\Ice Cream\miniconda3\envs\env_highway\lib\site-packages\ray\rllib\algorithms\algorithm.py", line 2051, in get_policy return self.workers.local_worker().get_policy(policy_id) AttributeError: 'MultiAgentEnvRunner' object has no attribute 'get_policy'

and before I call this method, I also printed the relevant info, this part looks normal:

saved_algorithm type: <class 'ray.rllib.algorithms.ppo.ppo.PPO'> saved_algorithm.get_config() <ray.rllib.algorithms.ppo.ppo.PPOConfig object at 0x0000014B0E4D4370>

through the code:

    print("saved_algorithm type:", type(saved_algorithm))
    print("saved_algorithm.get_config()",saved_algorithm.get_config())

Versions / Dependencies

Ray 2.10.0 Python 3.8.18 Windows11

Reproduction script

the code used for training is as follows:

    register_env("ray_dict_highway_env", create_env)
    config = (
        PPOConfig().environment(env="ray_dict_highway_env")
        .experimental(_enable_new_api_stack=True)
        .rollouts(env_runner_cls=MultiAgentEnvRunner)
        .resources(
            num_learner_workers=0,
            num_gpus_per_learner_worker=1,
            num_cpus_for_local_worker=1,
        )
        .training(model={"uses_new_env_runners": True})
        .multi_agent(
            policies={
                "controlled_vehicle_0",
                "controlled_vehicle_1"
            },
            policy_mapping_fn=lambda agent_id, episode, **kwargs: f"controlled_vehicle_{agent_id}",
        )
        .framework("torch")
    )
    current_script_directory = os.path.dirname(os.path.abspath(__file__))
    ray_result_path = os.path.join(current_script_directory, folder_path, "ray_results")
    tuner = tune.Tuner(
        "PPO",
        run_config=RunConfig(
            storage_path=ray_result_path,
            name="2-agent-PPO",
            stop={"timesteps_total": 5e5}
        ),
        param_space=config.to_dict() 
    )
    results = tuner.fit()

And the code for loading checkpoints:

algorithm_path = r"D:\DRL_Project\DRL_highway\experiments\hw-fast-ma-dict-v0_rllib-mappo\2024-04-01_01-28\ray_results\2-agent-PPO\PPO_ray_dict_highway_env_1c7ab_00000_0_2024-04-01_01-28-38\checkpoint_000000"
saved_algorithm = Algorithm.from_checkpoint(
        checkpoint=algorithm_path,
        policy_ids={"controlled_vehicle_0", "controlled_vehicle_1"},
        policy_mapping_fn=lambda agent_id, episode, **kwargs: f"controlled_vehicle_{agent_id}",
    )

Issue Severity

High: It blocks me from completing my task.

Dr-IceCream commented 5 months ago

It seems I have initially solved this issue by referring to the solutions in #40312 and the comments in the ray.rllib.core.rl_module code. using the following code repalcing the original code:

# Evaluate the model
obs, info = env.reset()
print("obs:", obs)
actions = {}
for agent_id, agent_obs in obs.items():
    # Determine the policy ID for the current agent using the policy mapping function
    policy_id = f"controlled_vehicle_{agent_id}"
    # Compute actions for each agent
    rl_module = saved_algorithm.get_module(policy_id)
    fwd_ins = {"obs": torch.Tensor([agent_obs])}
    fwd_outputs = rl_module.forward_inference(fwd_ins)
    action_dist_class = rl_module.get_inference_action_dist_cls()
    action_dist = action_dist_class.from_logits(
        fwd_outputs["action_dist_inputs"]
    )
    action = action_dist.sample()[0].numpy()
    actions[agent_id] = action
# actions = saved_algorithm.compute_actions(obs)
print("actions: ", actions)

and the output is as follows:

actions: {0: array(4, dtype=int64), 1: array(3, dtype=int64)}

it seems to be working.

but when I used a similar code approach to conduct multiple episode evaluations, the results were significantly worse compared to the training outcomes reported during training: episode_len_mean dropped from about 29 to 7, and episode_reward_mean decreased from about 42 to 10. After recording the videos, it was also evident that the agent indeed had its own strategy π, but the performance was relatively poor. I suspect that the specific steps I used to calculate actions directly through the rl_module might differ from those actually used during training, but I am not clear on the specific steps for action calculation used during training. Therefore, I would like to ask if I might have indeed done something wrong in this part?

and I still wonder why can't directly use the saved_algorithm.get_policy(policy_id).compute_single_action(agent_obs) or saved_algorithm.compute_single_action(agent_obs, policy_id) for action computing? or does this mean that this invocation will have a new syntax in the new API stack.

sortiz-hub commented 3 months ago

I am facing a similar issue with the SingleAgentEnvRunner - see screeenshot attached.

issue

vilmire commented 1 month ago

i dont know why this issue is p2. should be p0

grizzlybearg commented 1 month ago

This issue is still persistent. Hey @simonsays1980 @sven1977. Is there any plan to fix this?

adadelta commented 1 week ago

I'm unfortunately having this issue too:

ppo_agents = Algorithm.from_checkpoint(checkpoint=checkpoint_path) actions = ppo_agents.compute_actions(observations=observations)

Results in:

AttributeError: 'MultiAgentEnvRunner' object has no attribute 'get_policy'

Applying a fix similar to what @Dr-IceCream did, which follows #40312, gives very degraded results, to the point where it's unusable :(

Since we've had a similar P1 issue, should this be upgraded? @simonsays1980 @sven1977