salesforce / warp-drive

Extremely Fast End-to-End Deep Multi-Agent Reinforcement Learning Framework on a GPU (JMLR 2022)
BSD 3-Clause "New" or "Revised" License
465 stars 78 forks source link

Creating a 4D custom environment from Gridworld 2D env #60

Closed Mshz2 closed 1 year ago

Mshz2 commented 2 years ago

Dear all, I am new to reinforcement learning, but I am fascinated with the Warp Drive. I was wondering if you could help me to build up my custom env for my little study project. The story of my env is like: I wanna create a gym 4D environment, where it is a 468x225x182x54 plane (which means 1,034,888,400 unique cells). And every cell in this space has a unique value. And my agent (e.g. rabbit) can jump anywhere in this space and makes cells get zero value (or burned after the point of the cell gets collected by the rabbit). Also the agent will be rewarded based on reduction of the environment overall points (e.g. 2000) from the change of cell values to zero. Which cells have more points or reward is unknown to agent but fixed, and it is the task of the agent to find out by making jump in order to burn more higher value cells before game episode length finish. I thought my action space could be defined as

class CustomEnv(gym.Env):
    def __init__(self):
           self.action_space = gym.spaces.MultiDiscrete([468, 225, 182, 54])

For example

 print(CustomEnv.action_space.sample())
[172 54 101 37]

where my agent collects the reward of the location [172 54 101 37]. And all values at this cell is zero now. When the game starts the agent would jump to this 4D space (I assume it is better to make my first episode start at a fixed position but buffer action(no values are zeroed at this first episode) and during policy training agent learns to begin with an action that makes a globally better reward). Furthermore, I want the step function for episodes of the game be like a rabbit make a jump, then the reward is returned. Also, the returned state of the episode is the 4D space with same shape but the value of it will change from zeroing of previous action.

However, I don't know how should I define my observation space and I really appreciate your help.

So far, for example if I modify your gridword example env:

import numpy as np
from gym import spaces
from gym.utils import seeding

# seeding code from https://github.com/openai/gym/blob/master/gym/utils/seeding.py
from warp_drive.utils.constants import Constants
from warp_drive.utils.data_feed import DataFeed
from warp_drive.utils.gpu_environment_context import CUDAEnvironmentContext

_OBSERVATIONS = Constants.OBSERVATIONS
_ACTIONS = Constants.ACTIONS
_REWARDS = Constants.REWARDS

# Our Custom field, where it is 4D space of size 468x225x182x54, and each cell has a random value
RabbitField_World = np.array([np.random.randint(0,5,468), np.random.randint(0,5,225), np.random.randint(0,5,182), np.random.randint(0,5,54)])
RabbitField_World_Fixed_Points = (sum(RabbitField_World[0])+sum(RabbitField_World[1])+sum(RabbitField_World[2])+sum(RabbitField_World[3]))
_LOC_X = "cells_dim_x"
_LOC_Y = "cells_dim_y"
_LOC_Z = "cells_dim_z"
_LOC_K = "cells_dim_k"

def burning(dim_world, jump_pos):
    dim_world[jump_pos] = 0 
    return dim_world

