Open rahuldey91 opened 3 years ago
Can you show the test code?
I usually wrap the batch data as a torch_geometric.data.batch.Batch
data type, and it supports any dimensionality of x
.
As mentioned by @zcaicaros, GATConv
supports mini-batch computation by wrapping each data
object into a batch via torch_geometric.data.DataLoader
. However, it does not support static graph computation yet (different feature matrices, single edge_index
) without replicating edge_index
via DataLoader
. This is a current limitation as GATConv
needs to learn an attention coefficient for each edge in the mini-batch.
Thanks, I was able to get this to work either by (i) wrapping all graphs using the Batch
object, or (ii) by manually concatenating the batch-samples along the node-axis so as to formBNxC
sized input while simultaneously combining all the B edge_indices
as mentioned above (which is what Batch
object also does). It would be great if this difference for GATConv
is mentioned somewhere in the documentation, but I appreciate your help.
Good idea. I tried to introduce this information into our newly introduced GNN Cheatsheet table, see here.
Good idea. I tried to introduce this information into our newly introduced GNN Cheatsheet table, see here.
That sounds good. Thanks.
@rahuldey91 what are the steps to reproduce the solution in (i)? What I did is:
Data
object for each of them along with the (static) edge-index, and concatenated them in a listThe issue is that the check in the line 202 fails, it tries to go on the else
, and it errors on line 206 (because obviously it's not a tuple of src/dst).
Did I miss something? Thanks.
Your procedure looks correct. How does your final Batch
looks like that you input into GAT
? Can you clarify why the check in line 202 should fail? If x
is not a Tensor
, what else is it in your case? :)
I'm working on traffic data. The shape of a batch of input data is [batch_size, seq_len, n_nodes, d_hidden]
x = torch.randn((8, 12, 207, 16))
# 1200 random edges between 206 nodes, in sparse format.
edge_index = torch.randint(high=206, size=(2, 1200))
# This basically combines the batch and seq len dim into one - new shape is [batch_size * seq_len, n_nodes, d_hidden]
x = einops.rearrange(x, 'b l n f -> (b l) n f')
# Convert from 3D tensor to a list of 2D tensors (each is a graph, shape [n_nodes, d_hidden]).
x = list(x)
# Build a list of Data objects, each containing an item from the list above, and the (same) edge index
x = [Data(x=x_, edge_index=edge_index) for x_ in x]
x = Batch.from_data_list(x)
layer = GATConv(in_channels=16, out_channels=16)
result = layer(x, edge_index=edge_index)
x
is not a tensor, but a DataBatch, so it goes on the "else" in line 205.
Here's an approach that worked for me, in case anyone wants to accomplish something similar. Not sure if this is the most elegant way to accomplish the results, but it works.
x = torch.randn((8, 12, 207, 16))
edge_index = torch.randint(high=206, size=(2, 1200))
x = einops.rearrange(x, 'b l n f -> (b l) n f')
layer = GATConv(in_channels=16, out_channels=16)
result = torch.stack([layer(graph, edge_index=edge_index) for graph in x], dim=0)
The x
in layer needs to correspond to the node feature matrix of your data
/batch
object:
data_list = [Data(x=x_, edge_index=edge_index) for x_ in x]
batch = Batch.from_data_list(data_list)
layer = GATConv(in_channels=16, out_channels=16)
result = layer(batch.x, edge_index=batch.edge_index)
Makes sense. Thank you for pointing out the mistake!
LE: For anyone looking for performance, the torch.stack
approach may be faster than the Batch
, depending on the data.
I did some experiments with data of shape (384, N, 32) with N ranging from 100 to 500, with 10*N edges, and for N>100 the first approach was faster (by up to 2x).
LLE: This is very inefficient. If you just merge the other dimension(s) into the batch dimension apply GAT and then split the dimensions, you get a much faster result (~10x in my case).
As mentioned by @zcaicaros,
GATConv
supports mini-batch computation by wrapping eachdata
object into a batch viatorch_geometric.data.DataLoader
. However, it does not support static graph computation yet (different feature matrices, singleedge_index
) without replicatingedge_index
viaDataLoader
. This is a current limitation asGATConv
needs to learn an attention coefficient for each edge in the mini-batch. How to solve the above problem? I don't understand how the comments say
Makes sense. Thank you for pointing out the mistake!
LE: For anyone looking for performance, the
torch.stack
approach may be faster than theBatch
, depending on the data.I did some experiments with data of shape (384, N, 32) with N ranging from 100 to 500, with 10*N edges, and for N>100 the first approach was faster (by up to 2x).
LLE: This is very inefficient. If you just merge the other dimension(s) into the batch dimension apply GAT and then split the dimensions, you get a much faster result (~10x in my case).
In my practice, the (ii) way mentioned by @rahuldey91 in https://github.com/pyg-team/pytorch_geometric/issues/2844#issuecomment-878758615 would be faster :)
or (ii) by manually concatenating the batch-samples along the node-axis so as to form BNxC sized input while simultaneously combining all the B edge_indices
data_list = [Data(x=x_, edge_index=edge_index) for x_ in x] batch = Batch.from_data_list(data_list) layer = GATConv(in_channels=16, out_channels=16) result = layer(batch.x, edge_index=batch.edge_index)
This is inefficient because it involves a for loop, is there a more efficient way?
I am running a GNN network on a mesh. The inputs are of sizes BxNxC where B is the batch-size, N is the number of input nodes and C is the number of channels per node. This input works well with other kinds of conv layers like GCNConv and ChebConv, but it throws an error called
'Static graphs not supported in GATConv'
in GATConv. Its forward code looks like this:So it seems like its expecting the input x to be of dimensionality two, which is not the case with my input. I have the same issue with GATv2Conv which solves the static graph issue of GATConv. So, does GATConv not support multiple graph inputs as a minibatch? Or is there something I am missing here? Please help.