pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.68k stars 22.26k forks source link

Can't run Flex-Attention on CPU - NoValidChoicesError during autotuneSelectAlgorithm #136525

Open cathalobrien opened 1 week ago

cathalobrien commented 1 week ago

🐛 Describe the bug

I'm trying to run flex attention on a CPU but I'm getting an error. Seems to be during some autotuning algorithm selection during the initial iteration. There are no possible choices so it throws an error and suggests adding aten to 'max_autotune_gemm_backends', but it is there by default as well as CPP. I tried disabling autotuning with torch._inductor.config.max_autotune = False (because I gather it is not availible on CPU yet) but that didnt help.

Is flex attn supported on CPU?

The error message is the same as https://github.com/pytorch/pytorch/issues/135206, has the fix been merged into nightly yet? I'm running todays nightly cpu build.

Repro

Small source code change

Note: to get here I had to edit _get_default_config_fwd() in _inductor/kernel/flex_attention.py because otherwise I was getting an error at "torch.cuda.get_device_capability()": AssertionError: Torch not compiled with CUDA enabled. So I wrapped the reference in " torch.cuda.is_available()"

    if torch.cuda.is_available():
        if head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0):  # H100
            if dtype == torch.float32:
                default_config = (64, 64, 4, 3)
            else:
                default_config = (128, 64, 4, 3)
            default_config = _h100_default_config.get((dtype, head_dim), default_config)
        elif head_dim <= 256 and torch.cuda.get_device_capability() >= (8, 0):  # A100
            if dtype == torch.float32:
                default_config = (64, 64, 4, 3)
            else:
                default_config = (128, 64, 4, 3)
            default_config = _a100_default_config.get((dtype, head_dim), default_config)
    else:  # modest hardware or extremely large head_dim
        if dtype == torch.float32:
            default_config = (32, 16, 4, 3)
        else:
            default_config = (64, 32, 4, 3)

repro

import torch
import torch._inductor.config

from torch.nn.attention.flex_attention import flex_attention, create_block_mask

import functools

if __name__ == "__main__":

    seq_lens= {
        "o96" : 40320,
        "o32" : 5248,
    }

    num_channels=256 #1024 works, 256 doesnt at fp16 on 16 heads
    num_heads=4
    head_dim= num_channels // num_heads

    B, H, SEQ_LEN, HEAD_DIM = 1, num_heads, seq_lens['o96'], head_dim
    WINDOW_SIZE = 512
    PRECISION=torch.float32 
    DEVICE="cpu"

    FORWARD_ONLY=True

    def make_tensor():
        return torch.ones(B, H, SEQ_LEN, HEAD_DIM, device=DEVICE, dtype=PRECISION, requires_grad=True)

    q, k, v = make_tensor(), make_tensor(), make_tensor()
    gradOut = torch.ones(B, H, SEQ_LEN, HEAD_DIM, device=DEVICE, dtype=PRECISION)

    def sliding_window(b, h, q_idx, kv_idx):
        return torch.abs(q_idx - kv_idx) <= WINDOW_SIZE

    block_mask = create_block_mask(
        sliding_window, B=None, H=None, Q_LEN=SEQ_LEN, KV_LEN=SEQ_LEN, _compile=True, device=DEVICE

    )

    attention = functools.partial(flex_attention, block_mask=block_mask) #cache the block mask so its not remade

    attention = torch.compile(attention)
    print(f"Compiled correctly")

    # errors here, during initial autotuning
    out = attention(q, k, v, block_mask=block_mask)
    print(f"Shape of output tensor: {list(out.shape)}")

    if (not FORWARD_ONLY):
        out.backward(gradOut, retain_graph=True)
        print(f"Shape of output tensor after bw: {list(out.shape)}")

error message

