pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.02k stars 21.99k forks source link

torch.Tensor.is_sparse returns false for non-COO sparse tensors #101385

Open alexisthual opened 1 year ago

alexisthual commented 1 year ago

🐛 Describe the bug

In my humble opinion, torch.Tensor.is_sparse should return True for all possible layouts of sparse matrices. For now, I think it only does so for COO tensors. You can reproduce with the following snippet:

import torch

a = torch.tensor([[1, 2, 3], [0, 0, 7]])
b = a.to_sparse().to_sparse_csc()

b.is_sparse

I think that this function is the culprit: https://github.com/pytorch/pytorch/blob/7dd8e08817ee59c926922409062e25f30408469b/torch/_linalg_utils.py#L11-L19

Checking for all possible sparse layouts should probably be the default scenario. Maybe checking that "sparse" is in the layout name would work.

import torch

a = torch.tensor([[1, 2, 3], [0, 0, 7]])

b = a.to_sparse().to_sparse_coo()
print("COO", b.is_sparse, str(b.layout).find("sparse") >= 0)

b = a.to_sparse().to_sparse_csc()
print("CSC", b.is_sparse, str(b.layout).find("sparse") >= 0)

b = a.to_sparse().to_sparse_csr()
print("CSR", b.is_sparse, str(b.layout).find("sparse") >= 0)

b = a.to_sparse().to_sparse_bsc(blocksize=1)
print("BSC", b.is_sparse, str(b.layout).find("sparse") >= 0)

b = a.to_sparse().to_sparse_bsr(blocksize=1)
print("BSR", b.is_sparse, str(b.layout).find("sparse") >= 0)

yields

COO True True
CSC False True
CSR False True
BSC False True
BSR False True

Versions


Collecting environment information...                                                                                                                                                                                
PyTorch version: 2.0.0+cu117                                                                                                                                                                                         
Is debug build: False                                                                                                                                                                                                
CUDA used to build PyTorch: 11.7                                                                                                                                                                                     
ROCM used to build PyTorch: N/A                                                                                                                                                                                      

OS: Ubuntu 18.04.5 LTS (x86_64)                                                                                                                                                                                      
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0                                                                                                                                                                     
Clang version: Could not collect                                                                                                                                                                                     
CMake version: version 3.26.3
Libc version: glibc-2.27

Python version: 3.10.10 | packaged by conda-forge | (main, Mar 24 2023, 20:08:06) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-4.15.0-112-generic-x86_64-with-glibc2.27
Is CUDA available: True
CUDA runtime version: 11.7.64
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: Tesla V100-DGXS-32GB
GPU 1: Tesla V100-DGXS-32GB
GPU 2: Tesla V100-DGXS-32GB
GPU 3: Tesla V100-DGXS-32GB

Nvidia driver version: 450.51.05
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture :                          x86_64
Mode(s) opératoire(s) des processeurs : 32-bit, 64-bit
Boutisme :                              Little Endian
Processeur(s) :                         40
Liste de processeur(s) en ligne :       0-39
Thread(s) par cœur :                    2
Cœur(s) par socket :                    20
Socket(s) :                             1
Nœud(s) NUMA :                          1
Identifiant constructeur :              GenuineIntel
Famille de processeur :                 6
Modèle :                                79
Nom de modèle :                         Intel(R) Xeon(R) CPU E5-2698 v4 @ 2.20GHz
Révision :                              1
Vitesse du processeur en MHz :          1898.795
Vitesse maximale du processeur en MHz : 3600,0000
Vitesse minimale du processeur en MHz : 1200,0000
BogoMIPS :                              4397.21
Virtualisation :                        VT-x
Cache L1d :                             32K
Cache L1i :                             32K
Cache L2 :                              256K
Cache L3 :                              51200K
Nœud NUMA 0 de processeur(s) :          0-39
Drapaux :                               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single pti intel_ppin ssbd ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm rdt_a rdseed adx smap intel_pt xsaveopt cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts md_clear flush_l1d

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.3
[pip3] torch==2.0.0
[pip3] torchaudio==2.0.1
[pip3] torchvision==0.15.1
[pip3] triton==2.0.0
[conda] cudatoolkit-dev           11.7.0               h1de0b5d_6    conda-forge
[conda] numpy                     1.24.3                   pypi_0    pypi
[conda] torch                     2.0.0                    pypi_0    pypi
[conda] torchaudio                2.0.1                    pypi_0    pypi
[conda] torchvision               0.15.1                   pypi_0    pypi
[conda] triton                    2.0.0                    pypi_0    pypi

cc @alexsamardzic @nikitaved @pearu @cpuhrsch @amjames @bhosmer
vadimkantorov commented 1 year ago

An alternative (if it makes sense): is_sparse could be deprecated in favor of is_dense and more specific is_sparse_*, as currently the supported ops are very specific for this or that sparse layout