triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.19k stars 1.62k forks source link

LLVM ERROR: mma16816 data type not supported when invoking `tl.dot` with dequantized tensor #4652

Open shadowpa0327 opened 1 month ago

shadowpa0327 commented 1 month ago

Problem Statement

I am trying to dequantize the quantized tensor (packed into int32) and perform multiplication to another tensor in fp16. However, I observed a weird error: LLVM ERROR: mma16816 data type not supported when invoking tl.dot. When I further multiply a 1.0 to the dequantized tensor x * scales + zeros * 1.0 and downcast back to tl.float16, then the program can be executed properly.

It seems like this phenomenon only happens in triton==3.0.0. I have tried to downgrade the triton to 2.3.0, and it works well. Does anyone know some of the possible reasons behind this phenomenon or any potential bug in my implementation?

Dependency

Error message

LLVM ERROR: mma16816 data type not supported
 #0 0x000073e3ec533088 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) Signals.cpp:0:0
 #1 0x000073e3e8a6f880 triton_stacktrace_signal_handler(void*) /project/python/src/llvm.cc:424:3
 #2 0x000073e3ec530bac llvm::sys::RunSignalHandlers() Signals.cpp:0:0
 #3 0x000073e3ec53373d SignalHandler(int) Signals.cpp:0:0
 #4 0x000073e4f5042520 (/lib/x86_64-linux-gnu/libc.so.6+0x42520)
 #5 0x000073e4f50969fc __pthread_kill_implementation ./nptl/./nptl/pthread_kill.c:44:76
 #6 0x000073e4f50969fc __pthread_kill_internal ./nptl/./nptl/pthread_kill.c:78:10
 #7 0x000073e4f50969fc pthread_kill ./nptl/./nptl/pthread_kill.c:89:10
 #8 0x000073e4f5042476 gsignal ./signal/../sysdeps/posix/raise.c:27:6
 #9 0x000073e4f50287f3 abort ./stdlib/./stdlib/abort.c:81:7
