pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.51k stars 3.69k forks source link

Add `from_data` builder method to `Batch` that accepts batched data. #5083

Open ethanabrooks opened 2 years ago

ethanabrooks commented 2 years ago

🚀 The feature, motivation and pitch

PyG currently offers only one method for creating Batch objects, namely from_data_list. This method takes a list of BaseData objects as input. However, there are settings, such as Reinforcement Learning, where it is commonplace to process data that is already batched, e.g. using stable_baselines3 SubprocVecEnv. In these settings, it becomes necessary to unbatch the data and subsequently re-batch it using from_data_list. It would be helpful if there was a method that accepted already batched data in order to create Batch objects.

Alternatives

Currently our best alternatives are un-batching and re-batching, as described above, or overriding the methods in SubprocVecEnv.

Additional context

I am willing to work on a Pull Request for this.

rusty1s commented 2 years ago

Thanks for creating this issue. Can you clarify how it may look like to already batch "batched data"? Won't there be an inherent limitation of only being able to batch equally-sized data?

ethanabrooks commented 2 years ago

Hi Rusty. Yes our expectation is that the batched data would be padded to a maximum size with a designated padding index.

I should add that actually the assumption of batched data is woven into many places in many reinforcement learning algorithms, so overriding methods in SubprocVecEnv would actually not be sufficient. On-policy algorithms have to store "rollouts" of experience in between gradient updates, and in order to keep access to these rollouts space-efficient, observations are generally stacked. Off-policy algorithms generally maintain a replay buffer which again, requires stacking of observations.

rusty1s commented 2 years ago

I feel it might be the easiest solution to write your own batching solution for this. It is hard to provide a general solution here that is applicable in many different scenarios. Would you agree? Happy to help if you need help in doing so.

ethanabrooks commented 2 years ago

I have looked closely at the source code for from_data_list -- specifically this: https://github.com/pyg-team/pytorch_geometric/blob/be9e4af760dbb7c201515cdaff8edacea14d7e3d/torch_geometric/data/collate.py#L13. I am not entirely sure what assumptions are being made about out.stores. What properties need to be true for us to end up with a healthy Batch() object at the end?

rusty1s commented 2 years ago

The current collate and separate functions are quite complex since they also allow for recursive collating and supports different data formats such as Data and HeteroData. You can probably achieve something more simpler than this in your case, e.g. (untested):

x = ... [batch_size, num_nodes, num_features]
adj = [batch_size, num_nodes, num_nodes]

def collate(x, adj):
    x = x.view(batch_size * num_nodes, num_features)

    edge_index = adj.nonzero()[:1]
    inc = torch.arange(0, num_nodes * batch_size, batch_size).view(-1, 1)
    edge_index += inc
    edge_index = edge_index.t()

    return x, edge_index
ethanabrooks commented 2 years ago

Hi @rusty1s a member of my team came up with a solution to this based on what you wrote. We will be sharing shortly. Thanks for your help!

yikai5518 commented 2 years ago

For x, we reshaped with the padding as the padded nodes will not induce message passing since they are disjoint from the rest of the graphs. For edge_index, we used the inc tensor you suggested, and we removed the padded edges with boolean indexing. A similar approach was used for edge_attr.

The code that we used is as follows:

# x.shape = batch_size * max_num_nodes * embedding_size
# edge_index.shape = batch_size * 2 * max_num_edges
# edge_attr.shape = batch_size * max_num_edges * embedding_size

def collate(x, edge_index, edge_attr):
    b, num_nodes, num_features = x.shape
    num_edges = torch.count_nonzero(edge_index + 1, dim=2)[:, 0]
    x = x.reshape(b * num_nodes, num_features)

    inc = torch.arange(0, b * num_nodes, num_nodes, device=edge_index.device).reshape(b, 1, 1)
    edge_index = (edge_index + inc).transpose(-2, -1)

    max_num_edges = edge_index.shape[1]
    num_features = edge_attr.shape[-1]

    grid = torch.arange(max_num_edges, device=edge_index.device).repeat(b, 2, 1).transpose(-2, -1)
    mask = grid < num_edges.reshape(b, 1, 1)
    edge_index = edge_index[mask].reshape(-1, 2).transpose(-2, -1)

    grid = torch.arange(max_num_edges, device=edge_attr.device).repeat(b, num_features, 1).transpose(-2, -1)
    mask = grid < num_edges.reshape(b, 1, 1)
    edge_attr = edge_attr[mask].reshape(-1, num_features)

    return x, edge_index, edge_attr

For separate(), since we only needed the nodes, we just reshaped x back to the original shape.

Thanks for your help in this issue!