ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
32.93k stars 5.58k forks source link

[RLlib] unexpected sample_batch in policy.postprocess_trajectory() #46660

Open ErwinLiYH opened 1 month ago

ErwinLiYH commented 1 month ago

What happened + What you expected to happen

I am trying to write a custom policy with a postprocess_trajectory to post-process infos. However, after one training iteration, the infos in the raw sample batch as the input of postprocess_trajectory is abnormal, the first info dict will become the following, and the info dicts from the second to last are correct.

{
    0: {
        'agent0':{
            # info dict from env
        }
    }
}

I write a dummy env which will return info like:

{
        "SLA_violation": 0,
        "bad_SE": 0,
        "changed_RB": 1
}

The infos of the input sample batch to postprocess_trajectory after the first training iteration is:

array([{0: {'agent0': {'SLA_violation': 0, 'bad_SE': 0, 'changed_RB': 1}}},        
       {'SLA_violation': 0, 'bad_SE': 0, 'changed_RB': 1},                         
       {'SLA_violation': 0, 'bad_SE': 0, 'changed_RB': 1},                         
       {'SLA_violation': 0, 'bad_SE': 0, 'changed_RB': 1},                         
       {'SLA_violation': 0, 'bad_SE': 0, 'changed_RB': 1},                         
       {'SLA_violation': 0, 'bad_SE': 0, 'changed_RB': 1},                         
       {'SLA_violation': 0, 'bad_SE': 0, 'changed_RB': 1},                         
       {'SLA_violation': 0, 'bad_SE': 0, 'changed_RB': 1},                         
       {'SLA_violation': 0, 'bad_SE': 0, 'changed_RB': 1},
           .........
       ]

This is an unexcepted behaviour. The code to reproduce this problem is as follows.

Versions / Dependencies

Ray: 2.32.0 Python: 3.10.14 Ubuntu 20.04

Reproduction script

import gymnasium as gym
from gymnasium import spaces
import numpy as np

import ray
from ray.tune.registry import register_env
from ray.rllib.algorithms.dqn import DQN, DQNConfig
from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy
from ray.rllib.algorithms.dqn.dqn_tf_policy import postprocess_nstep_and_prio
from ray.rllib.utils.annotations import override
from ray.rllib.algorithms.algorithm import Algorithm

from ray.rllib.policy.sample_batch import SampleBatch

from ray import tune, air

import pickle
import copy

class DummyEnv(gym.Env):
    def __init__(self, n=100):
        super(DummyEnv, self).__init__()
        self.n = n
        self.current_step = 0
        self.action_space = spaces.Discrete(2)
        self.observation_space = spaces.Box(low=np.array([0]), high=np.array([100]), dtype=np.float32)
        self.test_info = {
            "SLA_violation": 0,
            "bad_SE": 0,
            "changed_RB": 1
            }

    def reset(self, seed=None, options=None):
        self.current_step = 0
        observation = np.array([0.0], dtype=np.float32)
        return observation, self.test_info

    def step(self, action):
        self.current_step += 1
        observation = np.array([self.current_step], dtype=np.float32)
        reward = 1.0

        return observation, reward, self.current_step>=self.n ,False , self.test_info

ray.init()

def postprocess_traj(policy, sample_batch, other_agent_batches=None, episode=None):
    # process infos further
    raw_b = copy.deepcopy(sample_batch)
    postprocess_nstep_and_prio(policy, sample_batch, other_agent_batches, episode)
    try:
        if (
            isinstance(sample_batch[SampleBatch.INFOS][0], dict) and   # ray will use fack datas to check API before formal training, infos will be 0
            SampleBatch.ACTIONS in sample_batch.keys()                 # when evaluation, there is no action in sample_batch
        ):
            for i in range(len(sample_batch)):
                # try to access the infos
                info = sample_batch[SampleBatch.INFOS][i]
                x = info["bad_SE"]
                x = info["changed_RB"]
                x = info["SLA_violation"]
    except Exception as e:                                             # store the raw batch and batch for debug
        with open("raw_batch.pkl", "wb") as f:
            pickle.dump(raw_b, f)
        with open("batch.pkl", "wb") as f:
            pickle.dump(sample_batch, f)
        raise Exception(f"Error in postprocess_trajectory: \n{raw_b['infos']}")
    return sample_batch

postprocess_policy = DQNTorchPolicy.with_updates(
    name="postprocess_policy",
    postprocess_fn = postprocess_traj
)

class MyDQN(DQN):
    @classmethod
    @override(Algorithm)
    def get_default_policy_class(
        cls, config
    ):
        if config["framework"] == "torch":
            return postprocess_policy
        else:
            raise Exception("Not support other framework")

def test_env_creator(env_config):
    return DummyEnv(100)

register_env("test_env", test_env_creator)

config = (
    DQNConfig(MyDQN)
    .environment("test_env")
    .framework("torch")
    .env_runners(
        num_env_runners = 1,
        num_envs_per_env_runner = 1,
        num_cpus_per_env_runner = 1,
        rollout_fragment_length = 100,
        batch_mode = "truncate_episodes",
    )
    .resources(num_gpus = 1)
    .training(
        gamma = 0,
        train_batch_size = 64,
        lr = 1e-3,
    )
    .evaluation(
        evaluation_interval = 1,
        evaluation_duration = 1,
        evaluation_duration_unit = "episodes",
        evaluation_sample_timeout_s = 6000,
        evaluation_num_env_runners = 1,
        evaluation_parallel_to_training = False,
        evaluation_config = {
            "env": "test_env",
            "explore": False,
        }
    )
    .reporting(keep_per_episode_custom_metrics=True)
)

config.exploration_config.update({
    "initial_epsilon": 0.5,
    "final_epsilon": 0.1,
    "epsilon_timesteps": 120000,
})

res = tune.Tuner(MyDQN, param_space=config.to_dict(), run_config=air.RunConfig(stop={"training_iteration":20})).fit()

I also tested it in the MARL setting, the first info dict will become:

{
  "0": {
    "agent_0": {
          # agent 1 info dict
    },
    "agent_1": {
          # agent 2 info dict
    },
    "agent_2": {
          # agent 3 info dict
    }
  }
}

and the whole info dicts of the input sample batch is:

array([{0: {'agent_0': {'SLA_violation': 1, 'bad_SE': 1, 'changed_RB': False}, 'agent_1': {'SLA_violation': 0, 'bad_SE': 1, 'changed_RB': False}, 'agent_2': {'SLA_violation': 0, 'bad_SE': 1, 'changed_RB': False}}},
       {'SLA_violation': 0, 'bad_SE': 0, 'changed_RB': True},
       {'SLA_violation': 0, 'bad_SE': 0, 'changed_RB': False},
       {'SLA_violation': 0, 'bad_SE': 1, 'changed_RB': True},
       {'SLA_violation': 1, 'bad_SE': 0, 'changed_RB': True},
       {'SLA_violation': 0, 'bad_SE': 0, 'changed_RB': True},
       {'SLA_violation': 0, 'bad_SE': 0, 'changed_RB': True},
       {'SLA_violation': 0, 'bad_SE': 1, 'changed_RB': True},
       {'SLA_violation': 0, 'bad_SE': 0, 'changed_RB': True},
       {'SLA_violation': 0, 'bad_SE': 1, 'changed_RB': True},
       {'SLA_violation': 0, 'bad_SE': 1, 'changed_RB': True},
       {'SLA_violation': 1, 'bad_SE': 0, 'changed_RB': True}
       .........
]

It seems the first info dict has not been transferred to SampleBatch, it is still the MultiAgentBatch

Issue Severity

High: It blocks me from completing my task.

ErwinLiYH commented 1 month ago

I have located the bug, It is caused by the function __process_resetted_obs_for_eval of class EnvRunnerV2 in ray/rllib/evaluation/env_runner_v2.py does not handle the raw info dict from reset operation correctly. This function will process obs and info from reset operation, but it only extracts agent obs from structure like:

{
  env_id: {
    agent_id: agent_obs
    ......
  },
}

We need to extract info dict as well. Should I create a pull request to fix it?