facebookresearch / CompilerGym

Reinforcement learning environments for compiler and program optimization tasks
https://compilergym.ai/
MIT License
898 stars 125 forks source link

RLlib training with DGL graph #786

Closed dejangrubisic closed 1 year ago

dejangrubisic commented 1 year ago

❓ Questions and Help

I am trying to train a graph-based model with RLlib that excepts the DGL graph. I created a backend with CompilationSession and registered the observation space as ByteSequenceSpace and I was able to return the pickled DGL graph as an observation in CompilerGym's base example. The problem arises when I want to use RLlib for training. It seems that the sequence kinds of spaces don't have implemented the sample() method.

env = make_env()
ray.rllib.utils.check_env(env) # I think this needs to pass for successful training.

It will return NotImplementedError compiler_gym/spaces/sequence.py(121)sample()

  1. What is the right way to send non-tensor kind of data such as DGL graph to ray?

  2. It might be ok also to send a dictionary of tensors (transformed version of DGL graph) from the backend but I couldn't figure out how to register Observation space to return such a dictionary. Any suggestions?

Additional Context

ObservationSpace( name="dgl_pickle", space=Space( byte_sequence=ByteSequenceSpace(length_range=Int64Range(min=0)), ), ),

CompilerGym 0.2.5 ray 2.2.0

ChrisCummins commented 1 year ago

Hey Dejan, good questions.

It will return NotImplementedError compiler_gym/spaces/sequence.py(121)sample()

Could you try monkey patching it?

env = make_env()
env.observation_space.sample = lambda *args, **kwargs: return make_some_plausible_data()

If you're building from source you could also try patching Sequence.sample() to generate plausible data. I would happily accept a PR for that.

What is the right way to send non-tensor kind of data such as DGL graph to ray?

Sorry, I don't have any experience with that. You may want to ask on the ray forums.

It might be ok also to send a dictionary of tensors (transformed version of DGL graph) from the backend but I couldn't figure out how to register Observation space to return such a dictionary. Any suggestions?

Here is an example of an observation that uses a dict (of scalars, but could be composed to dict of sequences/nested dict etc). Is that the kind of thing you're after?

Cheers, Chris

dejangrubisic commented 1 year ago

Thanks for the answer!

This is the scratch of the solution I found. We can define a new instance in derived_observation_spaces when we register the environment and define a "translate" function that translates a pickle to dictionary.

This is still not a full solution to work with Ray, since class Sequence doesn't have self._shape defined, but I am working on it and will come back when I figure it out.

def pickle_to_dict(base_observation):
    graph = pickle.loads(base_observation)
    return graph.to_dict()

register(
        id="env_name-v0",
        entry_point="compiler_gym.service.client_service_compiler_env:ClientServiceCompilerEnv",
        kwargs={    # check ClientServiceCompilerEnv for possible arguments
            "service": path-to-example_service.py,
            "rewards": [
                rewardClass(),
            ],
            "datasets": [
                importlib.import_module(f"env_name_service.agent_py.datasets.{dataset}").Dataset() for dataset in datasets 
            ],
            "derived_observation_spaces": [
                {
                    "id": "obs_name",
                    "base_id": "obs_name_pickle",
                    "space": DictSpace(
                        {
                            f"{key}": Sequence(
                                name=key, size_range=(0, None), dtype=np.ndarray
                            )
                            for key in ['key1', 'key2', 'key3]
                        },
                        name="obs_name",
                    ),
                    "translate": lambda base_observation: pickle_to_dict(base_observation),
                },
               ],
           }
    )
ChrisCummins commented 1 year ago

That looks nice. Albeit, as nice as it can get while having to jump through the hoops of the client/service env approach 🙂 Definitely room for improvement on the CompilerGym backend to better support use cases like yours

dejangrubisic commented 1 year ago

Here is the final solution that I found.

Here is out observation function from CompilationSession (backend).

class ProgramlCompilationSession(CompilationSession):
         ...
        # Here we define observation space suitable for pickle
        ObservationSpace( 
            name="programl_pickle",
            space=Space(
                byte_sequence=ByteSequenceSpace(length_range=Int64Range(min=0)),
            ),
        ),

class Profiler:
         ...
        # Here we return observation we made
    def get_observation(self) -> Event:
        g_programl = self.programl_get_graph(self.llvm_path)
        pickled = pickle.dumps(g_programl)
        return Event(byte_tensor=ByteTensor(shape=[len(pickled)], value=pickled))

When we register the environment we can define derived observation space to map it to RLlib format (Maybe we can reformat observation directly in get_observation, I didn't check that yet).

import torch
from compiler_gym.spaces import Sequence

max_pickle_size  = 1e5

def pad_pickle(base_observation):
        # Format of RLLib requires fixed sized tensor so we pad our tensor with zeros and add pickle size at the end, so we can recreate it later.

        orig_size = len(base_observation)
        assert(max_pickle_size < orig_size)
        padded = np.append(base_observation, np.zeros(max_pickle_size - orig_size)) 
        padded[-1] = orig_size 
        return padded

register(
        id="env_name-v0",
        entry_point="compiler_gym.service.client_service_compiler_env:ClientServiceCompilerEnv",
        kwargs={    # check ClientServiceCompilerEnv for possible arguments
            "service": path-to-example_service.py,
            "rewards": [
                rewardClass(),
            ],
            "datasets": [
                importlib.import_module(f"env_name_service.agent_py.datasets.{dataset}").Dataset() for dataset in datasets 
            ],
            "derived_observation_spaces": [
                               {
                    "id": "programl",
                    "base_id": "programl_pickle",
                    "space": Sequence( # Change Sequence (from CompilerGym) init to accept shape, self._shape = shape
                                name='pickled', size_range=(0, None), dtype=np.int32, shape=[max_pickle_size] 
                    ),
                    "translate": lambda base_observation: pad_pickle(base_observation),
                },
            ],
    )

Once this is done, we can access our observation in the custom module in RLlib with:


class TorchCustomModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
         ...

    def forward(self, input_dict, state, seq_lens):
        self.device = next(self.parameters()).device

        batch_size = input_dict['obs'].shape[0]
        if input_dict['obs'].any():
            # Handle real input observations
            pickle_tensor = np.array(input_dict['obs'].numpy(), dtype=np.int32)

            graphs = np.apply_along_axis(
                lambda row: pickle.loads(row[:row[-1]].astype(np.int8)), 
                axis=1, 
                arr=pickle_tensor
            )

Now you can access your graphs in Rllib.