#10 0x000073e3ec4975ac llvm::report_fatal_error(llvm::Twine const&, bool) ErrorHandling.cpp:0:0
#11 0x000073e3ec4973d6 (/home/cc2869/.conda/envs/Palu/lib/python3.10/site-packages/triton/_C/libtriton.so+0x6a973d6)
#12 0x000073e3e890155d /project/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp:500:5
#13 0x000073e3e8904257 getLoadMatrixFn(mlir::triton::MemDescType, mlir::LLVM::SharedMemoryObject const&, mlir::triton::gpu::NvidiaMmaEncodingAttr, int, unsigned int, int, llvm::SmallVector<int, 12u>, llvm::SmallVector<int, 12u>, llvm::SmallVector<mlir::Value, 6u>, mlir::Value, std::map<std::array<int, 3ul>, mlir::Value, std::less<std::array<int, 3ul> >, std::allocator<std::pair<std::array<int, 3ul> const, mlir::Value> > >&, bool, mlir::LLVMTypeConverter const*, mlir::ConversionPatternRewriter&, mlir::Location)::$_0::operator()(int, int, int) const /project/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp:562:19
#14 0x000073e3e8904257 void std::__invoke_impl<void, getLoadMatrixFn(mlir::triton::MemDescType, mlir::LLVM::SharedMemoryObject const&, mlir::triton::gpu::NvidiaMmaEncodingAttr, int, unsigned int, int, llvm::SmallVector<int, 12u>, llvm::SmallVector<int, 12u>, llvm::SmallVector<mlir::Value, 6u>, mlir::Value, std::map<std::array<int, 3ul>, mlir::Value, std::less<std::array<int, 3ul> >, std::allocator<std::pair<std::array<int, 3ul> const, mlir::Value> > >&, bool, mlir::LLVMTypeConverter const*, mlir::ConversionPatternRewriter&, mlir::Location)::$_0&, int, int, int>(std::__invoke_other, getLoadMatrixFn(mlir::triton::MemDescType, mlir::LLVM::SharedMemoryObject const&, mlir::triton::gpu::NvidiaMmaEncodingAttr, int, unsigned int, int, llvm::SmallVector<int, 12u>, llvm::SmallVector<int, 12u>, llvm::SmallVector<mlir::Value, 6u>, mlir::Value, std::map<std::array<int, 3ul>, mlir::Value, std::less<std::array<int, 3ul> >, std::allocator<std::pair<std::array<int, 3ul> const, mlir::Value> > >&, bool, mlir::LLVMTypeConverter const*, mlir::ConversionPatternRewriter&, mlir::Location)::$_0&, int&&, int&&, int&&) /opt/rh/gcc-toolset-13/root/usr/lib/gcc/x86_64-redhat-linux/13/../../../../include/c++/13/bits/invoke.h:61:14
#15 0x000073e3e8904257 std::enable_if<is_invocable_r_v<void, getLoadMatrixFn(mlir::triton::MemDescType, mlir::LLVM::SharedMemoryObject const&, mlir::triton::gpu::NvidiaMmaEncodingAttr, int, unsigned int, int, llvm::SmallVector<int, 12u>, llvm::SmallVector<int, 12u>, llvm::SmallVector<mlir::Value, 6u>, mlir::Value, std::map<std::array<int, 3ul>, mlir::Value, std::less<std::array<int, 3ul> >, std::allocator<std::pair<std::array<int, 3ul> const, mlir::Value> > >&, bool, mlir::LLVMTypeConverter const*, mlir::ConversionPatternRewriter&, mlir::Location)::$_0&, int, int, int>, void>::type std::__invoke_r<void, getLoadMatrixFn(mlir::triton::MemDescType, mlir::LLVM::SharedMemoryObject const&, mlir::triton::gpu::NvidiaMmaEncodingAttr, int, unsigned int, int, llvm::SmallVector<int, 12u>, llvm::SmallVector<int, 12u>, llvm::SmallVector<mlir::Value, 6u>, mlir::Value, std::map<std::array<int, 3ul>, mlir::Value, std::less<std::array<int, 3ul> >, std::allocator<std::pair<std::array<int, 3ul> const, mlir::Value> > >&, bool, mlir::LLVMTypeConverter const*, mlir::ConversionPatternRewriter&, mlir::Location)::$_0&, int, int, int>(getLoadMatrixFn(mlir::triton::MemDescType, mlir::LLVM::SharedMemoryObject const&, mlir::triton::gpu::NvidiaMmaEncodingAttr, int, unsigned int, int, llvm::SmallVector<int, 12u>, llvm::SmallVector<int, 12u>, llvm::SmallVector<mlir::Value, 6u>, mlir::Value, std::map<std::array<int, 3ul>, mlir::Value, std::less<std::array<int, 3ul> >, std::allocator<std::pair<std::array<int, 3ul> const, mlir::Value> > >&, bool, mlir::LLVMTypeConverter const*, mlir::ConversionPatternRewriter&, mlir::Location)::$_0&, int&&, int&&, int&&) /opt/rh/gcc-toolset-13/root/usr/lib/gcc/x86_64-redhat-linux/13/../../../../include/c++/13/bits/invoke.h:111:2
#16 0x000073e3e8904257 std::_Function_handler<void (int, int, int), getLoadMatrixFn(mlir::triton::MemDescType, mlir::LLVM::SharedMemoryObject const&, mlir::triton::gpu::NvidiaMmaEncodingAttr, int, unsigned int, int, llvm::SmallVector<int, 12u>, llvm::SmallVector<int, 12u>, llvm::SmallVector<mlir::Value, 6u>, mlir::Value, std::map<std::array<int, 3ul>, mlir::Value, std::less<std::array<int, 3ul> >, std::allocator<std::pair<std::array<int, 3ul> const, mlir::Value> > >&, bool, mlir::LLVMTypeConverter const*, mlir::ConversionPatternRewriter&, mlir::Location)::$_0>::_M_invoke(std::_Any_data const&, int&&, int&&, int&&) /opt/rh/gcc-toolset-13/root/usr/lib/gcc/x86_64-redhat-linux/13/../../../../include/c++/13/bits/std_function.h:290:9
#17 0x000073e3e890309c loadArg(mlir::ConversionPatternRewriter&, mlir::Location, mlir::triton::MemDescType, mlir::triton::gpu::DotOperandEncodingAttr, mlir::LLVM::SharedMemoryObject const&, mlir::LLVMTypeConverter const*, mlir::Value, bool) /project/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp:646:25
#18 0x000073e3e8903e49 SharedToDotOperandMMAv2::convertLayout(int, mlir::ConversionPatternRewriter&, mlir::Location, mlir::Value, mlir::triton::gpu::DotOperandEncodingAttr, mlir::LLVM::SharedMemoryObject const&, mlir::LLVMTypeConverter const*, mlir::Value) /project/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp:0:7
#19 0x000073e3e890d892 (anonymous namespace)::LocalLoadOpConversion::lowerSharedToDotOperandMMA(mlir::triton::gpu::LocalLoadOp, mlir::triton::gpu::LocalLoadOpAdaptor, mlir::LLVMTypeConverter const*, mlir::ConversionPatternRewriter&, mlir::triton::gpu::NvidiaMmaEncodingAttr const&, mlir::triton::gpu::DotOperandEncodingAttr const&, bool) const /project/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp:0:0
#20 0x000073e3e890d892 (anonymous namespace)::LocalLoadOpConversion::lowerSharedToDotOperand(mlir::triton::gpu::LocalLoadOp, mlir::triton::gpu::LocalLoadOpAdaptor, mlir::LLVMTypeConverter const*, mlir::ConversionPatternRewriter&) const /project/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp:135:17
#21 0x000073e3e890d892 (anonymous namespace)::LocalLoadOpConversion::matchAndRewrite(mlir::triton::gpu::LocalLoadOp, mlir::triton::gpu::LocalLoadOpAdaptor, mlir::ConversionPatternRewriter&) const /project/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp:68:14
#22 0x000073e3e88c055e mlir::ConvertOpToLLVMPattern<mlir::triton::gpu::LocalLoadOp>::matchAndRewrite(mlir::Operation*, llvm::ArrayRef<mlir::Value>, mlir::ConversionPatternRewriter&) const /root/.triton/llvm/llvm-657ec732-almalinux-x64/include/mlir/Conversion/LLVMCommon/Pattern.h:165:5
#23 0x000073e3ea4f1781 mlir::ConversionPattern::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&) const DialectConversion.cpp:0:0
#24 0x000073e3ea53564a void llvm::function_ref<void ()>::callback_fn<mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>)::$_0>(long) PatternApplicator.cpp:0:0
#25 0x000073e3ea53206f mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>) PatternApplicator.cpp:0:0
#26 0x000073e3ea4f2547 (anonymous namespace)::OperationLegalizer::legalize(mlir::Operation*, mlir::ConversionPatternRewriter&) DialectConversion.cpp:0:0
#27 0x000073e3ea4f1817 mlir::OperationConverter::convert(mlir::ConversionPatternRewriter&, mlir::Operation*) DialectConversion.cpp:0:0
#28 0x000073e3ea4f29a0 mlir::OperationConverter::convertOperations(llvm::ArrayRef<mlir::Operation*>) DialectConversion.cpp:0:0
#29 0x000073e3ea4fa75b mlir::applyPartialConversion(mlir::Operation*, mlir::ConversionTarget const&, mlir::FrozenRewritePatternSet const&, mlir::ConversionConfig) DialectConversion.cpp:0:0
#30 0x000073e3e8952a3e (anonymous namespace)::ConvertTritonGPUToLLVM::runOnOperation() /project/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp:175:16
#31 0x000073e3e8cf7c06 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) Pass.cpp:0:0
#32 0x000073e3e8cf8432 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) Pass.cpp:0:0
#33 0x000073e3e8cfac5a mlir::PassManager::run(mlir::Operation*) Pass.cpp:0:0
#34 0x000073e3e8a2a832 mlir::LogicalResult::failed() const /root/.triton/llvm/llvm-657ec732-almalinux-x64/include/mlir/Support/LogicalResult.h:44:33
#35 0x000073e3e8a2a832 mlir::failed(mlir::LogicalResult) /root/.triton/llvm/llvm-657ec732-almalinux-x64/include/mlir/Support/LogicalResult.h:72:58
#36 0x000073e3e8a2a832 init_triton_ir(pybind11::module_&&)::$_223::operator()(mlir::PassManager&, mlir::ModuleOp&) const /project/python/src/ir.cc:1662:13
#37 0x000073e3e8a2a832 void pybind11::detail::argument_loader<mlir::PassManager&, mlir::ModuleOp&>::call_impl<void, init_triton_ir(pybind11::module_&&)::$_223&, 0ul, 1ul, pybind11::detail::void_type>(init_triton_ir(pybind11::module_&&)::$_223&, std::integer_sequence<unsigned long, 0ul, 1ul>, pybind11::detail::void_type&&) && /root/.triton/pybind11/pybind11-2.11.1/include/pybind11/detail/../cast.h:1480:16
#38 0x000073e3e8a2a832 std::enable_if<std::is_void<void>::value, pybind11::detail::void_type>::type pybind11::detail::argument_loader<mlir::PassManager&, mlir::ModuleOp&>::call<void, pybind11::detail::void_type, init_triton_ir(pybind11::module_&&)::$_223&>(init_triton_ir(pybind11::module_&&)::$_223&) && /root/.triton/pybind11/pybind11-2.11.1/include/pybind11/detail/../cast.h:1454:35
#39 0x000073e3e8a2a832 void pybind11::cpp_function::initialize<init_triton_ir(pybind11::module_&&)::$_223, void, mlir::PassManager&, mlir::ModuleOp&, pybind11::name, pybind11::is_method, pybind11::sibling>(init_triton_ir(pybind11::module_&&)::$_223&&, void (*)(mlir::PassManager&, mlir::ModuleOp&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::'lambda'(pybind11::detail::function_call&)::operator()(pybind11::detail::function_call&) const /root/.triton/pybind11/pybind11-2.11.1/include/pybind11/pybind11.h:0:0
#40 0x000073e3e8a2a832 void pybind11::cpp_function::initialize<init_triton_ir(pybind11::module_&&)::$_223, void, mlir::PassManager&, mlir::ModuleOp&, pybind11::name, pybind11::is_method, pybind11::sibling>(init_triton_ir(pybind11::module_&&)::$_223&&, void (*)(mlir::PassManager&, mlir::ModuleOp&), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::'lambda'(pybind11::detail::function_call&)::__invoke(pybind11::detail::function_call&) /root/.triton/pybind11/pybind11-2.11.1/include/pybind11/pybind11.h:224:21
#41 0x000073e3e89da92d pybind11::cpp_function::dispatcher(_object*, _object*, _object*) /root/.triton/pybind11/pybind11-2.11.1/include/pybind11/pybind11.h:0:30
#42 0x00000000004fdc87 _PyErr_Occurred /usr/local/src/conda/python-3.10.14/Include/internal/pycore_pyerrors.h:14:18
#43 0x00000000004fdc87 _Py_CheckFunctionResult /usr/local/src/conda/python-3.10.14/Objects/call.c:39:14
#44 0x00000000004fdc87 cfunction_call /usr/local/src/conda/python-3.10.14/Objects/methodobject.c:554:12
#45 0x00000000004f741b _Py_LeaveRecursiveCall /usr/local/src/conda/python-3.10.14/Include/internal/pycore_ceval.h:100:28
#46 0x00000000004f741b _PyObject_MakeTpCall /usr/local/src/conda/python-3.10.14/Objects/call.c:216:9
#47 0x0000000000509cbf _PyObject_VectorcallTstate /usr/local/src/conda/python-3.10.14/Include/cpython/abstract.h:112:16
#48 0x0000000000509cbf _PyObject_VectorcallTstate /usr/local/src/conda/python-3.10.14/Include/cpython/abstract.h:99:1
#49 0x0000000000509cbf method_vectorcall /usr/local/src/conda/python-3.10.14/Objects/classobject.c:53:18
#50 0x00000000004f2c16 _Py_CheckFunctionResult /usr/local/src/conda/python-3.10.14/Objects/call.c:38:8
#51 0x00000000004f2c16 _PyObject_VectorcallTstate /usr/local/src/conda/python-3.10.14/Include/cpython/abstract.h:115:12
#52 0x00000000004f2c16 PyObject_Vectorcall /usr/local/src/conda/python-3.10.14/Include/cpython/abstract.h:123:12
#53 0x00000000004f2c16 call_function /usr/local/src/conda/python-3.10.14/Python/ceval.c:5893:13
#54 0x00000000004f2c16 codeobject.c.cf05b00b /usr/local/src/conda/python-3.10.14/Python/ceval.c:4181:23
#55 0x00000000004fe0cf _Py_REFCNT /usr/local/src/conda/python-3.10.14/Include/object.h:131:14
#56 0x00000000004fe0cf _PyEval_Vector /usr/local/src/conda/python-3.10.14/Python/ceval.c:5074:9
#57 0x00000000004fe0cf _PyFunction_Vectorcall /usr/local/src/conda/python-3.10.14/Objects/call.c:342:16
#58 0x00000000004f2c16 _Py_CheckFunctionResult /usr/local/src/conda/python-3.10.14/Objects/call.c:38:8
#59 0x00000000004f2c16 _PyObject_VectorcallTstate /usr/local/src/conda/python-3.10.14/Include/cpython/abstract.h:115:12
#60 0x00000000004f2c16 PyObject_Vectorcall /usr/local/src/conda/python-3.10.14/Include/cpython/abstract.h:123:12
#61 0x00000000004f2c16 call_function /usr/local/src/conda/python-3.10.14/Python/ceval.c:5893:13
#62 0x00000000004f2c16 codeobject.c.cf05b00b /usr/local/src/conda/python-3.10.14/Python/ceval.c:4181:23
#63 0x00000000004fe0cf _Py_REFCNT /usr/local/src/conda/python-3.10.14/Include/object.h:131:14
#64 0x00000000004fe0cf _PyEval_Vector /usr/local/src/conda/python-3.10.14/Python/ceval.c:5074:9
#65 0x00000000004fe0cf _PyFunction_Vectorcall /usr/local/src/conda/python-3.10.14/Objects/call.c:342:16
#66 0x00000000004ee40f _Py_CheckFunctionResult /usr/local/src/conda/python-3.10.14/Objects/call.c:38:8
#67 0x00000000004ee40f _PyObject_VectorcallTstate /usr/local/src/conda/python-3.10.14/Include/cpython/abstract.h:115:12
#68 0x00000000004ee40f PyObject_Vectorcall /usr/local/src/conda/python-3.10.14/Include/cpython/abstract.h:123:12
#69 0x00000000004ee40f call_function /usr/local/src/conda/python-3.10.14/Python/ceval.c:5893:13
#70 0x00000000004ee40f _PyEval_EvalFrameDefault /usr/local/src/conda/python-3.10.14/Python/ceval.c:4213:19
#71 0x00000000004fe0cf _Py_REFCNT /usr/local/src/conda/python-3.10.14/Include/object.h:131:14
#72 0x00000000004fe0cf _PyEval_Vector /usr/local/src/conda/python-3.10.14/Python/ceval.c:5074:9
#73 0x00000000004fe0cf _PyFunction_Vectorcall /usr/local/src/conda/python-3.10.14/Objects/call.c:342:16
#74 0x00000000004ef4a3 _Py_CheckFunctionResult /usr/local/src/conda/python-3.10.14/Objects/call.c:38:8
#75 0x00000000004ef4a3 _PyObject_VectorcallTstate /usr/local/src/conda/python-3.10.14/Include/cpython/abstract.h:115:12
#76 0x00000000004ef4a3 PyObject_Vectorcall /usr/local/src/conda/python-3.10.14/Include/cpython/abstract.h:123:12
#77 0x00000000004ef4a3 call_function /usr/local/src/conda/python-3.10.14/Python/ceval.c:5893:13
#78 0x00000000004ef4a3 classobject.c.0e1a9fd1 /usr/local/src/conda/python-3.10.14/Python/ceval.c:4231:19
#79 0x00000000005099ce _Py_REFCNT /usr/local/src/conda/python-3.10.14/Include/object.h:131:14
#80 0x00000000005099ce _PyEval_Vector /usr/local/src/conda/python-3.10.14/Python/ceval.c:5074:9
#81 0x00000000005099ce _PyFunction_Vectorcall /usr/local/src/conda/python-3.10.14/Objects/call.c:342:16
#82 0x00000000005099ce _PyObject_VectorcallTstate /usr/local/src/conda/python-3.10.14/Include/cpython/abstract.h:114:11
#83 0x00000000005099ce method_vectorcall /usr/local/src/conda/python-3.10.14/Objects/classobject.c:53:18
#84 0x000000000050a508 PyVectorcall_Call /usr/local/src/conda/python-3.10.14/Objects/call.c:269:5
#85 0x000000000050a508 _PyObject_Call /usr/local/src/conda/python-3.10.14/Objects/call.c:290:16
#86 0x000000000050a508 PyObject_Call /usr/local/src/conda/python-3.10.14/Objects/call.c:317:12
#87 0x00000000004f0c69 do_call_core /usr/local/src/conda/python-3.10.14/Python/ceval.c:5945:12
#88 0x00000000004f0c69 classobject.c.0e1a9fd1 /usr/local/src/conda/python-3.10.14/Python/ceval.c:4277:22
#89 0x00000000004fe0cf _Py_REFCNT /usr/local/src/conda/python-3.10.14/Include/object.h:131:14
#90 0x00000000004fe0cf _PyEval_Vector /usr/local/src/conda/python-3.10.14/Python/ceval.c:5074:9
#91 0x00000000004fe0cf _PyFunction_Vectorcall /usr/local/src/conda/python-3.10.14/Objects/call.c:342:16
#92 0x000000000050a508 PyVectorcall_Call /usr/local/src/conda/python-3.10.14/Objects/call.c:269:5
#93 0x000000000050a508 _PyObject_Call /usr/local/src/conda/python-3.10.14/Objects/call.c:290:16
#94 0x000000000050a508 PyObject_Call /usr/local/src/conda/python-3.10.14/Objects/call.c:317:12
#95 0x00000000004f0c69 do_call_core /usr/local/src/conda/python-3.10.14/Python/ceval.c:5945:12
#96 0x00000000004f0c69 classobject.c.0e1a9fd1 /usr/local/src/conda/python-3.10.14/Python/ceval.c:4277:22
#97 0x00000000004fe0cf _Py_REFCNT /usr/local/src/conda/python-3.10.14/Include/object.h:131:14
#98 0x00000000004fe0cf _PyEval_Vector /usr/local/src/conda/python-3.10.14/Python/ceval.c:5074:9
#99 0x00000000004fe0cf _PyFunction_Vectorcall /usr/local/src/conda/python-3.10.14/Objects/call.c:342:16
#100 0x00000000004ee40f _Py_CheckFunctionResult /usr/local/src/conda/python-3.10.14/Objects/call.c:38:8
#101 0x00000000004ee40f _PyObject_VectorcallTstate /usr/local/src/conda/python-3.10.14/Include/cpython/abstract.h:115:12
#102 0x00000000004ee40f PyObject_Vectorcall /usr/local/src/conda/python-3.10.14/Include/cpython/abstract.h:123:12
#103 0x00000000004ee40f call_function /usr/local/src/conda/python-3.10.14/Python/ceval.c:5893:13
#104 0x00000000004ee40f _PyEval_EvalFrameDefault /usr/local/src/conda/python-3.10.14/Python/ceval.c:4213:19
#105 0x00000000004fe0cf _Py_REFCNT /usr/local/src/conda/python-3.10.14/Include/object.h:131:14
#106 0x00000000004fe0cf _PyEval_Vector /usr/local/src/conda/python-3.10.14/Python/ceval.c:5074:9
#107 0x00000000004fe0cf _PyFunction_Vectorcall /usr/local/src/conda/python-3.10.14/Objects/call.c:342:16
#108 0x00000000004ee40f _Py_CheckFunctionResult /usr/local/src/conda/python-3.10.14/Objects/call.c:38:8
#109 0x00000000004ee40f _PyObject_VectorcallTstate /usr/local/src/conda/python-3.10.14/Include/cpython/abstract.h:115:12
#110 0x00000000004ee40f PyObject_Vectorcall /usr/local/src/conda/python-3.10.14/Include/cpython/abstract.h:123:12
#111 0x00000000004ee40f call_function /usr/local/src/conda/python-3.10.14/Python/ceval.c:5893:13
#112 0x00000000004ee40f _PyEval_EvalFrameDefault /usr/local/src/conda/python-3.10.14/Python/ceval.c:4213:19
#113 0x00000000004fe0cf _Py_REFCNT /usr/local/src/conda/python-3.10.14/Include/object.h:131:14
#114 0x00000000004fe0cf _PyEval_Vector /usr/local/src/conda/python-3.10.14/Python/ceval.c:5074:9
#115 0x00000000004fe0cf _PyFunction_Vectorcall /usr/local/src/conda/python-3.10.14/Objects/call.c:342:16
#116 0x00000000004ee40f _Py_CheckFunctionResult /usr/local/src/conda/python-3.10.14/Objects/call.c:38:8
#117 0x00000000004ee40f _PyObject_VectorcallTstate /usr/local/src/conda/python-3.10.14/Include/cpython/abstract.h:115:12
#118 0x00000000004ee40f PyObject_Vectorcall /usr/local/src/conda/python-3.10.14/Include/cpython/abstract.h:123:12
#119 0x00000000004ee40f call_function /usr/local/src/conda/python-3.10.14/Python/ceval.c:5893:13
#120 0x00000000004ee40f _PyEval_EvalFrameDefault /usr/local/src/conda/python-3.10.14/Python/ceval.c:4213:19
#121 0x00000000005950f2 _Py_REFCNT /usr/local/src/conda/python-3.10.14/Include/object.h:131:14
#122 0x00000000005950f2 _PyEval_Vector /usr/local/src/conda/python-3.10.14/Python/ceval.c:5074:9
#123 0x0000000000595037 PyEval_EvalCode /usr/local/src/conda/python-3.10.14/Python/ceval.c:1135:1
#124 0x00000000005c5e67 run_eval_code_obj /usr/local/src/conda/python-3.10.14/Python/pythonrun.c:1292:8
#125 0x00000000005c0fb0 _Py_DECREF /usr/local/src/conda/python-3.10.14/Include/object.h:492:8
#126 0x00000000005c0fb0 run_mod /usr/local/src/conda/python-3.10.14/Python/pythonrun.c:1313:5
#127 0x000000000045970e pyrun_file.cold /usr/local/src/conda/python-3.10.14/Python/pythonrun.c:1208:15
#128 0x00000000005bb53f _PyRun_SimpleFileObject /usr/local/src/conda/python-3.10.14/Python/pythonrun.c:456:13
#129 0x00000000005bb2a3 _PyRun_AnyFileObject /usr/local/src/conda/python-3.10.14/Python/pythonrun.c:90:15
#130 0x00000000005b805d pymain_run_file_obj /usr/local/src/conda/python-3.10.14/Modules/main.c:358:17
#131 0x00000000005b805d pymain_run_file /usr/local/src/conda/python-3.10.14/Modules/main.c:376:15
#132 0x00000000005b805d pymain_run_python /usr/local/src/conda/python-3.10.14/Modules/main.c:591:21
#133 0x00000000005b805d Py_RunMain /usr/local/src/conda/python-3.10.14/Modules/main.c:670:5
#134 0x0000000000588679 Py_BytesMain /usr/local/src/conda/python-3.10.14/Modules/main.c:1091:1
#135 0x000073e4f5029d90 __libc_start_call_main ./csu/../sysdeps/nptl/libc_start_call_main.h:58:16
#136 0x000073e4f5029e40 call_init ./csu/../csu/libc-start.c:128:20
#137 0x000073e4f5029e40 __libc_start_main ./csu/../csu/libc-start.c:379:5
#138 0x000000000058852e _start (/home/cc2869/.conda/envs/Palu/bin/python3.10+0x58852e)
Aborted (core dumped)

