ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
34.18k stars 5.8k forks source link

[RLlib] How to support gymnasium graph obs space? #45290

Open Panhaolin2001 opened 6 months ago

Panhaolin2001 commented 6 months ago

Description

No response

Use case

No response

Panhaolin2001 commented 6 months ago

my obs space is : Graph(node_space=Box(low=float('-inf'), high=float('inf'), shape=(self.feature_dim,)), edge_space=None)

my custom model is :

class GCN(TorchModelV2, torch.nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name, **customized_model_kwargs):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        torch.nn.Module.__init__(self)

        self.input_dim = model_config['custom_model_config']['input_dim']
        self.output_dim = model_config['custom_model_config']['output_dim']

        self.conv1 = GCNConv(self.input_dim, 128)
        self.conv2 = GCNConv(128, 128)
        self.conv3 = GCNConv(128, 128)
        self.line1 = torch.nn.Linear(128, 128)
        self.line2 = torch.nn.Linear(128, 64)
        self.line3 = torch.nn.Linear(64, self.output_dim)
        self.act = torch.nn.ReLU()

    def forward(self, input_dict, state, seq_lens):
        x, edge_index = input_dict["obs"].nodes[0], input_dict["obs"].edge_links[0].t()
        x = F.relu(self.conv1(x, edge_index))
        # x1 = gap(x)
        x = F.relu(self.conv2(x, edge_index))
        # x2 = gap(x)
        x = F.relu(self.conv3(x, edge_index))
        # x3 = gap(x)
        # x = x1 + x2 + x3
        x = self.line1(x)
        x = self.act(x)
        x = self.line2(x)
        x = self.act(x)
        x = self.line3(x)
        return x.sum(dim=0), state

    def value_function(self):
        return torch.zeros([1])

I disabled thest options :

algo["_disable_preprocessor_api"] = True
algo["_disable_initialize_loss_from_dummy_batch"] = True

But I still occur an error:

Failure # 1 (occurred at 2024-05-13_12-10-36)
ray::PPO.train() (pid=4121512, ip=10.240.145.6, actor_id=de9af3fd0a5dbc461709b06601000000, repr=PPO)
  File "/home/haolin/.local/lib/python3.10/site-packages/ray/tune/trainable/trainable.py", line 342, in train
    raise skipped from exception_cause(skipped)
  File "/home/haolin/.local/lib/python3.10/site-packages/ray/tune/trainable/trainable.py", line 339, in train
    result = self.step()
  File "/home/haolin/.local/lib/python3.10/site-packages/ray/rllib/algorithms/algorithm.py", line 853, in step
    results, train_iter_ctx = self._run_one_training_iteration()
  File "/home/haolin/.local/lib/python3.10/site-packages/ray/rllib/algorithms/algorithm.py", line 2854, in _run_one_training_iteration
    results = self.training_step()
  File "/home/haolin/.local/lib/python3.10/site-packages/ray/rllib/algorithms/ppo/ppo.py", line 429, in training_step
    train_batch = synchronous_parallel_sample(
  File "/home/haolin/.local/lib/python3.10/site-packages/ray/rllib/execution/rollout_ops.py", line 82, in synchronous_parallel_sample
    sample_batches = [worker_set.local_worker().sample()]
  File "/home/haolin/.local/lib/python3.10/site-packages/ray/rllib/evaluation/rollout_worker.py", line 696, in sample
    batches = [self.input_reader.next()]
  File "/home/haolin/.local/lib/python3.10/site-packages/ray/rllib/evaluation/sampler.py", line 92, in next
    batches = [self.get_data()]
  File "/home/haolin/.local/lib/python3.10/site-packages/ray/rllib/evaluation/sampler.py", line 277, in get_data
    item = next(self._env_runner)
  File "/home/haolin/.local/lib/python3.10/site-packages/ray/rllib/evaluation/env_runner_v2.py", line 344, in run
    outputs = self.step()
  File "/home/haolin/.local/lib/python3.10/site-packages/ray/rllib/evaluation/env_runner_v2.py", line 370, in step
    active_envs, to_eval, outputs = self._process_observations(
  File "/home/haolin/.local/lib/python3.10/site-packages/ray/rllib/evaluation/env_runner_v2.py", line 688, in _process_observations
    self._handle_done_episode(
  File "/home/haolin/.local/lib/python3.10/site-packages/ray/rllib/evaluation/env_runner_v2.py", line 838, in _handle_done_episode
    self._build_done_episode(env_id, is_done, outputs)
  File "/home/haolin/.local/lib/python3.10/site-packages/ray/rllib/evaluation/env_runner_v2.py", line 727, in _build_done_episode
    episode.postprocess_episode(
  File "/home/haolin/.local/lib/python3.10/site-packages/ray/rllib/evaluation/episode_v2.py", line 322, in postprocess_episode
    post_batch = policy.postprocess_trajectory(post_batch, other_batches, self)
  File "/home/haolin/.local/lib/python3.10/site-packages/ray/rllib/algorithms/ppo/ppo_torch_policy.py", line 215, in postprocess_trajectory
    return compute_gae_for_sample_batch(
  File "/home/haolin/.local/lib/python3.10/site-packages/ray/rllib/evaluation/postprocessing.py", line 188, in compute_gae_for_sample_batch
    sample_batch = compute_bootstrap_value(sample_batch, policy)
  File "/home/haolin/.local/lib/python3.10/site-packages/ray/rllib/evaluation/postprocessing.py", line 302, in compute_bootstrap_value
    vf_preds = np.array(sample_batch[SampleBatch.VF_PREDS])
  File "/home/haolin/.local/lib/python3.10/site-packages/ray/rllib/policy/sample_batch.py", line 927, in __getitem__
    value = dict.__getitem__(self, key)
KeyError: 'vf_preds'

So how to solve this problem? thx!

simonsays1980 commented 6 months ago

@Panhaolin2001 Thanks for filing this issue. At this point of time we do not support Graph observation nor action spaces. If you can convert the space into a Dict/Tuple space RLlib will take care of it by flattening.

Panhaolin2001 commented 6 months ago

@Panhaolin2001 Thanks for filing this issue. At this point of time we do not support Graph observation nor action spaces. If you can convert the space into a Dict/Tuple space RLlib will take care of it by flattening.

Thanks, but I set both options to True.

algo[“_disable_preprocessor_api”] = True
algo[“_disable_initialize_loss_from_dummy_batch”] = True

In that case I can accept raw data of type GraphInstance from env inside the custom model. Is my error reported because the GraphInstance type data can't be batch at the time of train? Is there any way I can do this through the custom batch and train_batch methods?

Or Do you have plans to support graph space?