Traceback (most recent call last):
  File "/lus/h2resw01/hpcperm/naco/aifs/anemoi/my_anemoi-models/mini_flexAttn_sw.py", line 58, in <module>
    out = attention(q, k, v, block_mask=block_mask)
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1298, in __call__
    return self._torchdynamo_orig_callable(
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1093, in __call__
    result = self._inner_convert(
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 530, in __call__
    return _compile(
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 933, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 675, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 708, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 220, in _fn
    return fn(*args, **kwargs)
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 643, in transform
    tracer.run()
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2776, in run
    super().run()
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 979, in run
    while self.step():
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 891, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2967, in RETURN_VALUE
    self._return(inst)
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2952, in _return
    self.output.compile_subgraph(
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1117, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1369, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1416, in call_user_compiler
    return self._call_user_compiler(gm)
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1465, in _call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1446, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 130, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/__init__.py", line 2235, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1562, in compile_fx
    return aot_autograd(
File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 72, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1076, in aot_module_simplified
    compiled_fn = dispatch_and_compile()
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1061, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 523, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 760, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 586, in aot_dispatch_autograd
    compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1388, in fw_compiler_base
    return _fw_compiler_base(model, example_inputs, is_inference)
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1459, in _fw_compiler_base
    return inner_compile(
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 480, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 666, in _compile_fx_inner
    compiled_graph = FxGraphCache.load(
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 1425, in load
    compiled_graph = compile_fx_fn(
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 575, in codegen_and_compile
    compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 864, in fx_codegen_and_compile
    graph.run(*example_inputs)
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_inductor/graph.py", line 780, in run
    return super().run(*args)
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/fx/interpreter.py", line 146, in run
    self.env[node] = self.run_node(node)
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1357, in run_node
    result = super().run_node(n)
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/fx/interpreter.py", line 203, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1023, in call_function
    raise LoweringException(e, target, args, kwargs).with_traceback(
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1020, in call_function
    out = lowerings[target](*args, **kwargs)  # type: ignore[index]
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_inductor/lowering.py", line 363, in wrapped
    out = decomp_fn(*args, **kwargs)
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_inductor/kernel/flex_attention.py", line 915, in flex_attention
    autotune_select_algorithm(
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 1729, in autotune_select_algorithm
    return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
  File "/perm/naco/conda/envs/aifs-cpu/lib/python3.10/site-packages/torch/_inductor/select_algorithm.py", line 1209, in __call__
    raise NoValidChoicesError(
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: NoValidChoicesError: No choices to select, please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice.
  target: flex_attention
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='primals_1', layout=FixedLayout('cpu', torch.float16, size=[1, 4, 40320, 64], stride=[10321920, 2580480, 64, 1]))
  ))
  args[1]: TensorBox(StorageBox(
    InputBuffer(name='primals_2', layout=FixedLayout('cpu', torch.float16, size=[1, 4, 40320, 64], stride=[10321920, 2580480, 64, 1]))
  ))
  args[2]: TensorBox(StorageBox(
    InputBuffer(name='primals_3', layout=FixedLayout('cpu', torch.float16, size=[1, 4, 40320, 64], stride=[10321920, 2580480, 64, 1]))
  ))
  args[3]: Subgraph(name='sdpa_score0', graph_module=<lambda>(), graph=None)
  args[4]: (TensorBox(StorageBox(
    InputBuffer(name='primals_4', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 315], stride=[315, 315, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_5', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 315, 315], stride=[99225, 99225, 315, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_6', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 315], stride=[315, 315, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_7', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 315, 315], stride=[99225, 99225, 315, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_8', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 315], stride=[315, 315, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_9', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 315, 315], stride=[99225, 99225, 315, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_10', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 315], stride=[315, 315, 1]))
  )), TensorBox(StorageBox(
    InputBuffer(name='primals_11', layout=FixedLayout('cpu', torch.int32, size=[1, 1, 315, 315], stride=[99225, 99225, 315, 1]))
  )), 128, 128, Subgraph(name='sdpa_mask0', graph_module=<lambda>(), graph=None))
  args[5]: 0.125
  args[6]: {'ROWS_GUARANTEED_SAFE': False, 'PRESCALE_QK': False, 'OUTPUT_LOGSUMEXP': True}
  args[7]: ()
  args[8]: ()

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Versions

(aifs-cpu) [naco@ac1-3010 my_anemoi-models]$ python collect_env.py Collecting environment information... PyTorch version: 2.6.0.dev20240923+cpu Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: Red Hat Enterprise Linux release 8.8 (Ootpa) (x86_64) GCC version: (ECMWF) 14.1.0 Clang version: 15.0.7 (Red Hat 15.0.7-1.module+el8.8.0+17939+b58878af) CMake version: version 3.30.2 Libc version: glibc-2.28

Python version: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0] (64-bit runtime) Python platform: Linux-4.18.0-477.43.1.el8_8.x86_64-x86_64-with-glibc2.28 Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: 14:44:48 [0/23542] Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian CPU(s): 256 On-line CPU(s) list: 0-255 Thread(s) per core: 2 Core(s) per socket: 64 Socket(s): 2 NUMA node(s): 8 Vendor ID: AuthenticAMD CPU family: 23 Model: 49 Model name: AMD EPYC 7742 64-Core Processor Stepping: 0 CPU MHz: 2250.000 CPU max MHz: 2250.0000 CPU min MHz: 1500.0000 BogoMIPS: 4500.00 Virtualization: AMD-V L1d cache: 32K L1i cache: 32K L2 cache: 512K L3 cache: 16384K NUMA node0 CPU(s): 0-15,128-143 NUMA node1 CPU(s): 16-31,144-159 NUMA node2 CPU(s): 32-47,160-175 NUMA node3 CPU(s): 48-63,176-191 NUMA node4 CPU(s): 64-79,192-207 NUMA node5 CPU(s): 80-95,208-223 NUMA node6 CPU(s): 96-111,224-239 NUMA node7 CPU(s): 112-127,240-255 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sme sev sev_es

Versions of relevant libraries: [pip3] numpy==1.26.4 [pip3] optree==0.10.0 [pip3] pytorch-lightning==2.4.0 [pip3] torch==2.6.0.dev20240923+cpu [pip3] torch_geometric==2.4.0 [pip3] torchaudio==2.5.0.dev20240829+cpu [pip3] torchmetrics==1.4.1 [pip3] torchvision==0.20.0.dev20240923+cpu [pip3] triton==3.0.0 [conda] numpy 1.26.4 pypi_0 pypi [conda] pytorch-lightning 2.4.0 pypi_0 pypi [conda] torch 2.6.0.dev20240923+cpu pypi_0 pypi [conda] torch-geometric 2.4.0 pypi_0 pypi [conda] torchaudio 2.5.0.dev20240829+cpu pypi_0 pypi [conda] torchmetrics 1.4.1 pypi_0 pypi [conda] torchvision 0.20.0.dev20240923+cpu pypi_0 pypi [conda] triton 3.0.0 pypi_0 pypi

cc @ezyang @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng

Chillee commented 1 week ago

We don't support FlexAttention on CPUs today. cc: @jgong5

jgong5 commented 1 week ago

We don't support FlexAttention on CPUs today. cc: @jgong5

Right, and we have plan to support that, assigned to @Valentine233

drisspg commented 1 week ago

That being said until we support this, we should make the error message better. I will put up a PR.