triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.73k stars 1.54k forks source link

Assertion error from linear layouts #4727

Closed peterbell10 closed 1 day ago

peterbell10 commented 4 days ago

I am running into an assertion error in the codegen for local_load which is coming from the linear layouts code. Here is a minified reproducer

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 2056 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
  tt.func public @test_fn() attributes {noinline = false} {
    %0 = triton_gpu.local_alloc  { allocation.offset = 0 : i32} : () -> !tt.memdesc<4x128xf32, #shared, #triton_gpu.shared_memory, mutable>
    %1 = triton_gpu.local_load %0 : !tt.memdesc<4x128xf32, #shared, #triton_gpu.shared_memory, mutable> -> tensor<4x128xf32, #blocked>
    tt.return
  }
}

When lowering to llvm ir it fails with the following error

$ triton-opt --convert-triton-gpu-to-llvm repro.ttgir

triton-opt: /root/code/triton/lib/Tools/LinearLayout.cpp:512: mlir::triton::LinearLayout mlir::triton::LinearLayout::reshapeIns(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const: Assertion `getTotalInDimSize() == std::accumulate(newInDims.begin(), newInDims.end(), 1, [&](int32_t acc, auto &inDim) { return acc * inDim.second; })' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.  Program arguments: /root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt --convert-triton-gpu-to-llvm repro.ttgir
 #0 0x00005621e7032ff7 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x2c94ff7)
 #1 0x00005621e7030b1e llvm::sys::RunSignalHandlers() (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x2c92b1e)
 #2 0x00005621e70336af SignalHandler(int) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x2c956af)
 #3 0x00007f20a484c420 __restore_rt (/usr/lib/x86_64-linux-gnu/libpthread.so.0+0x14420)
 #4 0x00007f20a431900b raise /build/glibc-LcI20x/glibc-2.31/signal/../sysdeps/unix/sysv/linux/raise.c:51:1
 #5 0x00007f20a42f8859 abort /build/glibc-LcI20x/glibc-2.31/stdlib/abort.c:81:7
 #6 0x00007f20a42f8729 get_sysdep_segment_value /build/glibc-LcI20x/glibc-2.31/intl/loadmsgcat.c:509:8
 #7 0x00007f20a42f8729 _nl_load_domain /build/glibc-LcI20x/glibc-2.31/intl/loadmsgcat.c:970:34
 #8 0x00007f20a4309fd6 (/usr/lib/x86_64-linux-gnu/libc.so.6+0x33fd6)
 #9 0x00005621e4aac52a mlir::triton::LinearLayout::reshapeIns(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const /root/code/triton/lib/Tools/LinearLayout.cpp:520:37
#10 0x00005621e46b26dd mlir::emitTransferBetweenRegistersAndShared(mlir::RankedTensorType, mlir::triton::MemDescType, mlir::Type, std::optional<int>, mlir::Value, llvm::ArrayRef<mlir::Value>, mlir::Location, mlir::RewriterBase&, mlir::triton::TargetInfoBase const&, std::function<void (mlir::VectorType, mlir::Value)>) /root/code/triton/lib/Conversion/TritonGPUToLLVM/Utility.cpp:307:61
#11 0x00005621e46b31f3 mlir::loadSharedToDistributed(mlir::RankedTensorType, mlir::triton::MemDescType, mlir::Type, mlir::LLVM::SharedMemoryObject, mlir::Location, mlir::RewriterBase&, mlir::triton::TargetInfoBase const&) /root/code/triton/lib/Conversion/TritonGPUToLLVM/Utility.cpp:386:55
#12 0x00005621e47a4185 (anonymous namespace)::LocalLoadOpConversion::lowerSharedToDistributed(mlir::triton::gpu::LocalLoadOp, mlir::triton::gpu::LocalLoadOpAdaptor, mlir::LLVMTypeConverter const*, mlir::ConversionPatternRewriter&) const /root/code/triton/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp:172:69
#13 0x00005621e47a3d05 (anonymous namespace)::LocalLoadOpConversion::matchAndRewrite(mlir::triton::gpu::LocalLoadOp, mlir::triton::gpu::LocalLoadOpAdaptor, mlir::ConversionPatternRewriter&) const /root/code/triton/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp:124:47
#14 0x00005621e47ac85d mlir::ConvertOpToLLVMPattern<mlir::triton::gpu::LocalLoadOp>::matchAndRewrite(mlir::Operation*, llvm::ArrayRef<mlir::Value>, mlir::ConversionPatternRewriter&) const /root/.triton/llvm/llvm-c08c6a71-ubuntu-x64/include/mlir/Conversion/LLVMCommon/Pattern.h:166:77
#15 0x00005621e6b3bd10 mlir::ConversionPattern::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&) const (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x279dd10)
#16 0x00005621e6b7a65b mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<llvm::LogicalResult (mlir::Pattern const&)>)::$_2::operator()() const (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x27dc65b)
#17 0x00005621e6b771df mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<llvm::LogicalResult (mlir::Pattern const&)>) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x27d91df)
#18 0x00005621e6b3cca1 (anonymous namespace)::OperationLegalizer::legalize(mlir::Operation*, mlir::ConversionPatternRewriter&) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x279eca1)
#19 0x00005621e6b3bdb4 mlir::OperationConverter::convert(mlir::ConversionPatternRewriter&, mlir::Operation*) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x279ddb4)
#20 0x00005621e6b3d1bf mlir::OperationConverter::convertOperations(llvm::ArrayRef<mlir::Operation*>) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x279f1bf)
#21 0x00005621e6b438fb mlir::applyPartialConversion(mlir::Operation*, mlir::ConversionTarget const&, mlir::FrozenRewritePatternSet const&, mlir::ConversionConfig) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x27a58fb)
#22 0x00005621e4d6e312 (anonymous namespace)::ConvertTritonGPUToLLVM::runOnOperation() /root/code/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp:178:15
#23 0x00005621e6081996 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1ce3996)
#24 0x00005621e6082140 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1ce4140)
#25 0x00005621e60845f5 mlir::PassManager::run(mlir::Operation*) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1ce65f5)
#26 0x00005621e607dccf performActions(llvm::raw_ostream&, std::shared_ptr<llvm::SourceMgr> const&, mlir::MLIRContext*, mlir::MlirOptMainConfig const&) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1cdfccf)
#27 0x00005621e607d8fd llvm::LogicalResult llvm::function_ref<llvm::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, llvm::raw_ostream&)>::callback_fn<mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::$_2>(long, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, llvm::raw_ostream&) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1cdf8fd)
#28 0x00005621e6fb2656 mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, llvm::function_ref<llvm::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x2c14656)
#29 0x00005621e6078721 mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1cda721)
#30 0x00005621e60789d3 mlir::MlirOptMain(int, char**, llvm::StringRef, llvm::StringRef, mlir::DialectRegistry&) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1cda9d3)
#31 0x00005621e6078da6 mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1cdada6)
#32 0x00005621e4e72ad0 main /root/code/triton/bin/triton-opt.cpp:9:0
#33 0x00007f20a42fa083 __libc_start_main /build/glibc-LcI20x/glibc-2.31/csu/../csu/libc-start.c:342:3
#34 0x00005621e468707e _start (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x2e907e)

cc @Jokeren @jlebar

Jokeren commented 4 days ago

Just to confirm, the TritonGPU IR is generated from valid Triton python code?

peterbell10 commented 4 days ago

It's came from the lowering from a new operator I'm adding, but I'll see if I can reproduce with an existing operator.

peterbell10 commented 4 days ago

This produces the same error on the current master branch

import triton.language as tl
import triton
import torch

@triton.jit
def test_fn(out_ptr, a_ptr, workspace, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
    desc_ptr = workspace
    tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=desc_ptr, global_address=a_ptr, load_size=[4, N_BLOCK], global_size=[M, N], element_ty=a_ptr.dtype.element_ty)
    tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(desc_ptr)

    gather = tl._experimental_descriptor_load(desc_ptr, [0, 0], [4, N_BLOCK], a_ptr.dtype.element_ty)
    tl.store(out_ptr + tl.arange(0, 4)[:, None] * N_BLOCK + tl.arange(0, N_BLOCK)[None, :], gather)

out = torch.empty((4, 128), dtype=torch.float32, device="cuda")
inp = torch.arange(4 * 128, dtype=torch.float32, device="cuda").reshape(4, 128)
workspace = torch.empty(128, dtype=torch.uint8, device="cuda")
test_fn[(1,)](out, inp, workspace, 4, 128, 4, 128)
Jokeren commented 4 days ago

I'll take a look today