pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
84.7k stars 22.81k forks source link

resutl of (torch.mm(a,b) does not match result of (a[:part,:], b) #108521

Open head-with-nothing opened 1 year ago

head-with-nothing commented 1 year ago

🐛 Describe the bug

I tried to use torch.mm compute block matrix multiplication severally instead of computing the result once , but I found the results of two computation are not close. For example, when $a \in R^(m \times n)$, $a1 = a[:m/2, :]$,. $a2 = [m/2:,:]$ , $b \in R^{n \times kl}$, torch.mm(a * b) should equal to torch.cat(torch.mm(a1, b), torch.mm(a2, b)), but actually they do not match.

The following code presents this problem.

import torch 

def test(m, n, k, dtype, rtol, atol):
    a = torch.randn(m, k, dtype=dtype).cuda() 
    b = torch.randn(k, n, dtype=dtype).cuda() 

    c = torch.mm(a, b) 
    for i in range(1,m+1):
        d = torch.mm(a[:i,:] , b)
        if not torch.allclose(d, c[:i, :], rtol, atol):
            print(f'Not match, m={i} {n=} {k=} {dtype=}')

dtypes = [(1.e-3, 1e-5, torch.float16),  (1.e-3, 1e-5, torch.bfloat16), (1e-5, 1e-8,torch.float32)]
for r, a, dtype in dtypes:
    for m in [4, 8, 16]:
        for n in [256, 512]:
            for k in [256, 512]:
                test(m, n, k, dtype, r, a)

Versions

PyTorch version: 2.0.1+cu117 Is debug build: False CUDA used to build PyTorch: 11.7 ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.2 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: 14.0.0-1ubuntu1.1 CMake version: version 3.27.2 Libc version: glibc-2.35

Python version: 3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0] (64-bit runtime) Python platform: Linux-5.10.16.3-microsoft-standard-WSL2-x86_64-with-glibc2.35 Is CUDA available: True CUDA runtime version: Could not collect CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3060 Ti Nvidia driver version: 536.67 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 39 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 6 On-line CPU(s) list: 0-5 Vendor ID: GenuineIntel Model name: 12th Gen Intel(R) Core(TM) i5-12400 CPU family: 6 Model: 151 Thread(s) per core: 2 Core(s) per socket: 3 Socket(s): 1 Stepping: 5 BogoMIPS: 4991.99 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves umip waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm serialize flush_l1d arch_capabilities Virtualization: VT-x Hypervisor vendor: Microsoft Virtualization type: full L1d cache: 144 KiB (3 instances) L1i cache: 96 KiB (3 instances) L2 cache: 3.8 MiB (3 instances) L3 cache: 18 MiB (1 instance) Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.25.2 [pip3] torch==2.0.1 [pip3] triton==2.0.0 [pip3] tritonclient==2.36.0 [conda] Could not collect

ezyang commented 1 year ago

Spot checking some of the outputs from your script, they look fine and probably your tolerances are too strict.

tringwald commented 1 year ago

Setting the environment variable CUBLAS_WORKSPACE_CONFIG=:16:8 is a quick way to fix most of the accuracy issues in the given code. Note that this will most likely degrade performance^1. With the env variable set, most absolute differences went down to the 1e-5 region.

The root cause of the problem, however, is the usage of torch.randn. This will sample random values from a normal distribution, which is (theoretically) unbounded. Both very large and very small numbers will be sampled, which is the worst case scenario for floating point math. Switching to torch.rand makes the code above pass without any errors.