ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
32.17k stars 5.48k forks source link

[RLlib] Support Graph input batches #23886

Closed gjoliver closed 1 year ago

gjoliver commented 2 years ago

Description

There are multiple user requests of using GraphNN data (node and edge lists) as sample batches into a custom RLlib model.

https://discuss.ray.io/t/rllib-variable-length-observation-spaces-without-padding/726 https://discuss.ray.io/t/working-with-graph-neural-networks-varying-state-space/5730/2

The recommended method today is to use Repeated observation space and VariableValues input type. However RLlib internally converts these into N-dim tensors, and will automatically pad the input data to max sizes. This is not efficient for a lot of the use cases.

Ideally, we would create something like SimpleSampleBatch, and DoNothingSampler, which basically collects the list of GNN input data, and hands those directly to a model without trying to mess with the data at all.

Use case

Using GraphNN as the embedding layer for variable # of agents from a multi-agent env seems to be gaining popularity. We should hopefully support this use case seamlessly, and allow users to train a policy that contains GNN layers e2e.

ArchieGertsman commented 1 year ago

Any updates on this? I've implementing a custom Gymnasium environment with a Dict observation space that includes a Graph space. I am able to train a GNN model build on PyG using an asynchronous PG algorithm I implemented. It would be nice to get RLlib involved instead of reinventing the wheel, and so that I can easily explore other algorithms.

trahxam commented 1 year ago

Further to Archie's request - are there any updates on this? I'd also find it very useful to have a way of batching graphs without writing custom implementations for this purpose alone.

kouroshHakha commented 1 year ago

We have now an externally contributed repo that covers these examples.

https://github.com/kk-55/tf-gnn-example-for-rllib

nikikotecha commented 1 year ago

are there any examples which use GNN in PyTorch as a custom model in ray rllib (multi agent). number of agents not changing

Panhaolin2001 commented 2 months ago

are there any examples which use GNN in PyTorch as a custom model in ray rllib (multi agent). number of agents not changing

Did you solve this problem?