class RabbitField:
    """
    The game of tag on a 4D 468x225x182x54 plane.
    There are a number of agents (Rabbits) trying to minimize the plane overall points.
    A cell might have a value from range of 0 to 5. An agent jumps on a cell and collects 
        the point of it, and the value of the cell becomes zero.
    The reward will be the remaining points in the 4D plane.
    """

    def __init__(
        self,
        num_agents=1,
        grid_dim_one=468,
        grid_dim_two=225,
        grid_dim_three=182,
        grid_dim_four=54,
        episode_length=100,
        starting_cells_x=RabbitField_World[0],
        starting_cells_y=RabbitField_World[1],
        starting_cells_z=RabbitField_World[2],
        starting_cells_k=RabbitField_World[3],
        finish_point = 1000,
        seed=None,
        step_cost_for_agent=0.01,
        use_full_observation=True,
        env_backend="cpu"
    ):
        """
        :param num_agents (int): the total number of rabbits. In this env,
            num_agent = 1 or each env can have only one rabbit or multi.
        :param grid_dim_# (int): the world is a 4D space,
        :param episode_length (int): episode length
        :param starting_location_x ([ndarray], optional): starting x axis cells values
            of the 4D plane.
        :param starting_location_y ([ndarray], optional): starting y axis cells values
            of the 4D agents.
        :param starting_location_z ([ndarray], optional): starting z axis cells values
            of the 4D agents.
        :param starting_location_k ([ndarray], optional): starting k axis cells values
            of the 4D agents.
        :param finish_point = 1000: The sufficient reward to finish the game.
        :param seed: seeding parameter.
        :param step_cost_for_agent (float): penalty for each jump that rabbit makes
        :param use_full_observation (bool): boolean indicating whether to
            include all the agents' data in the use_full_observation or
            just the nearest neighbor. Defaults to True.
        """
        assert num_agents > 0
        self.num_agents = num_agents

        assert episode_length > 0
        self.episode_length = episode_length

        self.grid_dim_one = grid_dim_one
        self.grid_dim_two = grid_dim_two
        self.grid_dim_three = grid_dim_three
        self.grid_dim_four = grid_dim_four

        # Seeding
        self.np_random = np.random
        if seed is not None:
            self.seed(seed)

        self.starting_cells_x = starting_cells_x
        self.starting_cells_y = starting_cells_y
        self.starting_cells_z = starting_cells_z
        self.starting_cells_k = starting_cells_k

        # Each possible action is a cell position in the self.RabbitField_World 
        self.step_actions = [468, 225, 182, 54]

        # Defining observation and action spaces
        self.observation_space = None  # Note: this will be set via the env_wrapper

        self.action_space = {
            agent_id: spaces.MultiDiscrete(self.step_actions)
            for agent_id in range(self.num_agents)
        }

        # These will be set during reset (see below)
        self.timestep = None
        self.global_state = None

        # For reward computation
        self.step_cost_for_agent = step_cost_for_agent
        self.finish_point = finish_point  #this is a fixed reward defined by us to end the game
        self.reward_penalty = np.zeros(self.num_agents)
        self.use_full_observation = use_full_observation

        self.env_backend = env_backend

    name = "RabbitField"

    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    def set_global_state(self, key=None, value=None, t=None, dtype=None):
        assert key is not None
        if dtype is None:
            dtype = np.int32

        # If no values are passed, set everything to zeros.
        if key not in self.global_state:
            self.global_state[key] = np.zeros(
                (self.episode_length + 1, self.num_agents), dtype=dtype
            )

        if t is not None and value is not None:
            assert isinstance(value, np.ndarray)
            assert value.shape[0] == self.global_state[key].shape[1]

            self.global_state[key][t] = value

    def update_state(self, actions_x, actions_y, actions_z, actions_k):
        loc_x_prev_t = self.global_state[_LOC_X][self.timestep - 1]
        loc_y_prev_t = self.global_state[_LOC_Y][self.timestep - 1]
        loc_z_prev_t = self.global_state[_LOC_Z][self.timestep - 1]
        loc_k_prev_t = self.global_state[_LOC_K][self.timestep - 1]

        loc_x_curr_t = burning(loc_x_prev_t, actions_x)
        loc_y_curr_t = burning(loc_y_prev_t, actions_y)
        loc_z_curr_t = burning(loc_z_prev_t, actions_z)
        loc_k_curr_t = burning(loc_k_prev_t, actions_k)

        self.set_global_state(key=_LOC_X, value=loc_x_curr_t, t=self.timestep)
        self.set_global_state(key=_LOC_Y, value=loc_y_curr_t, t=self.timestep)
        self.set_global_state(key=_LOC_Z, value=loc_z_curr_t, t=self.timestep)
        self.set_global_state(key=_LOC_K, value=loc_k_curr_t, t=self.timestep)

        #Our Rabbit Field Custom Reward from collecting points, the more the current 4D plane lose overall values, the more reward be increased.
        self.reward_collection = RabbitField_World_Fixed_Points - (sum(loc_x_curr_t)+sum(loc_y_curr_t)+sum(loc_z_curr_t)+sum(loc_k_curr_t))
        if self.reward_collection >= self.finish_point:
            tag = True

        reward = self.reward_collection
        rew = {}
        for agent_id, r in enumerate(reward):
            rew[agent_id] = r

        return rew, tag

    def generate_observation(self):
        obs = {}
        if self.use_full_observation:
            common_obs = None
            for feature in [
                _LOC_X,
                _LOC_Y,
                _LOC_Z,
                _LOC_K,
            ]:
                if common_obs is None:
                    common_obs = self.global_state[feature][self.timestep]
                else:
                    common_obs = np.vstack(
                        (common_obs, self.global_state[feature][self.timestep])
                    )
            normalized_common_obs = common_obs 

            agent_types = np.array(
                [self.agent_type[agent_id] for agent_id in range(self.num_agents)]
            )

            for agent_id in range(self.num_agents):
                agent_indicators = np.zeros(self.num_agents)
                agent_indicators[agent_id] = 1
                obs[agent_id] = np.concatenate(
                    [
                        np.vstack(
                            (normalized_common_obs, agent_types, agent_indicators)
                        ).reshape(-1),
                        np.array([float(self.timestep) / self.episode_length]),
                    ]
                )
        else:
            for agent_id in range(self.num_agents):
                feature_list = []
                for feature in [
                    _LOC_X,
                    _LOC_Y,
                    _LOC_Z,
                    _LOC_K,
                ]:
                    feature_list.append(
                        self.global_state[feature][self.timestep][agent_id]
                    )
                if agent_id < self.num_agents - 1:
                    for feature in [
                        _LOC_X,
                        _LOC_Y,
                        _LOC_Z,
                        _LOC_K,
                    ]:
                        feature_list.append(
                            self.global_state[feature][self.timestep][-1]
                        )
                else:
                    dist_array = None
                    for feature in [
                        _LOC_X,
                        _LOC_Y,
                        _LOC_Z,
                        _LOC_K,
                    ]:
                        if dist_array is None:
                            dist_array = np.square(
                                self.global_state[feature][self.timestep][:-1]
                                - self.global_state[feature][self.timestep][-1]
                            )
                        else:
                            dist_array += np.square(
                                self.global_state[feature][self.timestep][:-1]
                                - self.global_state[feature][self.timestep][-1]
                            )
                    min_agent_id = np.argmin(dist_array)
                    for feature in [
                        _LOC_X,
                        _LOC_Y,
                        _LOC_Z,
                        _LOC_K,
                    ]:
                        feature_list.append(
                            self.global_state[feature][self.timestep][min_agent_id]
                        )
                feature_list += [
                    self.agent_type[agent_id],
                    float(self.timestep) / self.episode_length,
                ]
                obs[agent_id] = np.array(feature_list)
        return obs

    def reset(self):
        # Reset time to the beginning
        self.timestep = 0

        # Re-initialize the global state
        self.global_state = {}
        self.set_global_state(
            key=_LOC_X, value=self.starting_cells_x, t=self.timestep, dtype=np.int32
        )
        self.set_global_state(
            key=_LOC_Y, value=self.starting_cells_y, t=self.timestep, dtype=np.int32
        )
        self.set_global_state(
            key=_LOC_Z, value=self.starting_cells_z, t=self.timestep, dtype=np.int32
        )
        self.set_global_state(
            key=_LOC_K, value=self.starting_cells_k, t=self.timestep, dtype=np.int32
        )
        return self.generate_observation()

    def step(
        self,
        actions=None,
    ):
        self.timestep += 1
        assert isinstance(actions, dict)
        assert len(actions) == self.num_agents

        actions_x = np.array(
            [
                actions[agent_id][0]
                for agent_id in range(self.num_agents)
            ]
        )
        actions_y = np.array(
            [
                actions[agent_id][1]
                for agent_id in range(self.num_agents)
            ]
        )
        actions_z = np.array(
            [
                actions[agent_id][2]
                for agent_id in range(self.num_agents)
            ]
        )
        actions_k = np.array(
            [
                actions[agent_id][3]
                for agent_id in range(self.num_agents)
            ]
        )

        rew, tag = self.update_state(actions_x, actions_y, actions_z, actions_k)
        obs = self.generate_observation()
        done = {"__all__": self.timestep >= self.episode_length or tag}
        info = {}

        return obs, rew, done, info

