Open karupayun opened 1 year ago
We created the following TTIR, which was previously working for us. However, it is now segfaulting at HEAD on main.
The TTIR is essentially loading 2 parameters (one s8, one f32), convert the s8 to a f32, then dot.
I did a bisect and narrowed down the commit where the segfault started to https://github.com/openai/triton/commit/52c146f66b79b6079bcd28c55312fc6ea1852519. The issue can be solved by setting int minBitwidth = 0; (in line https://github.com/openai/triton/blob/17d633a64e43337037d2e873b029fab92422762f/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp#L299) but also commenting https://github.com/openai/triton/blob/17d633a64e43337037d2e873b029fab92422762f/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp#L250 as we do internally.
int minBitwidth = 0;
I attached a lit test that runs the entire pipeline for a convenient repro: crash.mlir.txt
Segfault stacktrace:
Stack dump: 0. Program arguments: build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt --triton-rewrite-tensor-pointer --inline --canonicalize --triton-combine --triton-reorder-broadcast --canonicalize --cse --loop-invariant-code-motion --symbol-dce --convert-triton-to-tritongpu --tritongpu-coalesce --tritongpu-remove-layout-conversions --tritongpu-accelerate-matmul --tritongpu-remove-layout-conversions --tritongpu-optimize-dot-operands --tritongpu-pipeline --tritongpu-prefetch --tritongpu-optimize-dot-operands --tritongpu-remove-layout-conversions --tritongpu-decompose-conversions --tritongpu-reorder-instructions --cse --symbol-dce --convert-scf-to-cf --convert-index-to-llvm --convert-triton-gpu-to-llvm test.mlir #0 0x00005596154f701b llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x3bb901b) #1 0x00005596154f4e74 SignalHandler(int) Signals.cpp:0:0 #2 0x00007f99ef77b540 (/lib/x86_64-linux-gnu/libc.so.6+0x3c540) #3 0x000055961542a373 mlir::detail::OperandStorage::OperandStorage(mlir::Operation*, mlir::OpOperand*, mlir::ValueRange) (build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x3aec373) #4 0x000055961541938e mlir::Operation::create(mlir::Location, mlir::OperationName, mlir::TypeRange, mlir::ValueRange, mlir::DictionaryAttr, mlir::OpaqueProperties, mlir::BlockRange, unsigned int) (build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x3adb38e) #5 0x0000559615419733 mlir::Operation::create(mlir::Location, mlir::OperationName, mlir::TypeRange, mlir::ValueRange, mlir::NamedAttrList&&, mlir::OpaqueProperties, mlir::BlockRange, mlir::RegionRange) (build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x3adb733) #6 0x0000559615419c1a mlir::Operation::create(mlir::OperationState const&) (build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x3adbc1a) #7 0x000055961539878f mlir::OpBuilder::create(mlir::OperationState const&) (build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x3a5a78f) #8 0x0000559612b0b078 mlir::LLVM::InsertElementOp mlir::OpBuilder::create<mlir::LLVM::InsertElementOp, mlir::Value&, mlir::LLVM::ExtractElementOp, mlir::Value>(mlir::Location, mlir::Value&, mlir::LLVM::ExtractElementOp&&, mlir::Value&&) /usr/local/google/home/karupayun/.triton/llvm/llvm+mlir-17.0.0-x86_64-linux-gnu-ubuntu-18.04-release/include/mlir/IR/Builders.h:491:22 #9 0x0000559612b0b078 MMA16816SmemLoader::loadX4(int, int, llvm::ArrayRef<mlir::Value>, mlir::Type, mlir::Type) const /usr/local/google/home/karupayun/projects/triton/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp:380:15 #10 0x0000559612b0c55e getLoadMatrixFn(mlir::Value, mlir::LLVM::SharedMemoryObject const&, mlir::triton::gpu::MmaEncodingAttr, int, unsigned int, int, llvm::SmallVector<int, 12u>, llvm::SmallVector<int, 12u>, mlir::Value, mlir::Value, std::map<std::pair<unsigned int, unsigned int>, mlir::Value, std::less<std::pair<unsigned int, unsigned int> >, std::allocator<std::pair<std::pair<unsigned int, unsigned int> const, mlir::Value> > >&, bool, TritonGPUToLLVMTypeConverter*, mlir::ConversionPatternRewriter&, mlir::Location)::'lambda'(int, int)::operator()(int, int) const /usr/local/google/home/karupayun/projects/triton/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp:538:5 #11 0x0000559612b09a8a loadArg(mlir::ConversionPatternRewriter&, mlir::Location, mlir::Value, mlir::triton::gpu::DotOperandEncodingAttr, mlir::LLVM::SharedMemoryObject const&, TritonGPUToLLVMTypeConverter*, mlir::Value, bool) /usr/local/google/home/karupayun/projects/triton/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp:611:23 #12 0x0000559612a79827 ConvertLayoutOpConversion::lowerSharedToDotOperandMMA(mlir::triton::gpu::ConvertLayoutOp, mlir::triton::gpu::ConvertLayoutOpAdaptor, mlir::ConversionPatternRewriter&, mlir::triton::gpu::MmaEncodingAttr const&, mlir::triton::gpu::DotOperandEncodingAttr const&, bool) const /usr/local/google/home/karupayun/projects/triton/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp:950:51 #13 0x0000559612a821b6 ConvertLayoutOpConversion::lowerSharedToDotOperand(mlir::triton::gpu::ConvertLayoutOp, mlir::triton::gpu::ConvertLayoutOpAdaptor, mlir::ConversionPatternRewriter&) const /usr/local/google/home/karupayun/projects/triton/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp:849:39 #14 0x0000559612a87065 ConvertLayoutOpConversion::matchAndRewrite(mlir::triton::gpu::ConvertLayoutOp, mlir::triton::gpu::ConvertLayoutOpAdaptor, mlir::ConversionPatternRewriter&) const /usr/local/google/home/karupayun/projects/triton/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp:77:59 #15 0x0000559612a771d4 mlir::ConvertOpToLLVMPattern<mlir::triton::gpu::ConvertLayoutOp>::matchAndRewrite(mlir::Operation*, llvm::ArrayRef<mlir::Value>, mlir::ConversionPatternRewriter&) const /usr/local/google/home/karupayun/.triton/llvm/llvm+mlir-17.0.0-x86_64-linux-gnu-ubuntu-18.04-release/include/mlir/Conversion/LLVMCommon/Pattern.h:170:3 #16 0x0000559614f07ba1 mlir::ConversionPattern::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&) const (build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x35c9ba1) #17 0x0000559614f4d022 mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>) (build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x360f022) #18 0x0000559614f13439 (anonymous namespace)::OperationLegalizer::legalize(mlir::Operation*, mlir::ConversionPatternRewriter&) DialectConversion.cpp:0:0 #19 0x0000559614f13a10 (anonymous namespace)::OperationConverter::convertOperations(llvm::ArrayRef<mlir::Operation*>, llvm::function_ref<void (mlir::Diagnostic&)>) DialectConversion.cpp:0:0 #20 0x0000559614f15a40 mlir::applyPartialConversion(mlir::Operation*, mlir::ConversionTarget&, mlir::FrozenRewritePatternSet const&, llvm::DenseSet<mlir::Operation*, llvm::DenseMapInfo<mlir::Operation*, void> >*) (build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x35d7a40) #21 0x0000559612a36192 (anonymous namespace)::ConvertTritonGPUToLLVM::runOnOperation() /usr/local/google/home/karupayun/projects/triton/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp:544:15 #22 0x0000559612b39bd1 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) (build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x11fbbd1) #23 0x0000559612b3a421 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) (build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x11fc421) #24 0x0000559612b3af8a mlir::PassManager::run(mlir::Operation*) (build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x11fcf8a) #25 0x0000559612b2bf9b performActions(llvm::raw_ostream&, std::shared_ptr<llvm::SourceMgr> const&, mlir::MLIRContext*, mlir::MlirOptMainConfig const&) MlirOptMain.cpp:0:0 #26 0x0000559612b2cad5 processBuffer(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, mlir::MlirOptMainConfig const&, mlir::DialectRegistry&, llvm::ThreadPool*) MlirOptMain.cpp:0:0 #27 0x0000559612b2cbc0 mlir::LogicalResult llvm::function_ref<mlir::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, llvm::raw_ostream&)>::callback_fn<mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::'lambda'(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, llvm::raw_ostream&)>(long, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, llvm::raw_ostream&) MlirOptMain.cpp:0:0 #28 0x0000559615457105 mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, llvm::function_ref<mlir::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, llvm::raw_ostream&)>, llvm::raw_ostream&, bool, bool) (build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x3b19105) #29 0x0000559612b2aa43 mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) (build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x11eca43) #30 0x0000559612b2cef3 mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&) (build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x11eeef3) #31 0x0000559611c2dad9 std::vector<std::unique_ptr<mlir::DialectExtensionBase, std::default_delete<mlir::DialectExtensionBase> >, std::allocator<std::unique_ptr<mlir::DialectExtensionBase, std::default_delete<mlir::DialectExtensionBase> > > >::~vector() /usr/include/c++/11/bits/stl_vector.h:680:15 #32 0x0000559611c2dad9 mlir::DialectRegistry::~DialectRegistry() /usr/local/google/home/karupayun/.triton/llvm/llvm+mlir-17.0.0-x86_64-linux-gnu-ubuntu-18.04-release/include/mlir/IR/DialectRegistry.h:109:7 #33 0x0000559611c2dad9 main /usr/local/google/home/karupayun/projects/triton/bin/triton-opt.cpp:11:1 #34 0x00007f99ef7666ca __libc_start_call_main ./csu/../sysdeps/nptl/libc_start_call_main.h:74:3 #35 0x00007f99ef766785 call_init ./csu/../csu/libc-start.c:128:20 #36 0x00007f99ef766785 __libc_start_main ./csu/../csu/libc-start.c:347:5 #37 0x0000559611cc91b1 _start (build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x38b1b1) Segmentation fault
I think 8-bit, 32-bit mixed precision is not well supported yet.
To get around the issue, without performance guarantee, you can try allow_tf32=False
allow_tf32=False
We created the following TTIR, which was previously working for us. However, it is now segfaulting at HEAD on main.
The TTIR is essentially loading 2 parameters (one s8, one f32), convert the s8 to a f32, then dot.
I did a bisect and narrowed down the commit where the segfault started to https://github.com/openai/triton/commit/52c146f66b79b6079bcd28c55312fc6ea1852519. The issue can be solved by setting
int minBitwidth = 0;
(in line https://github.com/openai/triton/blob/17d633a64e43337037d2e873b029fab92422762f/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp#L299) but also commenting https://github.com/openai/triton/blob/17d633a64e43337037d2e873b029fab92422762f/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp#L250 as we do internally.I attached a lit test that runs the entire pipeline for a convenient repro: crash.mlir.txt
Segfault stacktrace: