intel / intel-xpu-backend-for-triton

OpenAI Triton backend for Intel® GPUs
MIT License
124 stars 35 forks source link

[Accuracy] `PassManager::run failed` in Inductor HuggingFace inference accuracy check #412

Closed ESI-SYD closed 7 months ago

ESI-SYD commented 7 months ago

Error message:

Testing model AllenaiLongformerBase
Test amp with dt: torch.bfloat16
loading model: 0it [00:02, ?it/s]
xpu  eval  AllenaiLongformerBase              
skipping cudagraphs for unknown reason
ERROR:common:backend='inductor' raised:
RuntimeError: PassManager::run failed

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
Traceback (most recent call last):
  File "/home/pytorch/benchmarks/dynamo/common.py", line 2145, in check_accuracy
    new_result = optimized_model_iter_fn(model_copy, example_inputs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/home/pytorch/benchmarks/dynamo/common.py", line 1909, in run_n_iterations
    self.model_iter_fn(mod, inputs, collect_outputs=False)
  File "/home/pytorch/benchmarks/dynamo/huggingface.py", line 550, in forward_pass
    return mod(**inputs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/longformer/modeling_longformer.py", line 1835, in forward
    outputs = self.longformer(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/longformer/modeling_longformer.py", line 1738, in forward
    encoder_outputs = self.encoder(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 490, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 641, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 133, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 389, in _convert_frame_assert
    return _compile(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 569, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 491, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 458, in transform
    tracer.run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2074, in run
    super().run()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 439, in wrapper
    self.output.compile_subgraph(self, reason=reason)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 857, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/opt/conda/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 957, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1024, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1009, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/backends/inductor.py", line 9, in inductor
    return compile_fx(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1150, in compile_fx
    return aot_autograd(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 55, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 3891, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 3429, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2212, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2392, in aot_wrapper_synthetic_base
    return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
  File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1573, in aot_dispatch_base
    compiled_fw = compiler(fw_module, flat_args)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 924, in fw_compiler_freezing
    optimized_function = inner_compile(
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/repro/after_aot.py", line 80, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/debug.py", line 228, in inner
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 54, in newFunction
    return old_func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 341, in compile_fx_inner
    compiled_graph: CompiledFxGraph = fx_codegen_and_compile(
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 565, in fx_codegen_and_compile
    compiled_fn = graph.compile_to_fn()
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 970, in compile_to_fn
    res = self.compile_to_module()
  File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 189, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 941, in compile_to_module
    mod = PyCodeCache.load_by_key_path(key, path, linemap=linemap)
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 1143, in load_by_key_path
    exec(code, mod.__dict__, mod.__dict__)
  File "/tmp/torchinductor_root/xw/cxwvz3wbualzm5vh4femvd5f63atoldhgprfgnz5fiosbaipgb3f.py", line 74, in <module>
    async_compile.wait(globals())
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 1431, in wait
    scope[key] = result.result()
  File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 1290, in result
    self.future.result()
  File "/opt/conda/lib/python3.10/concurrent/futures/_base.py", line 458, in result
    return self.__get_result()
  File "/opt/conda/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: PassManager::run failed

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

TorchDynamo optimized model failed to run because of following error
fail_to_run

Reproduce:

cd /path/to/pytorch
wget -O inductor_xpu_test.sh https://raw.githubusercontent.com/intel/intel-xpu-backend-for-triton/main/.github/scripts/inductor_xpu_test.sh
pip install pandas
bash inductor_xpu_test.sh huggingface amp_bf16 inference accuracy xpu 1 static 1 0 AllenaiLongformerBase

Version:

root@a4bf01946f13:/home# python collect_env.py 
Collecting environment information...
PyTorch version: 2.1.0a0+git8a1575b
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      52 bits physical, 57 bits virtual
Byte Order:                         Little Endian
CPU(s):                             224
On-line CPU(s) list:                0-223
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Xeon(R) Platinum 8480+
CPU family:                         6
Model:                              143
Thread(s) per core:                 2
Core(s) per socket:                 56
Socket(s):                          2
Stepping:                           8
CPU max MHz:                        3800.0000
CPU min MHz:                        800.0000
BogoMIPS:                           4000.00
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
Virtualization:                     VT-x
L1d cache:                          5.3 MiB (112 instances)
L1i cache:                          3.5 MiB (112 instances)
L2 cache:                           224 MiB (112 instances)
L3 cache:                           210 MiB (2 instances)
NUMA node(s):                       2
NUMA node0 CPU(s):                  0-55,112-167
NUMA node1 CPU(s):                  56-111,168-223
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] bert-pytorch==0.0.1a4
[pip3] clip-anytorch==2.6.0
[pip3] CoCa-pytorch==0.1.0
[pip3] dalle2-pytorch==1.14.2
[pip3] ema-pytorch==0.3.3
[pip3] flake8==7.0.0
[pip3] functorch==1.14.0a0+b71aa0b
[pip3] intel-extension-for-pytorch==2.1.10+git99b4297
[pip3] mypy==1.8.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.23.5
[pip3] onnx==1.15.0
[pip3] open-clip-torch==2.24.0
[pip3] pytorch-warmup==0.1.1
[pip3] rotary-embedding-torch==0.3.3
[pip3] torch==2.1.0a0+git59f7c41
[pip3] torch-fidelity==0.3.0
[pip3] torch_geometric==2.4.0
[pip3] torchaudio==2.2.0a0+02586da
[pip3] torchbench==0.1
[pip3] torchdata==0.7.1
[pip3] torchmetrics==1.0.3
[pip3] torchmultimodal==0.1.0b0
[pip3] torchrec==0.6.0
[pip3] torchtext==0.17.0a0+2c5e344
[pip3] torchvision==0.18.0a0+806dba6
[pip3] triton==3.0.0
[pip3] vector_quantize_pytorch==1.12.17
[conda] bert-pytorch              0.0.1a4                   dev_0    <develop>
[conda] blas                      1.0                         mkl  
[conda] clip-anytorch             2.6.0                    pypi_0    pypi
[conda] coca-pytorch              0.1.0                    pypi_0    pypi
[conda] dalle2-pytorch            1.14.2                   pypi_0    pypi
[conda] ema-pytorch               0.3.3                    pypi_0    pypi
[conda] functorch                 1.14.0a0+b71aa0b          pypi_0    pypi
[conda] intel-extension-for-pytorch 2.1.10+git99b4297          pypi_0    pypi
[conda] mkl                       2023.1.0         h213fc3f_46344  
[conda] mkl-service               2.4.0           py310h5eee18b_1  
[conda] mkl_fft                   1.3.8           py310h5eee18b_0  
[conda] mkl_random                1.2.4           py310hdb19cb5_0  
[conda] numpy                     1.23.5                   pypi_0    pypi
[conda] open-clip-torch           2.24.0                   pypi_0    pypi
[conda] pytorch-warmup            0.1.1                    pypi_0    pypi
[conda] rotary-embedding-torch    0.3.3                    pypi_0    pypi
[conda] torch                     2.1.0a0+git59f7c41          pypi_0    pypi
[conda] torch-fidelity            0.3.0                    pypi_0    pypi
[conda] torch-geometric           2.4.0                    pypi_0    pypi
[conda] torchaudio                2.2.0a0+02586da          pypi_0    pypi
[conda] torchbench                0.1                       dev_0    <develop>
[conda] torchdata                 0.7.1                    pypi_0    pypi
[conda] torchmetrics              1.0.3                    pypi_0    pypi
[conda] torchmultimodal           0.1.0b0                  pypi_0    pypi
[conda] torchrec                  0.6.0                    pypi_0    pypi
[conda] torchtext                 0.17.0a0+2c5e344          pypi_0    pypi
[conda] torchvision               0.18.0a0+806dba6          pypi_0    pypi
[conda] triton                    3.0.0                    pypi_0    pypi
[conda] vector-quantize-pytorch   1.12.17                  pypi_0    pypi

triton: https://github.com/intel/intel-xpu-backend-for-triton/commit/97ac4f91d149a3392d6e14f5d39aa4953fb6c56e

quintinwang5 commented 7 months ago

For AllenaiLongformerBase, the error: failed to legalize operation 'tt.reduce' that was explicitly marked illegal is led by the insertion of nvvm.redux.sync, which is illegal in the target.

  %269 = "llvm.mlir.constant"() <{value = -1 : i32}> : () -> i32
  %270 = "nvvm.redux.sync"(%268, %269) <{kind = #nvvm<redux_kind or>}> : (i32, i32) -> i32
  %271 = "llvm.mlir.constant"() <{value = 0 : i32}> : () -> i32