pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.63k stars 22.24k forks source link

[Quantization-aware training] Can not quantize nn.MultiheadAttention module #118165

Open taoddiao opened 8 months ago

taoddiao commented 8 months ago

🐛 Describe the bug

import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=256, num_heads=4, batch_first=True)

    def forward(self, x):
        x = self.attn(x, x, x, need_weights=False, is_causal=False)[0] 
        return x

model = Model()
print(model)

model.train(True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig("x86")
torch.ao.quantization.prepare_qat(model, inplace=True)
AssertionError:  qat.Linear.from_float only works for Linear

It seems that NonDynamicallyQuantizableLinear is maped to nnqat.Linear but can not pass the assert. Remove or change the assert condition works normally, but I just wonder whether this affects the quantization accuracy.

Model after modify the assert condition

Model(
  (attn): QuantizableMultiheadAttention(
    (out_proj): Linear(
      in_features=256, out_features=256, bias=True
      (activation_post_process): FusedMovingAvgObsFakeQuantize(
        fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True
        (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
      )
    )
    (linear_Q): Linear(
      in_features=256, out_features=256, bias=True
      (activation_post_process): FusedMovingAvgObsFakeQuantize(
        fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True
        (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
      )
    )
    (linear_K): Linear(
      in_features=256, out_features=256, bias=True
      (activation_post_process): FusedMovingAvgObsFakeQuantize(
        fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True
        (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
      )
    )
    (linear_V): Linear(
      in_features=256, out_features=256, bias=True
      (activation_post_process): FusedMovingAvgObsFakeQuantize(
        fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True
        (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
      )
    )
    (q_scaling_product): FloatFunctional(
      (activation_post_process): FusedMovingAvgObsFakeQuantize(
        fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True
        (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
      )
    )
    (quant_attn_output): QuantStub(
      (activation_post_process): FusedMovingAvgObsFakeQuantize(
        fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True
        (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
      )
    )
    (quant_attn_output_weights): QuantStub(
      (activation_post_process): FusedMovingAvgObsFakeQuantize(
        fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True
        (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
      )
    )
    (dequant_q): DeQuantStub()
    (dequant_k): DeQuantStub()
    (dequant_v): DeQuantStub()
  )
)

Versions

Collecting environment information... PyTorch version: 2.1.0 Is debug build: False CUDA used to build PyTorch: 11.8 ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64) GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0 Clang version: Could not collect CMake version: version 3.26.4 Libc version: glibc-2.31

Python version: 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0] (64-bit runtime) Python platform: Linux-4.15.0-213-generic-x86_64-with-glibc2.31 Is CUDA available: True CUDA runtime version: Could not collect CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: Tesla V100-SXM2-32GB GPU 1: Tesla V100-SXM2-32GB GPU 2: Tesla V100-SXM2-32GB GPU 3: Tesla V100-SXM2-32GB

Nvidia driver version: 470.182.03 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian Address sizes: 46 bits physical, 48 bits virtual CPU(s): 40 On-line CPU(s) list: 0-39 Thread(s) per core: 1 Core(s) per socket: 40 Socket(s): 1 NUMA node(s): 1 Vendor ID: GenuineIntel CPU family: 6 Model: 85 Model name: Intel(R) Xeon(R) Platinum 8255C CPU @ 2.50GHz Stepping: 5 CPU MHz: 2500.000 BogoMIPS: 5000.00 Hypervisor vendor: KVM Virtualization type: full L1d cache: 1.3 MiB L1i cache: 1.3 MiB L2 cache: 160 MiB L3 cache: 35.8 MiB NUMA node0 CPU(s): 0-39 Vulnerability Itlb multihit: KVM: Vulnerable Vulnerability L1tf: Mitigation; PTE Inversion Vulnerability Mds: Vulnerable; SMT Host state unknown Vulnerability Meltdown: Vulnerable Vulnerability Mmio stale data: Vulnerable Vulnerability Retbleed: Vulnerable Vulnerability Spec store bypass: Vulnerable Vulnerability Spectre v1: Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers Vulnerability Spectre v2: Vulnerable, STIBP: disabled, PBRSB-eIBRS: Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Vulnerable Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single fsgsbase bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 arat avx512_vnni

Versions of relevant libraries: [pip3] numpy==1.26.0 [pip3] torch==2.1.0 [pip3] torchaudio==2.1.0 [pip3] torchelastic==0.2.2 [pip3] torchvision==0.16.0 [pip3] triton==2.1.0 [conda] blas 1.0 mkl [conda] ffmpeg 4.3 hf484d3e_0 pytorch [conda] libjpeg-turbo 2.0.0 h9bf148f_0 pytorch [conda] mkl 2023.1.0 h213fc3f_46343 [conda] mkl-service 2.4.0 py310h5eee18b_1 [conda] mkl_fft 1.3.8 py310h5eee18b_0 [conda] mkl_random 1.2.4 py310hdb19cb5_0 [conda] numpy 1.26.0 py310h5f9d8c6_0 [conda] numpy-base 1.26.0 py310hb5e798b_0 [conda] pytorch 2.1.0 py3.10_cuda11.8_cudnn8.7.0_0 pytorch [conda] pytorch-cuda 11.8 h7e8668a_5 pytorch [conda] pytorch-mutex 1.0 cuda pytorch [conda] torchaudio 2.1.0 py310_cu118 pytorch [conda] torchelastic 0.2.2 pypi_0 pypi [conda] torchtriton 2.1.0 py310 pytorch [conda] torchvision 0.16.0 py310_cu118 pytorch

cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen @leslie-fang-intel

ecolss commented 4 months ago

Any advice? To do QAT for models with MHA layer, are we required to define a custom module for MHA module?

taoddiao commented 4 months ago

Any advice? To do QAT for models with MHA layer, are we required to define a custom module for MHA module?

According to latest doc (pytorch 2.3), MHA is still not support for QAT.