facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8.7k stars 621 forks source link

python3 -m xformers.info got AttributeError: 'NoneType' object has no attribute 'start' #1059

Open channingxiao18 opened 5 months ago

channingxiao18 commented 5 months ago

🐛 Bug

Command

python3 -m xformers.info

/usr/local/lib/python3.10/dist-packages/transformers/utils/hub.py:124: FutureWarning: Using TRANSFORMERS_CACHE is deprecated and will be removed in v5 of Transformers. Use HF_HOME instead. warnings.warn( Traceback (most recent call last): File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/usr/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/usr/local/lib/python3.10/dist-packages/xformers/info.py", line 11, in from . import version, _cpp_lib, _is_opensource, _is_triton_available, ops File "/usr/local/lib/python3.10/dist-packages/xformers/ops/init.py", line 8, in from .fmha import ( File "/usr/local/lib/python3.10/dist-packages/xformers/ops/fmha/init.py", line 10, in from . import ( File "/usr/local/lib/python3.10/dist-packages/xformers/ops/fmha/triton_splitk.py", line 548, in _get_splitk_kernel(num_groups) File "/usr/local/lib/python3.10/dist-packages/xformers/ops/fmha/triton_splitk.py", line 503, in _get_splitk_kernel _fwd_kernel_splitK_unrolled = unroll_varargs(_fwd_kernel_splitK, N=num_groups) File "/usr/local/lib/python3.10/dist-packages/xformers/triton/vararg_kernel.py", line 166, in unroll_varargs jitted_fn = triton.jit(fn) File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 864, in jit return decorator(fn) File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 853, in decorator return JITFunction( File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 699, in init self.src = self.src[re.search(r"^def\s+\w+\s*(", self.src, re.MULTILINE).start():] AttributeError: 'NoneType' object has no attribute 'start'

To Reproduce

I complied xformers on jetson orin agx with the following command:

git clone --branch=v0.0.26 --depth=1 --recursive https://github.com/facebookresearch/xformers /opt/xformers

cd /opt/xformers

XFORMERS_MORE_DETAILS=1 MAX_JOBS=$(nproc) \ python3 setup.py --verbose bdist_wheel --dist-dir /opt pip3 install --no-cache-dir --verbose /opt/xformers*.whl

Environment

Please copy and paste the output from the environment collection script from PyTorch (or fill out the checklist below manually).

You can run the script with:

# For security purposes, please check the contents of collect_env.py before running it.
python -m torch.utils.collect_env

Collecting environment information... PyTorch version: 2.2.0 Is debug build: False CUDA used to build PyTorch: 12.2 ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (aarch64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: Could not collect CMake version: version 3.29.5 Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime) Python platform: Linux-5.15.136-tegra-aarch64-with-glibc2.35 Is CUDA available: True CUDA runtime version: 12.2.140 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: Orin (nvgpu) Nvidia driver version: N/A cuDNN version: Probably one of the following: /usr/lib/aarch64-linux-gnu/libcudnn.so.8.9.4 /usr/lib/aarch64-linux-gnu/libcudnn_adv_infer.so.8.9.4 /usr/lib/aarch64-linux-gnu/libcudnn_adv_train.so.8.9.4 /usr/lib/aarch64-linux-gnu/libcudnn_cnn_infer.so.8.9.4 /usr/lib/aarch64-linux-gnu/libcudnn_cnn_train.so.8.9.4 /usr/lib/aarch64-linux-gnu/libcudnn_ops_infer.so.8.9.4 /usr/lib/aarch64-linux-gnu/libcudnn_ops_train.so.8.9.4 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: aarch64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian CPU(s): 12 On-line CPU(s) list: 0-11 Vendor ID: ARM Model name: Cortex-A78AE Model: 1 Thread(s) per core: 1 Core(s) per cluster: 4 Socket(s): - Cluster(s): 3 Stepping: r0p1 CPU max MHz: 2201.6001 CPU min MHz: 115.2000 BogoMIPS: 62.50 Flags: fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm lrcpc dcpop asimddp uscat ilrcpc flagm paca pacg L1d cache: 768 KiB (12 instances) L1i cache: 768 KiB (12 instances) L2 cache: 3 MiB (12 instances) L3 cache: 6 MiB (3 instances) NUMA node(s): 1 NUMA node0 CPU(s): 0-11 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; __user pointer sanitization Vulnerability Spectre v2: Mitigation; CSV2, but not BHB Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] numpy==1.26.4 [pip3] onnx==1.16.1 [pip3] onnx-graphsurgeon==0.3.12 [pip3] torch==2.2.0 [pip3] torchvision==0.17.2+c1d70fe [pip3] triton==3.0.0 [conda] Could not collect

Additional context

triton was install from jetpack. I also tested it compiled from source, and the results was the same.

danthe3rd commented 5 months ago

cc @bottler @sgrigory maybe this initialization can be done lazily?

channingxiao18 commented 5 months ago

I added some print to jit.py

    if not re.search(r"^def\s+\w+\s*\(", self.src, re.MULTILINE):
        print(f"\n\n========src=========\n|{self.src}|\n\n=============\n")
        print(f"\n=========fn=========\n|{fn}||{fn.__name__}|\n\n=============\n")
        print(f"\n=========module=====\n|{fn.__module__}|\n\n=============\n")

and the outputs are

========src========= |d|

=============

=========fn========= |<function _fwd_kernel_splitK at 0xfffea4187370>||_fwd_kernel_splitK|

=============

=========module===== |xformers.ops.fmha.triton_splitk|

=============

It seams that the problem was related to _fwd_kernel_splitK

danthe3rd commented 5 months ago

As a workaround, and if you don't need any triton kernel from xformers, you can try setting this env variable: XFORMERS_FORCE_DISABLE_TRITON=1

channingxiao18 commented 5 months ago

As a workaround, and if you don't need any triton kernel from xformers, you can try setting this env variable: XFORMERS_FORCE_DISABLE_TRITON=1

this workaround did work.

I am using vllm to run LLM, I am not sure if it needs triton kernel from xformers. Thanks anyway.

channingxiao18 commented 5 months ago

I did some investigation. I add some prints to /xformers/triton/vararg_kernel.py

@functools.lru_cache(None)
def unroll_varargs(kernel, N: int):
    """
    Specializes a triton kernel with variable number of inputs
    to a specific number of inputs `N`.
    NOTE: Because it's quite costly to call `triton.jit`,
    we cache the returned value with `lru_cache`
    """
    global _FILENAME_TO_SRC, _getlines_orig
    print(f"********\n|{kernel}|\n|{kernel.fn}|\n**********\n")
    k = triton.JITFunction(kernel.fn)
    #print(f"***k:***\n|{k}|\n****k.src***\n{k.src}\n*******")
    parsed = ast.parse(k.src)
    print(f"******parsed\n|{parsed}|\n******")
    nodeVisitor = _VisitorUnrollKernel(N=N)
    parsed = nodeVisitor.visit(parsed)
    parsed = ast.fix_missing_locations(parsed)

    # NOTE: `ast.unparse` requires python 3.9+
    if (sys.version_info.major, sys.version_info.minor) <= (3, 8):
        raise RuntimeError("Error: This functionality requires python 3.9 or above")
    new_src = ast.unparse(parsed)  # type: ignore
    #print(f"\n***New Source Code:\n{new_src}\n***\n")
    # Now we want to `eval` the function, but we need all this
    # boilerplate code to make sure triton can run `inspect.getsource`

    fn_filename = f"<unroll_varargs-{kernel.fn.__name__}-{N}>"

    # Create function given source
    code = compile(new_src, fn_filename, "exec")

    _locals: Dict[str, Any] = {}
    exec(code, kernel.fn.__globals__, _locals)
    assert len(_locals) == 1, len(_locals)
    fn = next(iter(_locals.values()))
    print(f"\n***Local Function Names:\n{_locals.keys()}\n***\n")
    # Patch `getlines` only the first time
    if not _FILENAME_TO_SRC:
        _getlines_orig = linecache.getlines
        linecache.getlines = _monkey_patched_getlines
        # 放在`linecache.getlines = _monkey_patched_getlines`之后
        print(f"\n***Is linecache.getlines patched? {'Yes' if linecache.getlines is _monkey_patched_getlines else 'No'}\n***\n")
    _FILENAME_TO_SRC[fn_filename] = new_src
    print(f"\n***FN_Filename:\n{fn_filename}\n***\n")
    print(f"\n***_FILENAME_TO_SRC Keys:\n{_FILENAME_TO_SRC.keys()}\n***\n")
    print(f"*****fn*****\n|{fn}|\n******")
    import inspect
    #print(f"******src**\n|{fn}|\n*****")
    lines = linecache.getlines(fn_filename)
     print(f"\n***Lines from {fn_filename}:\n{''.join(lines)}\n***\n")

    fn_source = inspect.getsource(fn)
    print(f"\n\n***Function Source via inspect|{fn_source}|\n\n======\n")
    print(f"\n\n***getsource**|{inspect.getsource(fn)}|\n\n======\n")
    print(f"\n***Direct Source Retrieval:\n{_FILENAME_TO_SRC.get(fn.__code__.co_filename, 'Source not found')}\n***\n")
    jitted_fn = triton.jit(fn)
    jitted_fn.src = new_src
    return jitted_fn

and the print outputs are ***New Source Code: def _fwd_kernel_splitK(Q, K, V, sm_scale, Out_splitK, LSE_splitk, block_tables, Seq_len, Seq_starts, additive_bias, stride_qz, stride_qm, stride_qg, stride_qh, stride_qk, stride_kz, stride_kn, stride_kg, stride_kh, stride_kk, stride_vz, stride_vn, stride_vg, stride_vh, stride_vk, stride_osk_z, stride_osk_g, stride_osk_h, stride_osk_s, stride_osk_m, stride_osk_k, stride_lsek_z, stride_lsek_g, stride_lsek_h, stride_lsek_s, stride_lsek_m, stride_blocktablesz, stride_blocktablesl, stride_bias_b, stride_bias_g, stride_bias_h, stride_bias_qm, stride_bias_km, kv_cache_blocks_per_row: tl.constexpr, Z: tl.constexpr, N_CTX_Q: tl.constexpr, N_CTX_K: tl.constexpr, BLOCK_N_PER_SPLIT: tl.constexpr, H: tl.constexpr, G: tl.constexpr, BLOCK_DMODEL: tl.constexpr, USE_SEQ_LEN: tl.constexpr, PACKED_PER_VAL: tl.constexpr, N_GROUPS: tl.constexpr, BOUNDS_CHECKS_N: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, IS_SPLITK: tl.constexpr, IS_CAUSAL: tl.constexpr, NUM_QUERIES_CAUSAL: tl.constexpr, USE_PAGED_ATTENTION: tl.constexpr, PAGE_SIZE: tl.constexpr, WRITE_LSE: tl.constexpr, HAS_ADDITIVE_BIAS: tl.constexpr): xxx

***Local Function Names: dict_keys(['_fwd_kernel_splitK'])


***Is linecache.getlines patched? Yes


***FN_Filename:

*** ***_FILENAME_TO_SRC Keys: dict_keys(['']) *** *****fn***** || ****** ***Lines from : def _fwd_kernel_splitK(Q, K, V, sm_scale, Out_splitK, LSE_splitk, block_tables, Seq_len, Seq_starts, additive_bias, stride_qz, stride_qm, stride_qg, stride_qh, stride_qk, stride_kz, stride_kn, stride_kg, stride_kh, stride_kk, stride_vz, stride_vn, stride_vg, stride_vh, stride_vk, stride_osk_z, stride_osk_g, stride_osk_h, stride_osk_s, stride_osk_m, stride_osk_k, stride_lsek_z, stride_lsek_g, stride_lsek_h, stride_lsek_s, stride_lsek_m, stride_blocktablesz, stride_blocktablesl, stride_bias_b, stride_bias_g, stride_bias_h, stride_bias_qm, stride_bias_km, kv_cache_blocks_per_row: tl.constexpr, Z: tl.constexpr, N_CTX_Q: tl.constexpr, N_CTX_K: tl.constexpr, BLOCK_N_PER_SPLIT: tl.constexpr, H: tl.constexpr, G: tl.constexpr, BLOCK_DMODEL: tl.constexpr, USE_SEQ_LEN: tl.constexpr, PACKED_PER_VAL: tl.constexpr, N_GROUPS: tl.constexpr, BOUNDS_CHECKS_N: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, IS_SPLITK: tl.constexpr, IS_CAUSAL: tl.constexpr, NUM_QUERIES_CAUSAL: tl.constexpr, USE_PAGED_ATTENTION: tl.constexpr, PAGE_SIZE: tl.constexpr, WRITE_LSE: tl.constexpr, HAS_ADDITIVE_BIAS: tl.constexpr): ***Function Source via inspect|d| ====== ***getsource**|d| ====== ***Direct Source Retrieval: def _fwd_kernel_splitK(Q, K, V, sm_scale, Out_splitK, LSE_splitk, block_tables, Seq_len, Seq_starts, additive_bias, stride_qz, stride_qm, stride_qg, stride_qh, stride_qk, stride_kz, stride_kn, stride_kg, stride_kh, stride_kk, stride_vz, stride_vn, stride_vg, stride_vh, stride_vk, stride_osk_z, stride_osk_g, stride_osk_h, stride_osk_s, stride_osk_m, stride_osk_k, stride_lsek_z, stride_lsek_g, stride_lsek_h, stride_lsek_s, stride_lsek_m, stride_blocktablesz, stride_blocktablesl, stride_bias_b, stride_bias_g, stride_bias_h, stride_bias_qm, stride_bias_km, kv_cache_blocks_per_row: tl.constexpr, Z: tl.constexpr, N_CTX_Q: tl.constexpr, N_CTX_K: tl.constexpr, BLOCK_N_PER_SPLIT: tl.constexpr, H: tl.constexpr, G: tl.constexpr, BLOCK_DMODEL: tl.constexpr, USE_SEQ_LEN: tl.constexpr, PACKED_PER_VAL: tl.constexpr, N_GROUPS: tl.constexpr, BOUNDS_CHECKS_N: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, IS_SPLITK: tl.constexpr, IS_CAUSAL: tl.constexpr, NUM_QUERIES_CAUSAL: tl.constexpr, USE_PAGED_ATTENTION: tl.constexpr, PAGE_SIZE: tl.constexpr, WRITE_LSE: tl.constexpr, HAS_ADDITIVE_BIAS: tl.constexpr): Based on the debug info, the problem part was "inspect.getsource(fn)". I guess maybe inspect.getsource is not working properly with dynamic defined or modified code ??