Closed DzAvril closed 3 months ago
Hi, currently in llama export, SDPA is dispatched through a custom optimized operator: sdpa_with_kv_cache, this is because it is optimized with kv_cache operations which was faster than the XNNPACK sdpa. Currently for Llama, we only delegate the dynamically quantized Linear Operations. We do so by delegating via to_backend(XnnpackDynamicallyQuantizedPartitioner)
. We have more operator support for statically quantized and fp32 operators. These can be accessed through to_backend(XnnpackPartitioner)
. If you want to try delegating more operators you can try finding the to_backend(XnnpackDynamicallyQuantizedPartitioner)
call and then running to_backend(XnnpackPartitioner)
right after it. However, I believe the llama export path does some sdpa_with_kv_cache replacement so that XNNPACK does not delegate this op.
For general SDPA support, it is currently supported through to_backend(XnnpackPartitioner)
(Only when mask is rank 2, this is an XNNPACK constraint).
Hi Max, thanks for your response.
this is because it is optimized with kv_cache operations which was faster than the XNNPACK sdpa
Is the custom sdpa_kv_cache faster than XNNPACK’s SDPA on the aarch64 platform? I profiled the latency of the custom sdpa_kv_cache on the sa8295, and the latency was significantly higher compared to MNN, which is a highly efficient and lightweight deep learning framework. This is why I’m considering dispatching SDPA to XNNPACK and want to explore if SDPA performance in XNNPACK is superior to that of MNN.
Thank you!
I noticed that the GEMM kernel utilized in the custom sdpa_with_kv_cache uses the naive gemm_impl implementation instead of the optimized version provided by CPU BLAS.
In kernels/optimized/CMakeLists.txt
, the DET_BUILD_WITH_BLAS
compile option is set. However, I’m not clear on why the BLAS version of the GEMM branch isn’t being triggered.
Additionally, is there a CPU BLAS library available for AARCH64? I am interested in running inference on a Qualcomm SA8295
🐛 Describe the bug
I’m currently working on dispatching the SDPA operations to XNNPACK. To accomplish this, I’ve added
torch.nn.functional.scaled_dot_product_attention
to theSUPPORTED_DYN_QUANT_LINEAR_MODULES
in thebackends/xnnpack/partition/configs.py
file, as shown in the code block below.I attempted to run the llama example using the following command:
Unfortunately, an error occurred. Please find the full backtrace attached below.
I believe the SDPA can be integrated with XNNPACK, but I'm unsure of the correct approach. Could you please offer guidance on how to do this?
Versions
Collecting environment information... PyTorch version: 2.4.0a0+git9afe4ec Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.4 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: Could not collect CMake version: version 3.30.0 Libc version: glibc-2.35
Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime) Python platform: Linux-6.5.0-14-generic-x86_64-with-glibc2.35 Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True
CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 46 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 36 On-line CPU(s) list: 0-35 Vendor ID: GenuineIntel Model name: Intel(R) Core(TM) i9-10980XE CPU @ 3.00GHz CPU family: 6 Model: 85 Thread(s) per core: 2 Core(s) per socket: 18 Socket(s): 1 Stepping: 7 CPU max MHz: 4500.0000 CPU min MHz: 1200.0000 BogoMIPS: 6000.00 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts vnmi avx512_vnni md_clear flush_l1d arch_capabilities Virtualization: VT-x L1d cache: 576 KiB (18 instances) L1i cache: 576 KiB (18 instances) L2 cache: 18 MiB (18 instances) L3 cache: 24.8 MiB (1 instance) NUMA node(s): 1 NUMA node0 CPU(s): 0-35 Vulnerability Gather data sampling: Vulnerable: No microcode Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT vulnerable Vulnerability Retbleed: Mitigation; Enhanced IBRS Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Mitigation; TSX disabled
Versions of relevant libraries: [pip3] executorch==0.3.0a0 [pip3] numpy==1.26.4 [pip3] torch==2.4.0a0+git9afe4ec [pip3] torchao==0.1 [pip3] torchaudio==2.4.0.dev20240618+cpu [pip3] torchsr==1.0.4 [pip3] torchvision==0.20.0.dev20240618+cpu [conda] Could not collect