rusty1s / pytorch_sparse

PyTorch Extension Library of Optimized Autograd Sparse Matrix Operations
MIT License
1.01k stars 147 forks source link

GPU version of spspmm seems not work when using RTX3090 but it works normally when using P100 #205

Closed wrccrwx closed 2 years ago

wrccrwx commented 2 years ago

Hi!

The minimum case to reproduce the situation (however, I'm wondering whether I missed something or I didn't use it correctly):

import torch
from torch_sparse import spspmm

indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]]).to("cuda")
valueA = torch.Tensor([1, 2, 3, 4, 5]).to("cuda")

indexB = torch.tensor([[0, 2], [1, 0]]).to("cuda")
valueB = torch.Tensor([2, 4]).to("cuda")

indexC, valueC = spspmm(indexA, valueA, indexB, valueB, 3, 3, 2)

When I using spspmm with gpu in colab with Tesla P100, it works fine. Today I happened to use the same code with RTX3090, it reported either: Out of memory error or: RuntimeError: Trying to create tensor with negative dimension: -1833421600: [-1833421600] (which is a large random negative value, may be type overflow)

In colab, I use

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
!pip install -q git+https://github.com/rusty1s/pytorch_geometric.git

to install the pytorch_sparse, and the pytorch version is 1.10.0+cu111.

For my local machine, I use conda install pyg -c pyg -c conda-forge , and the pytorch version is 1.10.1. py3.9_cuda11.3_cudnn8_0. My cuda version is 11.1 .

When I use CPU version on my local machine, it works fine.

Thank you.

rusty1s commented 2 years ago

This looks to be related to https://github.com/rusty1s/pytorch_sparse/issues/205. I will answer in this issue.

LCHJ commented 2 years ago

RuntimeError: Trying to create tensor with negative dimension -1387719423: [-1387719423]

from torch_sparse import SparseTensor
z_sparse = SparseTensor.from_dense(z)
adj_re = z_sparse.matmul(z_sparse.t())`

I encountered the same error when I used torch_sparse to do the multiplication of two sparse matrices. I guess that the operation needs more memory space. Because the error wouldn't occur if my input is sufficiently small or if the GPU has a greater RAM. But the important question is, how do I fix the current error?

rusty1s commented 2 years ago

Can you try if the sparse-sparse matrix multiplication from PyTorch works for you? This might be a good workaround:

z_sparse = z_sparse.to_torch_sparse_coo_tensor()
out = z_sparse @ z_sparse.t()