pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.52k stars 3.69k forks source link

Abnormal memory usage of SageConv and GraphConv #2074

Closed codyseally closed 2 years ago

codyseally commented 3 years ago

I am trying to run the following 5 algorithms on my custom dataset (one single graph):

- GCNConv
- SAGEConv
- GATConv
- GraphConv
- HyperGraphConv

In all cases the task is node classification.

Three of them run perfectly fine, but when I replace the layer of my network with either SAGEConv or GraphConv, I get a memory error saying I am trying to allocate 667946000000 bytes to memory.

Here is my small network:

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        #torch.manual_seed(12345)
        #self.conv1 = GCNConv(dataset_num_features, 16)
        #self.conv1 = SAGEConv(dataset_num_features, 16)     # Does NOT work.
        #self.conv1 = GATConv(dataset_num_features, 16, 10, concat=False)
        self.conv1 = GraphConv(dataset_num_features, 16, aggr='mean')  # Does NOT work.
        #self.conv1 = HypergraphConv(dataset_num_features, 16, use_attention=True, heads=5, concat=False)   

        self.lin = nn.Linear(hidden_channels, dataset_num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = x.relu()
        #x = global_mean_pool(x, batch) 
        x = F.dropout(x, p=0.8, training=self.training)
        x = self.lin(x)
        return x

and here is the error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-7-c43cd0da8520> in <module>
      9     data.to(device)
     10     optimizer.zero_grad()
---> 11     out = model(data)
     12     loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask].squeeze())
     13     loss.backward()

~/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

<ipython-input-6-b86f676abb44> in forward(self, data)
     18         x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
     19 
---> 20         x = self.conv1(x, edge_index)#, edge_weight=edge_attr.squeeze())
     21         x = x.relu()
     22         #x = global_mean_pool(x, batch)

~/anaconda3/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

~/anaconda3/envs/py38/lib/python3.8/site-packages/torch_geometric/nn/conv/graph_conv.py in forward(self, x, edge_index, edge_weight, size)
     60 
     61         # propagate_type: (x: OptPairTensor, edge_weight: OptTensor)
---> 62         out = self.propagate(edge_index, x=x, edge_weight=edge_weight,
     63                              size=size)
     64         out = self.lin_l(out)

~/anaconda3/envs/py38/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py in propagate(self, edge_index, size, **kwargs)
    231         # Otherwise, run both functions in separation.
    232         elif isinstance(edge_index, Tensor) or not self.fuse:
--> 233             coll_dict = self.__collect__(self.__user_args__, edge_index, size,
    234                                          kwargs)
    235 

~/anaconda3/envs/py38/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py in __collect__(self, args, edge_index, size, kwargs)
    155                 if isinstance(data, Tensor):
    156                     self.__set_size__(size, dim, data)
--> 157                     data = self.__lift__(data, edge_index,
    158                                          j if arg[-2:] == '_j' else i)
    159 

~/anaconda3/envs/py38/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py in __lift__(self, src, edge_index, dim)
    125         if isinstance(edge_index, Tensor):
    126             index = edge_index[dim]
--> 127             return src.index_select(self.node_dim, index)
    128         elif isinstance(edge_index, SparseTensor):
    129             if dim == 1:

RuntimeError: [enforce fail at CPUAllocator.cpp:64] . DefaultCPUAllocator: can't allocate memory: you tried to allocate 667946000000 bytes. Error code 12 (Cannot allocate memory)

Is there any reason why these two layer-types use so much memory? Is there something I can do to solve this problem?

EDIT : In the case of SAGEConv, I managed to run it by using the SparseTensor functionality. However it seems GraphConv does not work with SparseTensors.

rusty1s commented 3 years ago

They shouldn't consume more memory than GCN or GAT for sure. How big is your actual edge_index matrix and how many nodes does your graph contain?

cszhangzhen commented 2 years ago

Hi,

I also have this problem. SAGEConv is out of memory, while GCNConv and GATConv work good. Why SAGEConv consumes more memory than GCNConv and GATConv? My dataset contains 9k nodes, 15k edges and feature dim is ~6k. I understand the graph is a little large, and I don't want to perform neighbour sampling. Do you have any suggestions?

