Open ethanabrooks opened 2 years ago
Thanks for creating this issue. Can you clarify how it may look like to already batch "batched data"? Won't there be an inherent limitation of only being able to batch equally-sized data?
Hi Rusty. Yes our expectation is that the batched data would be padded to a maximum size with a designated padding index.
I should add that actually the assumption of batched data is woven into many places in many reinforcement learning algorithms, so overriding methods in SubprocVecEnv
would actually not be sufficient. On-policy algorithms have to store "rollouts" of experience in between gradient updates, and in order to keep access to these rollouts space-efficient, observations are generally stacked. Off-policy algorithms generally maintain a replay buffer which again, requires stacking of observations.
I feel it might be the easiest solution to write your own batching solution for this. It is hard to provide a general solution here that is applicable in many different scenarios. Would you agree? Happy to help if you need help in doing so.
I have looked closely at the source code for from_data_list
-- specifically this: https://github.com/pyg-team/pytorch_geometric/blob/be9e4af760dbb7c201515cdaff8edacea14d7e3d/torch_geometric/data/collate.py#L13. I am not entirely sure what assumptions are being made about out.stores
. What properties need to be true for us to end up with a healthy Batch()
object at the end?
The current collate
and separate
functions are quite complex since they also allow for recursive collating and supports different data formats such as Data
and HeteroData
. You can probably achieve something more simpler than this in your case, e.g. (untested):
x = ... [batch_size, num_nodes, num_features]
adj = [batch_size, num_nodes, num_nodes]
def collate(x, adj):
x = x.view(batch_size * num_nodes, num_features)
edge_index = adj.nonzero()[:1]
inc = torch.arange(0, num_nodes * batch_size, batch_size).view(-1, 1)
edge_index += inc
edge_index = edge_index.t()
return x, edge_index
Hi @rusty1s a member of my team came up with a solution to this based on what you wrote. We will be sharing shortly. Thanks for your help!
For x
, we reshaped with the padding as the padded nodes will not induce message passing since they are disjoint from the rest of the graphs. For edge_index
, we used the inc
tensor you suggested, and we removed the padded edges with boolean indexing. A similar approach was used for edge_attr
.
The code that we used is as follows:
# x.shape = batch_size * max_num_nodes * embedding_size
# edge_index.shape = batch_size * 2 * max_num_edges
# edge_attr.shape = batch_size * max_num_edges * embedding_size
def collate(x, edge_index, edge_attr):
b, num_nodes, num_features = x.shape
num_edges = torch.count_nonzero(edge_index + 1, dim=2)[:, 0]
x = x.reshape(b * num_nodes, num_features)
inc = torch.arange(0, b * num_nodes, num_nodes, device=edge_index.device).reshape(b, 1, 1)
edge_index = (edge_index + inc).transpose(-2, -1)
max_num_edges = edge_index.shape[1]
num_features = edge_attr.shape[-1]
grid = torch.arange(max_num_edges, device=edge_index.device).repeat(b, 2, 1).transpose(-2, -1)
mask = grid < num_edges.reshape(b, 1, 1)
edge_index = edge_index[mask].reshape(-1, 2).transpose(-2, -1)
grid = torch.arange(max_num_edges, device=edge_attr.device).repeat(b, num_features, 1).transpose(-2, -1)
mask = grid < num_edges.reshape(b, 1, 1)
edge_attr = edge_attr[mask].reshape(-1, num_features)
return x, edge_index, edge_attr
For separate()
, since we only needed the nodes, we just reshaped x
back to the original shape.
Thanks for your help in this issue!
🚀 The feature, motivation and pitch
PyG currently offers only one method for creating
Batch
objects, namelyfrom_data_list
. This method takes a list ofBaseData
objects as input. However, there are settings, such as Reinforcement Learning, where it is commonplace to process data that is already batched, e.g. usingstable_baselines3
SubprocVecEnv. In these settings, it becomes necessary to unbatch the data and subsequently re-batch it usingfrom_data_list
. It would be helpful if there was a method that accepted already batched data in order to createBatch
objects.Alternatives
Currently our best alternatives are un-batching and re-batching, as described above, or overriding the methods in
SubprocVecEnv
.Additional context
I am willing to work on a Pull Request for this.