class CUDARabbitField(RabbitField, CUDAEnvironmentContext):
    """
    CUDA version of the RabbitField environment.
    Note: this class subclasses the Python environment class RabbitField,
    and also the  CUDAEnvironmentContext
    """

    def get_data_dictionary(self):
        data_dict = DataFeed()
        for feature in [
            _LOC_X,
            _LOC_Y,
            _LOC_Z,
            _LOC_K,
        ]:
            data_dict.add_data(
                name=feature,
                data=self.global_state[feature][0],
                save_copy_and_apply_at_reset=True,
                log_data_across_episode=True,
            )
        data_dict.add_data_list(
            [
                ("finish_point", self.finish_point),
                ("step_cost_for_agent", self.step_cost_for_agent),
                ("use_full_observation", self.use_full_observation),
            ]
        )
        return data_dict

    def get_tensor_dictionary(self):
        tensor_dict = DataFeed()
        return tensor_dict

    def step(self, actions=None):
        self.timestep += 1
        args = [
            _LOC_X,
            _LOC_Y,
            _LOC_Z,
            _LOC_K,
            _ACTIONS,
            "_done_",
            _REWARDS,
            _OBSERVATIONS,
            "finish_point",
            "step_cost_for_agent",
            "use_full_observation",
            "_timestep_",
            ("episode_length", "meta"),
        ]
        if self.env_backend == "pycuda":
            self.cuda_step(
                *self.cuda_step_function_feed(args),
                block=self.cuda_function_manager.block,
                grid=self.cuda_function_manager.grid,
            )
        elif self.env_backend == "numba":
            self.cuda_step[
                self.cuda_function_manager.grid, self.cuda_function_manager.block
            ](*self.cuda_step_function_feed(args))
        else:
            raise Exception("CUDARabbitField expects env_backend = 'pycuda' or 'numba' ")