rusty1s commented 2 years ago

A feature dimension of 6k looks pretty big to me. Why not project your features into a lower-dimensional space beforehand?

cszhangzhen commented 2 years ago

Yes, we can do that. I’m just curious why SAGEConv consumes more memory than GATConv. I think if GATConv works, then SAGEConv should work as well.

rusty1s commented 2 years ago

My understanding of the issue is that SAGEConv uses two weight matrices, while GATConv and GCNConv only use one.

cszhangzhen commented 2 years ago

Oh, I see. Thanks. It seems that the implementation is a little different from the original paper. It does not use concat aggregation.

rusty1s commented 2 years ago

The implementation is the same, just more efficient. Instead of learning a single weight matrix of shape [2 * in_channels, out_channels], we use two weight matrices with shape [in_channels, out_channels]. This helps us to avoid the expensive concatenation of source and destination node features (by replacing it with transform + sum).

cszhangzhen commented 2 years ago

Yes. You are right. The implementation is the same. I guess the large memory consumption is caused by some intermediate representations. It’s not caused by the number of weight matrices, since GATConv with 8 heads (concat = False) works.

rusty1s commented 2 years ago

This is very interesting. I don't see any reason why this should happen in SAGEConv. Do you know which operation exactly causes OOM problems?

cszhangzhen commented 2 years ago

Sorry for the late reply. After checking the code, I find that it might be caused by the propagation in SAGEConv involving a pair of tensors. See code below: https://github.com/pyg-team/pytorch_geometric/blob/893aca527033888df1fbfa7207b6bc34f020ff4a/torch_geometric/nn/conv/sage_conv.py#L116-L124

This will cause huge memory consumption when input feature dimension is large.

rusty1s commented 2 years ago

Interesting. Notably, we only do shallow copies of node features into tuples, so there shouldn't be an increase in memory. GATConv also utilizes tuples and does not have the blow-up.

cszhangzhen commented 2 years ago

Thanks for pointing this out. Yes, shallow copies do not increase memory consumption. After checking your implementation carefully, I find that it is actually caused by __lift__ function in MessagePassing when performing propagate. The codes are as follows: https://github.com/pyg-team/pytorch_geometric/blob/f8ab880ab2b475e10597b9353ffab6e1270da766/torch_geometric/nn/conv/message_passing.py#L186-L189

The returned tensor does not use the same storage as the original tensor due to the index_select operation. Thus, the memory consumption suddenly increases for high dimensional input features.

Although the __lift__ function has also been used in other graph convolutional operations, the key difference lies at that GCNConv and GATConv conduct linear transformation first, and then propagate is applied, thus they do not have this OOM issue. However, the implementation of SAGEConv applies propagate first, and then conducts linear transformation. In this situation, keeping two high-dimensional feature matrices will be unaffordable for GPU.

I try to fix it in the __lift__ function by releasing some useless variables, but it still has the OOM issue.

rusty1s commented 2 years ago

Ah, this is a nice finding. Thanks for sharing! I don't think we can release any variables here as they are part of the computation graph. The only option is to use a smaller out_channels argument?

cszhangzhen commented 2 years ago

Yeah, you are right. But I am a little confused by the project argument in __init__ function of SAGEConv. According to the code, https://github.com/pyg-team/pytorch_geometric/blob/f8ab880ab2b475e10597b9353ffab6e1270da766/torch_geometric/nn/conv/sage_conv.py#L88-L92 https://github.com/pyg-team/pytorch_geometric/blob/f8ab880ab2b475e10597b9353ffab6e1270da766/torch_geometric/nn/conv/sage_conv.py#L113-L120

it seem that this operation actually does not perform projection, since its dimension is in_channels[0]. It is actually a nonlinear transformation without dimension reduction.

The OOM issue can be solved by projecting x into low-dimensional space, i.e., using out_channels?

rusty1s commented 2 years ago

Note that project is False by default. However, you are right, the first transformation will not do a dimensionality reduction.

cszhangzhen commented 2 years ago

Yes, I understand project is False by default. We can conduct dimensionality reduction in project, and the users can decide to use it or not. Then, the OOM issue is solved.

You can close this issue now. Thanks for you nice work on this wonderful GNN library.