huggingface / optimum-quanto

A pytorch quantization backend for optimum
Apache License 2.0
833 stars 62 forks source link

fix(library): disable int_mm for CPU #309

Closed dacorvo closed 2 months ago

dacorvo commented 2 months ago

The CPU implementation still has accuracy issues with pytorch 2.4.1.

maktukmak commented 2 months ago

@dacorvo, I could not reproduce this issue. Could you give me more details?

dacorvo commented 2 months ago

When you do a matmul between two qint8 QTensor, quanto tries to use torch._int_mm whenever it is available. However, when using the op on CPU, there is a larger mismatch in the output that when using it on CUDA.

maktukmak commented 2 months ago

I am comparing the results with torch 2.4.1 using the following test:

import pytest  
import torch  

@pytest.mark.parametrize("device", ['cpu', 'cuda'])
@pytest.mark.parametrize("m", [32, 64])
@pytest.mark.parametrize("k", [32, 64])
@pytest.mark.parametrize("n", [32, 64])
@pytest.mark.parametrize("use_transpose_a", [True, False])
@pytest.mark.parametrize("use_transpose_b", [True, False])
@pytest.mark.parametrize("non_contig_type", [0, 1, 2])
def test__int_mm_cpu(device, m, k, n, use_transpose_a, use_transpose_b, non_contig_type):

    # non_contig_type:
    # 0: the whole data buffer is contiguous (can be transposed)
    # 1: stride of one dimension is 1, but the whole buffer is not contiguous
    # 2: Neither stride is 1

    def genf_int_float(x, y, use_transpose, non_contig_type):
        if use_transpose:
            x, y = y, x
        if non_contig_type != 0:
            y = y * 2
        x_int8 = torch.randint(-10, 10, (x, y), dtype=torch.int8, device=device)
        x_float = x_int8.to(torch.float32)
        if non_contig_type == 1:
            x_int8 = x_int8[:, : y // 2]
            x_float = x_float[:, : y // 2]
        elif non_contig_type == 2:
            x_int8 = x_int8[:, ::2]
            x_float = x_float[:, ::2]
        if use_transpose:
            return x_int8.t(), x_float.t()
        return x_int8, x_float

    if non_contig_type != 0 and (m == 0 or k == 0):
        return
    a_int8, a_float = genf_int_float(m, k, use_transpose_a, non_contig_type)
    b_int8, b_float = genf_int_float(k, n, use_transpose_b, non_contig_type)
    c_int32 = torch._int_mm(a_int8, b_int8)
    assert torch.equal(c_int32.float(), torch.mm(a_float, b_float))
    c_int32_result = c_int32.new_empty(c_int32.size())
    torch._int_mm(a_int8, b_int8, out=c_int32_result)
    assert torch.equal(c_int32_result.float(), torch.mm(a_float, b_float))

and all pass. My system has v100 as GPU and Xeon Platinium 8380 as CPU. Could you give me a test that fails so I can take a look?

dacorvo commented 2 months ago

I am comparing the results with torch 2.4.1 using the following test: ...

and all pass. My system has v100 as GPU and Xeon Platinium 8380 as CPU. Could you give me a test that fails so I can take a look?

Just increase the range of your random integers and you will see all your CPU tests failing, while the CUDA tests are still passing.

 x_int8 = torch.randint(-128, 128, (x, y), dtype=torch.int8, device=device)

This looks like there is an overflow in the accumulator on CPU.

maktukmak commented 2 months ago

The test is still passing on my CPU. What CPU do you run the test on? Can you copy past lscpu output?

dacorvo commented 2 months ago

It is an AMD EPYC 7R32

Architecture:            x86_64       
  CPU op-mode(s):        32-bit, 64-bit
  Address sizes:         48 bits physical, 48 bits virtual                                                                                                                           
  Byte Order:            Little Endian                                                    
CPU(s):                  8                                                                
  On-line CPU(s) list:   0-7                                                                                                                                                         
Vendor ID:               AuthenticAMD                                                                                                                                                
  Model name:            AMD EPYC 7R32      
    CPU family:          23          
    Model:               49              
    Thread(s) per core:  2                    
    Core(s) per socket:  4                    
    Socket(s):           1                    
    Stepping:            0                    
    BogoMIPS:            5599.64              
    Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_t   
                         sc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand h   
                         ypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch topoext ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx sma   
                         p clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save rdpid                                                     
Virtualization features:                      
  Hypervisor vendor:     KVM                  
  Virtualization type:   full                 
Caches (sum of all):                          
  L1d:                   128 KiB (4 instances)                                              
  L1i:                   128 KiB (4 instances)                                              
  L2:                    2 MiB (4 instances)  
  L3:                    16 MiB (1 instance)  
NUMA:                                         
  NUMA node(s):          1                    
  NUMA node0 CPU(s):     0-7                  
Vulnerabilities:                              
  Gather data sampling:  Not affected         
  Itlb multihit:         Not affected         
  L1tf:                  Not affected         
  Mds:                   Not affected         
  Meltdown:              Not affected         
  Mmio stale data:       Not affected         
  Retbleed:              Mitigation; untrained return thunk; SMT enabled with STIBP protection                                                                                          
  Spec rstack overflow:  Vulnerable: Safe RET, no microcode                                 
  Spec store bypass:     Mitigation; Speculative Store Bypass disabled via prctl            
  Spectre v1:            Mitigation; usercopy/swapgs barriers and __user pointer sanitization                                                                                           
  Spectre v2:            Mitigation; Retpolines; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected                                             
  Srbds:                 Not affected         
  Tsx async abort:       Not affected

So maybe if I can detect that from python I can change the condition and allow _int_mm for Intel CPUs.

maktukmak commented 2 months ago

torch.backends.mkldnn.is_available() should indicate if you can perform _int_mm on Intel CPUs or not, according to https://oneapi-src.github.io/oneDNN/v1.0/dev_guide_int8_computations.html.