Emerald01 commented 2 years ago

Hi @Mshz2

WarpDrive is mostly designed for running multi-agent reinforcement learning, i.e., each gpu thread can work in parallel for one thread. In this sense, a single agent environment may not be the best suit, as in your case. Although WarpDrive is able to run single agent, it is not optimal for single agent as we designed. We are going to include a single agent adaptor later though. Beside, your 4-D gridsize seems extremely large, roughly ~1G per environment, if you want to run many replicas in parallel, the memory would be a serious constraint right here.

Regards,

Mshz2 commented 2 years ago

Hi @Mshz2

WarpDrive is mostly designed for running multi-agent reinforcement learning, i.e., each gpu thread can work in parallel for one thread. In this sense, a single agent environment may not be the best suit, as in your case. Although WarpDrive is able to run single agent, it is not optimal for single agent as we designed. We are going to include a single agent adaptor later though. Beside, your 4-D gridsize seems extremely large, roughly ~1G per environment, if you want to run many replicas in parallel, the memory would be a serious constraint right here.

Regards,

Thanks for your response. There can be also a case for four 1D game. How about assigning an agent per dimension and agents take action one after another? Do you think would it reduce the load? for my case one environment is sufficient.

Emerald01 commented 2 years ago

I believe that would be much better if you can use 4 agents to take care of each dimension individually.