Code to reproduce

For checking the kernel implementation, please go ahead to the _ab_qx_fwd function.

import torch
import random
import triton
import triton.language as tl
#from kernel.packing import *
import numpy as np
import torch

def quant_and_pack_vcache(v: torch.FloatTensor, group_size: int, bits: int):
    shape = v.shape
    assert len(shape) == 4
    assert v.shape[-1] % group_size == 0
    num_groups = shape[-1] // group_size
    new_shape = (shape[:-1] + (num_groups, group_size))
    # Quantize
    max_int = 2 ** bits - 1
    data = v.view(new_shape)
    mn = torch.min(data, dim=-1, keepdim=True)[0]
    mx = torch.max(data, dim=-1, keepdim=True)[0]
    scale = (mx - mn) / max_int
    data = data - mn
    data.div_(scale)
    data = data.clamp_(0, max_int).round_().to(torch.int32)
    data = data.view(shape)
    #print(data)
    # Pack
    code = pack_tensor(data, bits, pack_dim=3)
    #print(code)
    return code, scale, mn

def unpack_and_dequant_vcache(v_code: torch.FloatTensor, 
                              scale: torch.FloatTensor, 
                              mn: torch.FloatTensor,
                              group_size: int, 
                              bits: int,
                              ):
    assert bits in [2, 4, 8]
    assert len(v_code.shape) == 4
    data = unpack_tensor(v_code, bits, pack_dim=3)
    #print(data.shape)
    shape = data.shape
    #num_groups = shape[-1] // group_size
    #data = data.view(shape[:-1] + (num_groups, group_size,))
    #print(data.shape)
    data = data.to(torch.float16)
    data = data * scale + mn 
    #print(data.shape)
    return data.view(shape)

