Open cathalobrien opened 1 week ago
We don't support FlexAttention on CPUs today. cc: @jgong5
We don't support FlexAttention on CPUs today. cc: @jgong5
Right, and we have plan to support that, assigned to @Valentine233
That being said until we support this, we should make the error message better. I will put up a PR.
🐛 Describe the bug
I'm trying to run flex attention on a CPU but I'm getting an error. Seems to be during some autotuning algorithm selection during the initial iteration. There are no possible choices so it throws an error and suggests adding aten to 'max_autotune_gemm_backends', but it is there by default as well as CPP. I tried disabling autotuning with torch._inductor.config.max_autotune = False (because I gather it is not availible on CPU yet) but that didnt help.
Is flex attn supported on CPU?
The error message is the same as https://github.com/pytorch/pytorch/issues/135206, has the fix been merged into nightly yet? I'm running todays nightly cpu build.
Repro
Small source code change
Note: to get here I had to edit _get_default_config_fwd() in _inductor/kernel/flex_attention.py because otherwise I was getting an error at "torch.cuda.get_device_capability()": AssertionError: Torch not compiled with CUDA enabled. So I wrapped the reference in " torch.cuda.is_available()"
repro
error message
Versions
(aifs-cpu) [naco@ac1-3010 my_anemoi-models]$ python collect_env.py Collecting environment information... PyTorch version: 2.6.0.dev20240923+cpu Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A
OS: Red Hat Enterprise Linux release 8.8 (Ootpa) (x86_64) GCC version: (ECMWF) 14.1.0 Clang version: 15.0.7 (Red Hat 15.0.7-1.module+el8.8.0+17939+b58878af) CMake version: version 3.30.2 Libc version: glibc-2.28
Python version: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0] (64-bit runtime) Python platform: Linux-4.18.0-477.43.1.el8_8.x86_64-x86_64-with-glibc2.28 Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True
CPU: 14:44:48 [0/23542] Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian CPU(s): 256 On-line CPU(s) list: 0-255 Thread(s) per core: 2 Core(s) per socket: 64 Socket(s): 2 NUMA node(s): 8 Vendor ID: AuthenticAMD CPU family: 23 Model: 49 Model name: AMD EPYC 7742 64-Core Processor Stepping: 0 CPU MHz: 2250.000 CPU max MHz: 2250.0000 CPU min MHz: 1500.0000 BogoMIPS: 4500.00 Virtualization: AMD-V L1d cache: 32K L1i cache: 32K L2 cache: 512K L3 cache: 16384K NUMA node0 CPU(s): 0-15,128-143 NUMA node1 CPU(s): 16-31,144-159 NUMA node2 CPU(s): 32-47,160-175 NUMA node3 CPU(s): 48-63,176-191 NUMA node4 CPU(s): 64-79,192-207 NUMA node5 CPU(s): 80-95,208-223 NUMA node6 CPU(s): 96-111,224-239 NUMA node7 CPU(s): 112-127,240-255 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sme sev sev_es
Versions of relevant libraries: [pip3] numpy==1.26.4 [pip3] optree==0.10.0 [pip3] pytorch-lightning==2.4.0 [pip3] torch==2.6.0.dev20240923+cpu [pip3] torch_geometric==2.4.0 [pip3] torchaudio==2.5.0.dev20240829+cpu [pip3] torchmetrics==1.4.1 [pip3] torchvision==0.20.0.dev20240923+cpu [pip3] triton==3.0.0 [conda] numpy 1.26.4 pypi_0 pypi [conda] pytorch-lightning 2.4.0 pypi_0 pypi [conda] torch 2.6.0.dev20240923+cpu pypi_0 pypi [conda] torch-geometric 2.4.0 pypi_0 pypi [conda] torchaudio 2.5.0.dev20240829+cpu pypi_0 pypi [conda] torchmetrics 1.4.1 pypi_0 pypi [conda] torchvision 0.20.0.dev20240923+cpu pypi_0 pypi [conda] triton 3.0.0 pypi_0 pypi
cc @ezyang @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng