Open channingxiao18 opened 5 months ago
cc @bottler @sgrigory maybe this initialization can be done lazily?
I added some print to jit.py
if not re.search(r"^def\s+\w+\s*\(", self.src, re.MULTILINE):
print(f"\n\n========src=========\n|{self.src}|\n\n=============\n")
print(f"\n=========fn=========\n|{fn}||{fn.__name__}|\n\n=============\n")
print(f"\n=========module=====\n|{fn.__module__}|\n\n=============\n")
and the outputs are
========src========= |d|
=============
=========fn========= |<function _fwd_kernel_splitK at 0xfffea4187370>||_fwd_kernel_splitK|
=============
=========module===== |xformers.ops.fmha.triton_splitk|
=============
It seams that the problem was related to _fwd_kernel_splitK
As a workaround, and if you don't need any triton kernel from xformers, you can try setting this env variable: XFORMERS_FORCE_DISABLE_TRITON=1
As a workaround, and if you don't need any triton kernel from xformers, you can try setting this env variable:
XFORMERS_FORCE_DISABLE_TRITON=1
this workaround did work.
I am using vllm to run LLM, I am not sure if it needs triton kernel from xformers. Thanks anyway.
I did some investigation. I add some prints to /xformers/triton/vararg_kernel.py
@functools.lru_cache(None)
def unroll_varargs(kernel, N: int):
"""
Specializes a triton kernel with variable number of inputs
to a specific number of inputs `N`.
NOTE: Because it's quite costly to call `triton.jit`,
we cache the returned value with `lru_cache`
"""
global _FILENAME_TO_SRC, _getlines_orig
print(f"********\n|{kernel}|\n|{kernel.fn}|\n**********\n")
k = triton.JITFunction(kernel.fn)
#print(f"***k:***\n|{k}|\n****k.src***\n{k.src}\n*******")
parsed = ast.parse(k.src)
print(f"******parsed\n|{parsed}|\n******")
nodeVisitor = _VisitorUnrollKernel(N=N)
parsed = nodeVisitor.visit(parsed)
parsed = ast.fix_missing_locations(parsed)
# NOTE: `ast.unparse` requires python 3.9+
if (sys.version_info.major, sys.version_info.minor) <= (3, 8):
raise RuntimeError("Error: This functionality requires python 3.9 or above")
new_src = ast.unparse(parsed) # type: ignore
#print(f"\n***New Source Code:\n{new_src}\n***\n")
# Now we want to `eval` the function, but we need all this
# boilerplate code to make sure triton can run `inspect.getsource`
fn_filename = f"<unroll_varargs-{kernel.fn.__name__}-{N}>"
# Create function given source
code = compile(new_src, fn_filename, "exec")
_locals: Dict[str, Any] = {}
exec(code, kernel.fn.__globals__, _locals)
assert len(_locals) == 1, len(_locals)
fn = next(iter(_locals.values()))
print(f"\n***Local Function Names:\n{_locals.keys()}\n***\n")
# Patch `getlines` only the first time
if not _FILENAME_TO_SRC:
_getlines_orig = linecache.getlines
linecache.getlines = _monkey_patched_getlines
# 放在`linecache.getlines = _monkey_patched_getlines`之后
print(f"\n***Is linecache.getlines patched? {'Yes' if linecache.getlines is _monkey_patched_getlines else 'No'}\n***\n")
_FILENAME_TO_SRC[fn_filename] = new_src
print(f"\n***FN_Filename:\n{fn_filename}\n***\n")
print(f"\n***_FILENAME_TO_SRC Keys:\n{_FILENAME_TO_SRC.keys()}\n***\n")
print(f"*****fn*****\n|{fn}|\n******")
import inspect
#print(f"******src**\n|{fn}|\n*****")
lines = linecache.getlines(fn_filename)
print(f"\n***Lines from {fn_filename}:\n{''.join(lines)}\n***\n")
fn_source = inspect.getsource(fn)
print(f"\n\n***Function Source via inspect|{fn_source}|\n\n======\n")
print(f"\n\n***getsource**|{inspect.getsource(fn)}|\n\n======\n")
print(f"\n***Direct Source Retrieval:\n{_FILENAME_TO_SRC.get(fn.__code__.co_filename, 'Source not found')}\n***\n")
jitted_fn = triton.jit(fn)
jitted_fn.src = new_src
return jitted_fn
and the print outputs are ***New Source Code: def _fwd_kernel_splitK(Q, K, V, sm_scale, Out_splitK, LSE_splitk, block_tables, Seq_len, Seq_starts, additive_bias, stride_qz, stride_qm, stride_qg, stride_qh, stride_qk, stride_kz, stride_kn, stride_kg, stride_kh, stride_kk, stride_vz, stride_vn, stride_vg, stride_vh, stride_vk, stride_osk_z, stride_osk_g, stride_osk_h, stride_osk_s, stride_osk_m, stride_osk_k, stride_lsek_z, stride_lsek_g, stride_lsek_h, stride_lsek_s, stride_lsek_m, stride_blocktablesz, stride_blocktablesl, stride_bias_b, stride_bias_g, stride_bias_h, stride_bias_qm, stride_bias_km, kv_cache_blocks_per_row: tl.constexpr, Z: tl.constexpr, N_CTX_Q: tl.constexpr, N_CTX_K: tl.constexpr, BLOCK_N_PER_SPLIT: tl.constexpr, H: tl.constexpr, G: tl.constexpr, BLOCK_DMODEL: tl.constexpr, USE_SEQ_LEN: tl.constexpr, PACKED_PER_VAL: tl.constexpr, N_GROUPS: tl.constexpr, BOUNDS_CHECKS_N: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, IS_SPLITK: tl.constexpr, IS_CAUSAL: tl.constexpr, NUM_QUERIES_CAUSAL: tl.constexpr, USE_PAGED_ATTENTION: tl.constexpr, PAGE_SIZE: tl.constexpr, WRITE_LSE: tl.constexpr, HAS_ADDITIVE_BIAS: tl.constexpr): xxx
***Local Function Names: dict_keys(['_fwd_kernel_splitK'])
***Is linecache.getlines patched? Yes
***FN_Filename:
🐛 Bug
Command
python3 -m xformers.info
/usr/local/lib/python3.10/dist-packages/transformers/utils/hub.py:124: FutureWarning: Using
from . import version, _cpp_lib, _is_opensource, _is_triton_available, ops
File "/usr/local/lib/python3.10/dist-packages/xformers/ops/init.py", line 8, in
from .fmha import (
File "/usr/local/lib/python3.10/dist-packages/xformers/ops/fmha/init.py", line 10, in
from . import (
File "/usr/local/lib/python3.10/dist-packages/xformers/ops/fmha/triton_splitk.py", line 548, in
_get_splitk_kernel(num_groups)
File "/usr/local/lib/python3.10/dist-packages/xformers/ops/fmha/triton_splitk.py", line 503, in _get_splitk_kernel
_fwd_kernel_splitK_unrolled = unroll_varargs(_fwd_kernel_splitK, N=num_groups)
File "/usr/local/lib/python3.10/dist-packages/xformers/triton/vararg_kernel.py", line 166, in unroll_varargs
jitted_fn = triton.jit(fn)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 864, in jit
return decorator(fn)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 853, in decorator
return JITFunction(
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 699, in init
self.src = self.src[re.search(r"^def\s+\w+\s*(", self.src, re.MULTILINE).start():]
AttributeError: 'NoneType' object has no attribute 'start'
TRANSFORMERS_CACHE
is deprecated and will be removed in v5 of Transformers. UseHF_HOME
instead. warnings.warn( Traceback (most recent call last): File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/usr/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/usr/local/lib/python3.10/dist-packages/xformers/info.py", line 11, inTo Reproduce
I complied xformers on jetson orin agx with the following command:
git clone --branch=v0.0.26 --depth=1 --recursive https://github.com/facebookresearch/xformers /opt/xformers
cd /opt/xformers
XFORMERS_MORE_DETAILS=1 MAX_JOBS=$(nproc) \ python3 setup.py --verbose bdist_wheel --dist-dir /opt pip3 install --no-cache-dir --verbose /opt/xformers*.whl
Environment
Please copy and paste the output from the environment collection script from PyTorch (or fill out the checklist below manually).
You can run the script with:
Collecting environment information... PyTorch version: 2.2.0 Is debug build: False CUDA used to build PyTorch: 12.2 ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.3 LTS (aarch64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: Could not collect CMake version: version 3.29.5 Libc version: glibc-2.35
Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime) Python platform: Linux-5.15.136-tegra-aarch64-with-glibc2.35 Is CUDA available: True CUDA runtime version: 12.2.140 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: Orin (nvgpu) Nvidia driver version: N/A cuDNN version: Probably one of the following: /usr/lib/aarch64-linux-gnu/libcudnn.so.8.9.4 /usr/lib/aarch64-linux-gnu/libcudnn_adv_infer.so.8.9.4 /usr/lib/aarch64-linux-gnu/libcudnn_adv_train.so.8.9.4 /usr/lib/aarch64-linux-gnu/libcudnn_cnn_infer.so.8.9.4 /usr/lib/aarch64-linux-gnu/libcudnn_cnn_train.so.8.9.4 /usr/lib/aarch64-linux-gnu/libcudnn_ops_infer.so.8.9.4 /usr/lib/aarch64-linux-gnu/libcudnn_ops_train.so.8.9.4 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True
CPU: Architecture: aarch64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian CPU(s): 12 On-line CPU(s) list: 0-11 Vendor ID: ARM Model name: Cortex-A78AE Model: 1 Thread(s) per core: 1 Core(s) per cluster: 4 Socket(s): - Cluster(s): 3 Stepping: r0p1 CPU max MHz: 2201.6001 CPU min MHz: 115.2000 BogoMIPS: 62.50 Flags: fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm lrcpc dcpop asimddp uscat ilrcpc flagm paca pacg L1d cache: 768 KiB (12 instances) L1i cache: 768 KiB (12 instances) L2 cache: 3 MiB (12 instances) L3 cache: 6 MiB (3 instances) NUMA node(s): 1 NUMA node0 CPU(s): 0-11 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; __user pointer sanitization Vulnerability Spectre v2: Mitigation; CSV2, but not BHB Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected
Versions of relevant libraries: [pip3] numpy==1.26.4 [pip3] onnx==1.16.1 [pip3] onnx-graphsurgeon==0.3.12 [pip3] torch==2.2.0 [pip3] torchvision==0.17.2+c1d70fe [pip3] triton==3.0.0 [conda] Could not collect
Additional context
triton was install from jetpack. I also tested it compiled from source, and the results was the same.