def pack_tensor(data, bits, pack_dim):
    # Pack
    shape = data.shape
    feat_per_int = 32 // bits
    assert bits in [2,4,8], "Only 2, 4, 8 bits are supported"
    assert shape[pack_dim] % feat_per_int == 0, "Dimension length must be divisible by number of features per int"
    # BS, nh, T, nd // 16 # 16 is for 2bit
    code = torch.zeros(shape[:pack_dim] + (shape[pack_dim] // feat_per_int,)+shape[pack_dim+1:], 
                    dtype=torch.int32, 
                    device=data.device)
    i = 0
    row = 0
    unpacked_indices = [slice(None)] * len(data.shape)
    packed_indices = [slice(None)] * len(data.shape)
    while row < code.shape[pack_dim]:
        packed_indices[pack_dim] = row
        for j in range(i, i + (32 // bits)):
            unpacked_indices[pack_dim] = j
            code[packed_indices] |= data[unpacked_indices] << (bits * (j - i))
        i += 32 // bits
        row += 1
    return code

def unpack_tensor(v_code: torch.FloatTensor, 
                  bits: int, 
                  pack_dim: int):
    assert bits in [2,4,8]
    shape = v_code.shape
    feat_per_int = 32 // bits
    new_shape = shape[:pack_dim] + (shape[pack_dim] * feat_per_int,) + shape[pack_dim+1:]
    unpacked_v_code = torch.zeros(new_shape, dtype=torch.int8, device=v_code.device)
    i = torch.arange(new_shape[pack_dim], device=v_code.device) // feat_per_int
    j = torch.arange(new_shape[pack_dim], device=v_code.device) % feat_per_int
    num = 0xFF >> (8 - bits)
    packed_indices = [slice(None)] * len(new_shape)
    packed_indices[pack_dim] = i
    if pack_dim == 2:
        unpacked_v_code = ((v_code[packed_indices] >> (j * bits)[None, None, :, None]).to(torch.int16)) & num
    elif pack_dim == 3:
        unpacked_v_code = ((v_code[packed_indices] >> (j * bits)).to(torch.int16)) & num
    else:
        raise NotImplementedError
    return unpacked_v_code

@triton.jit
def _pack_along_last_dim(
    bits: tl.constexpr,
    intensor_ptr,
    code_ptr,
    N,
    num_feats: tl.constexpr,
    feat_per_int: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr
):
    num_int_per_y_dim = num_feats // feat_per_int
    bid = tl.program_id(axis=0)
    yid = tl.program_id(axis=1)
    offs_N = bid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    block_start = intensor_ptr + offs_N * num_feats + yid * feat_per_int # offset of the first element at current tile
    packed = tl.zeros((BLOCK_SIZE_N,), dtype=tl.int32)
    for i in range(feat_per_int):
        ptr = block_start + i
        element = tl.load(ptr, mask=offs_N<N, other=0.)
        element = element << (i * bits)
        # Combine the value using bitwise OR
        packed = packed | element
    tl.store(code_ptr + offs_N * num_int_per_y_dim + yid, packed, mask=offs_N < N)

@triton.jit
def _minmax_along_last_dim(
    x_ptr,
    mn_ptr, mx_ptr,
    total_elements: tl.constexpr, 
    N: tl.constexpr,
    num_groups: tl.constexpr, 
    group_size: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr
):
    bid = tl.program_id(axis=0)
    offsets_b = bid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    offsets = offsets_b[:, None] * group_size + tl.arange(0, group_size)[None, :]
    mask = offsets < total_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    mx_val = tl.max(x, axis=1)
    mn_val = tl.min(x, axis=1)
    # tl.device_print('shape', mn_val[:, None].shape)
    tl.store(mn_ptr+offsets_b, mn_val, mask=offsets_b<N*num_groups)
    tl.store(mx_ptr+offsets_b, mx_val, mask=offsets_b<N*num_groups)

def triton_quantize_and_pack_along_last_dim(data: torch.Tensor, group_size: int, bit: int):
    assert len(data.shape) == 4
    shape = data.shape
    B, nh, D, T = shape
    # ================== Get Scale & Zeros ===============
    assert T % group_size == 0
    num_groups = T // group_size
    new_shape = (B * nh * D, num_groups, group_size)
    scale_mn_shape = B, nh, D, num_groups
    # Quantize
    data = data.reshape(new_shape)
    mx = torch.empty((B * nh * D, num_groups), device=data.device, dtype=data.dtype)
    mn = torch.empty((B * nh * D, num_groups), device=data.device, dtype=data.dtype)
    BLOCK_SIZE_N = 128
    grid = lambda meta: (triton.cdiv(data.shape[0]*data.shape[1], BLOCK_SIZE_N),)
    _minmax_along_last_dim[grid](data, mn, mx,
                             data.numel(), data.shape[0], num_groups, group_size,
                             BLOCK_SIZE_N=BLOCK_SIZE_N, num_warps=8) 
    # mn = torch.min(data, dim=-1, keepdim=True)[0].squeeze(-1)
    # mx = torch.max(data, dim=-1, keepdim=True)[0].squeeze(-1)
    scale = (mx - mn) / (2 ** bit - 1)
    data = data - mn.unsqueeze(-1)
    data.div_(scale.unsqueeze(-1))
    data = data.clamp_(0, 2 ** bit - 1).round_().to(torch.int32)
    data = data.view(-1, T)
    feat_per_int = 32 // bit
    packshape = (np.prod(shape[:-1]), shape[-1] // feat_per_int,)
    code = torch.zeros(*packshape, device=data.device, dtype=torch.int32)
    grid = lambda meta: (triton.cdiv(data.shape[0], BLOCK_SIZE_N), data.shape[1] // feat_per_int,)
    _pack_along_last_dim[grid](bit, data, code, data.shape[0], 
                                data.shape[1], feat_per_int, 
                                BLOCK_SIZE_N=BLOCK_SIZE_N, 
                                num_warps=8)
    return code.view(B, nh, D, -1), scale.reshape(scale_mn_shape), mn.reshape(scale_mn_shape)

def get_configs():
    configs = []
    for block_l in [16, 32, 64, 128]:
        for block_r in [16, 32]:
            for num_warps in [1, 4, 8, 16]:
                for num_stages in [1, 2, 3]:
                    configs.append(
                        triton.Config({'BLOCK_SIZE_L': block_l, 'BLOCK_SIZE_R': block_r},
                                num_stages=num_stages, num_warps=num_warps))
    # return configs
    # return [triton.Config({'BLOCK_SIZE_L': 128, 'BLOCK_SIZE_R': 32}, num_warps=4, num_stages=3)] # for gs=4
    # return [triton.Config({'BLOCK_SIZE_L': 64, 'BLOCK_SIZE_R': 32}, num_warps=4, num_stages=3)] # for gs=2
    return [triton.Config({'BLOCK_SIZE_L': 64, 'BLOCK_SIZE_R': 32}, num_warps=4, num_stages=1)] # for gs=1

# @triton.autotune(
#     configs= get_configs(),
#     key=["seq_len"]
# )
@triton.jit
def _ab_qx_fwd(
    bits, group_size,
    # ptrs
    # a_ptr: (bs, num_heads, seq_len(assume 1), head_dim)
    # b_ptr: (num_goups, rank, head_group_size*head_dim)
    # x_ptr: (bs, num_groups, seq_len, rank)
    a_ptr, b_ptr, x_ptr, 
    scales_ptr, zeros_ptr, out_ptr,
    # strides
    stride_az, stride_aa ,stride_ad,
    stride_bz, stride_br, stride_bd,
    stride_xhg, stride_xl, stride_xr,
    stride_scales_hg, stride_scales_xl, stide_scales_g,
    stride_zeros_hg, stride_zeros_xl, stide_zeros_g,
    #NOTE(brian1009): Debug, check dequant first
    #stride_ohg, stride_ol, stride_or,
    stide_oz, stride_oa, stride_ol,
    R, D, seq_len,
    BLOCK_SIZE_D: tl.constexpr,
    BLOCK_SIZE_R: tl.constexpr,
    BLOCK_SIZE_L: tl.constexpr,
    NUM_GROUPS: tl.constexpr,
    THETA: tl.constexpr,
):
    pid_h = tl.program_id(axis=0) # parallel alone heads
    pid_l = tl.program_id(axis=1) # parallel alone blocks of sequence dim

    HEAD_GROUPS_ID = pid_h // (32 // NUM_GROUPS)

    offs_ds = tl.arange(0, BLOCK_SIZE_D)
    offs_rs = tl.arange(0, BLOCK_SIZE_R)
    offs_ls = (pid_l * BLOCK_SIZE_L) + tl.arange(0, BLOCK_SIZE_L)

    feat_per_int = 32 // bits

    # Load A and B
    A_ptrs = a_ptr + pid_h * stride_az + (0*stride_aa + offs_ds[None, :]*stride_ad) # assume a is always (bs, nh, 1, d)
    B_ptrs = b_ptr + pid_h * stride_bz + (offs_rs[:, None]*stride_br + offs_ds[None, :]*stride_bd) 
    X_ptrs = x_ptr + HEAD_GROUPS_ID * stride_xhg + (offs_ls[:, None]*stride_xl + (offs_rs[None, :] // feat_per_int) * stride_xr) # (BLOCK_SIZE_L, BLOCK_SIZE_R)
    scales_ptr = scales_ptr + HEAD_GROUPS_ID * stride_scales_hg + (offs_ls[:, None]*stride_scales_xl) # (BLOCK_SIZE_L, 1)
    zeros_ptr = zeros_ptr + HEAD_GROUPS_ID * stride_zeros_hg + (offs_ls[:, None]*stride_zeros_xl)
    # Set output ptr
    O_ptrs = out_ptr + pid_h * stide_oz + (0*stride_oa + offs_ls[None, :] * stride_ol) # follow the shape assumption of a, we set 0*stride_oa
    # NOTE(brian1009) debug
    #O_ptrs = out_ptr + pid_h * stride_ohg + (offs_ls[:, None] * stride_ol + offs_rs[None, :] * stride_or) 

    # parameters for dequantization
    # NOTE(brian1009): Since we do not have group-wise quant yet,
    # Hence, we dont need update the scales and zeros for each rank-dimensions
    # as it is the same.    
    shifter = (offs_rs % feat_per_int) * bits
    num = 0xFF >> (8-bits)
    scales = tl.load(scales_ptr)
    zeros = tl.load(zeros_ptr)
    #zeros = (zeros * 1.0).to(tl.float16)

    xb_0 = tl.zeros((BLOCK_SIZE_L, BLOCK_SIZE_D), dtype=tl.float32)
    xb_1 = tl.zeros((BLOCK_SIZE_L, BLOCK_SIZE_D), dtype=tl.float32)

    for _ in range(0, tl.cdiv(R, BLOCK_SIZE_R)):
        # Load next block of B, X
        x = tl.load(X_ptrs)
        x = (x >> shifter[None, :] & num)
        #FIXME: Multiply this 1.0 can make execution normal Triton==3.0.0
        #x = x * scales + zeros * 1.0 
        #x = x.to(tl.float16)   
        x = x * scales + zeros
        b_0 = tl.load(B_ptrs)
        b_1 = tl.load(B_ptrs + BLOCK_SIZE_D * stride_bd)
        # Accumulate along the rank dimension
        xb_0 += tl.dot(x, b_0)
        xb_1 += tl.dot(x, b_1)
        # Advance the pointer to the next blocks
        B_ptrs += BLOCK_SIZE_R * stride_br
        X_ptrs += (BLOCK_SIZE_R // feat_per_int) * stride_xr

    xb_0 = xb_0.to(tl.float16)
    xb_1 = xb_1.to(tl.float16)

    # RoPE (TBD)

    # # GEMV
    a_0 = tl.load(A_ptrs)
    a_1 = tl.load(A_ptrs + BLOCK_SIZE_D * stride_ad)
    abx_0 = tl.sum(a_0 * xb_0, 1)
    abx_1 = tl.sum(a_1 * xb_1, 1)
    abx = abx_0 + abx_1
    tl.store(O_ptrs, abx[None, :])

def triton_ab_qx_rope(
    a: torch.Tensor, b: torch.Tensor, x_q: torch.Tensor, 
    x_scales: torch.Tensor, x_zeros: torch.Tensor, 
    x_bits: int, x_quant_group_size: int):

    assert x_bits in [4, 8]
    assert a.dim() == 3
    assert b.dim() == 3
    assert x_q.dim() == 3

    feat_per_int = 32 // x_bits
    num_heads, _, head_dim = a.shape
    num_heads, rank_per_head_groups, head_dim = b.shape
    num_groups, seq_len, packed_rank_per_head_groups = x_q.shape
    rank_per_head_groups = packed_rank_per_head_groups * feat_per_int

    out = torch.empty((num_heads, 1, seq_len), dtype=a.dtype, device=a.device)
    BLOCK_SIZE_D = 64
    NUM_GROUPS = num_groups
    grid = lambda META: (
       num_heads ,triton.cdiv(seq_len, META['BLOCK_SIZE_L']),
    )
    _ab_qx_fwd[grid](
        x_bits, x_quant_group_size,
        a, b, x_q, x_scales, x_zeros, out,
        a.stride(0), a.stride(1), a.stride(2),
        b.stride(0), b.stride(1), b.stride(2),
        x_q.stride(0), x_q.stride(1), x_q.stride(2),
        x_scales.stride(0), x_scales.stride(1), x_scales.stride(2),
        x_zeros.stride(0), x_zeros.stride(1), x_zeros.stride(2),
        out.stride(0), out.stride(1), out.stride(2),
        rank_per_head_groups, num_heads, seq_len,
        BLOCK_SIZE_D=BLOCK_SIZE_D,
        BLOCK_SIZE_L = 32,
        BLOCK_SIZE_R = 64,
        # num_stages=num_stages,
        # num_warps=num_warps,
        NUM_GROUPS=NUM_GROUPS,
        THETA=10000.
    )
    return out

def torch_ab_qx_rope2(a, b, x_q, x_scales, x_zeros, x_bits, x_quant_group_size):
    x_f = unpack_and_dequant_vcache(x_q, x_scales, x_zeros, x_quant_group_size, x_bits) #(num_groups, seq_len, rank_per_head_groups)
    #return x_f
    out_puts = torch.empty((a.shape[0], a.shape[1], x_f.shape[-2]), dtype=torch.float16, device=a.device)
    for i in range(b.shape[0]):
        b_i = b[i]
        a_i = a[i]
        rank_groups_id = i // (32 // x_q.shape[1])
        x_f_i = x_f[:, rank_groups_id].squeeze(0)
        xb_i = x_f_i @ b_i
        axb_i = a_i @ xb_i.transpose(1, 0)
        out_puts[i] = axb_i
    return out_puts

def test_correctness(args):
    num_heads = args.num_heads
    head_dim = args.head_dim
    total_rank = args.total_rank
    seq_len = 512
    num_groups = args.num_groups
    rank_per_groups = total_rank // num_groups
    dtype = torch.float16
    device = torch.device('cuda')
    bits = args.x_bits

    A = torch.randn(num_heads, 1, head_dim, dtype=dtype, device=device)
    B = torch.randn(num_heads, rank_per_groups, head_dim, dtype=dtype, device=device)
    X = torch.randn(num_groups, seq_len, rank_per_groups, dtype=dtype, device=device)
    X_q, X_scales, X_zeros = triton_quantize_and_pack_along_last_dim(X.unsqueeze(0), group_size=rank_per_groups, bit=bits)  
    out_torch2 = torch_ab_qx_rope2(A, B, X_q, X_scales, X_zeros, bits, rank_per_groups)
    out_triton = triton_ab_qx_rope(A, B, X_q.squeeze(0), X_scales.squeeze(0), X_zeros.squeeze(0), bits, rank_per_groups)
    #print(out_triton)
    print("Correctness: ", torch.allclose(out_torch2, out_triton, atol=1, rtol=1e-4))

def main(args):
    args.num_groups = args.num_heads // args.group_size
    args.group_rank = args.total_rank // args.num_groups
    test_correctness(args)

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser = argparse.ArgumentParser(description="Argument Parser")
    parser.add_argument("--total_rank", type=int, default=1024, help="Total rank")
    parser.add_argument("--num_heads", type=int, default=32, help="Number of heads, default to 32 (llama)")
    parser.add_argument("--head_dim", type=int, default=128, help="Head dimension, default to 128 (llama)")
    parser.add_argument("--group_size", type=int, default=4, help="Number of heads per group")
    parser.add_argument("--target_seq_lens", nargs="+", type=int, 
                        default=[4096, 16384, 65536, 262144], help="Target sequence lengths")
    parser.add_argument("--x_bits", type=int, default=4, help="Number of bits for quantization")
    args = parser.parse_args()
    main(args)
mobicham commented 1 week ago

This bug is still in the master branch as of October 17th, it happens with integer-packed weights. One way to fix it is by replacing the for loop with:

for k in tl.range(0, total_blocks_k, 1, num_stages=1):

The strange thing also is that, if you add an if statement in the for loop, the error disappears:

for k in range(0, total_blocks_k, 1):
    if(k < total_blocks_k):