Closed codyseally closed 2 years ago
They shouldn't consume more memory than GCN
or GAT
for sure. How big is your actual edge_index
matrix and how many nodes does your graph contain?
Hi,
I also have this problem. SAGEConv is out of memory, while GCNConv and GATConv work good. Why SAGEConv consumes more memory than GCNConv and GATConv? My dataset contains 9k nodes, 15k edges and feature dim is ~6k. I understand the graph is a little large, and I don't want to perform neighbour sampling. Do you have any suggestions?
A feature dimension of 6k looks pretty big to me. Why not project your features into a lower-dimensional space beforehand?
Yes, we can do that. I’m just curious why SAGEConv consumes more memory than GATConv. I think if GATConv works, then SAGEConv should work as well.
My understanding of the issue is that SAGEConv
uses two weight matrices, while GATConv
and GCNConv
only use one.
Oh, I see. Thanks. It seems that the implementation is a little different from the original paper. It does not use concat aggregation.
The implementation is the same, just more efficient. Instead of learning a single weight matrix of shape [2 * in_channels, out_channels]
, we use two weight matrices with shape [in_channels, out_channels]
. This helps us to avoid the expensive concatenation of source and destination node features (by replacing it with transform + sum).
Yes. You are right. The implementation is the same. I guess the large memory consumption is caused by some intermediate representations. It’s not caused by the number of weight matrices, since GATConv with 8 heads (concat = False) works.
This is very interesting. I don't see any reason why this should happen in SAGEConv
. Do you know which operation exactly causes OOM problems?
Sorry for the late reply. After checking the code, I find that it might be caused by the propagation in SAGEConv involving a pair of tensors. See code below: https://github.com/pyg-team/pytorch_geometric/blob/893aca527033888df1fbfa7207b6bc34f020ff4a/torch_geometric/nn/conv/sage_conv.py#L116-L124
This will cause huge memory consumption when input feature dimension is large.
Interesting. Notably, we only do shallow copies of node features into tuples, so there shouldn't be an increase in memory. GATConv
also utilizes tuples and does not have the blow-up.
Thanks for pointing this out. Yes, shallow copies do not increase memory consumption. After checking your implementation carefully, I find that it is actually caused by __lift__
function in MessagePassing
when performing propagate
. The codes are as follows:
https://github.com/pyg-team/pytorch_geometric/blob/f8ab880ab2b475e10597b9353ffab6e1270da766/torch_geometric/nn/conv/message_passing.py#L186-L189
The returned tensor does not use the same storage as the original tensor due to the index_select
operation. Thus, the memory consumption suddenly increases for high dimensional input features.
Although the __lift__
function has also been used in other graph convolutional operations, the key difference lies at that GCNConv
and GATConv
conduct linear transformation first, and then propagate
is applied, thus they do not have this OOM issue. However, the implementation of SAGEConv
applies propagate
first, and then conducts linear transformation. In this situation, keeping two high-dimensional feature matrices will be unaffordable for GPU.
I try to fix it in the __lift__
function by releasing some useless variables, but it still has the OOM issue.
Ah, this is a nice finding. Thanks for sharing! I don't think we can release any variables here as they are part of the computation graph. The only option is to use a smaller out_channels
argument?
Yeah, you are right. But I am a little confused by the project
argument in __init__
function of SAGEConv
. According to the code,
https://github.com/pyg-team/pytorch_geometric/blob/f8ab880ab2b475e10597b9353ffab6e1270da766/torch_geometric/nn/conv/sage_conv.py#L88-L92
https://github.com/pyg-team/pytorch_geometric/blob/f8ab880ab2b475e10597b9353ffab6e1270da766/torch_geometric/nn/conv/sage_conv.py#L113-L120
it seem that this operation actually does not perform projection, since its dimension is in_channels[0]. It is actually a nonlinear transformation without dimension reduction.
The OOM issue can be solved by projecting x into low-dimensional space, i.e., using out_channels?
Note that project
is False
by default. However, you are right, the first transformation will not do a dimensionality reduction.
Yes, I understand project
is False
by default. We can conduct dimensionality reduction in project
, and the users can decide to use it or not. Then, the OOM issue is solved.
You can close this issue now. Thanks for you nice work on this wonderful GNN library.
I am trying to run the following 5 algorithms on my custom dataset (one single graph):
In all cases the task is node classification.
Three of them run perfectly fine, but when I replace the layer of my network with either SAGEConv or GraphConv, I get a memory error saying I am trying to allocate 667946000000 bytes to memory.
Here is my small network:
and here is the error:
Is there any reason why these two layer-types use so much memory? Is there something I can do to solve this problem?
EDIT : In the case of SAGEConv, I managed to run it by using the SparseTensor functionality. However it seems GraphConv does not work with SparseTensors.