triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.56k stars 1.67k forks source link

Assertion error from linear layouts #4727

Open peterbell10 opened 2 months ago

peterbell10 commented 2 months ago

I am running into an assertion error in the codegen for local_load which is coming from the linear layouts code. Here is a minified reproducer

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 2056 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
  tt.func public @test_fn() attributes {noinline = false} {
    %0 = triton_gpu.local_alloc  { allocation.offset = 0 : i32} : () -> !tt.memdesc<4x128xf32, #shared, #triton_gpu.shared_memory, mutable>
    %1 = triton_gpu.local_load %0 : !tt.memdesc<4x128xf32, #shared, #triton_gpu.shared_memory, mutable> -> tensor<4x128xf32, #blocked>
    tt.return
  }
}

When lowering to llvm ir it fails with the following error

$ triton-opt --convert-triton-gpu-to-llvm repro.ttgir

triton-opt: /root/code/triton/lib/Tools/LinearLayout.cpp:512: mlir::triton::LinearLayout mlir::triton::LinearLayout::reshapeIns(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const: Assertion `getTotalInDimSize() == std::accumulate(newInDims.begin(), newInDims.end(), 1, [&](int32_t acc, auto &inDim) { return acc * inDim.second; })' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.  Program arguments: /root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt --convert-triton-gpu-to-llvm repro.ttgir
 #0 0x00005621e7032ff7 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x2c94ff7)
 #1 0x00005621e7030b1e llvm::sys::RunSignalHandlers() (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x2c92b1e)
 #2 0x00005621e70336af SignalHandler(int) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x2c956af)
 #3 0x00007f20a484c420 __restore_rt (/usr/lib/x86_64-linux-gnu/libpthread.so.0+0x14420)
 #4 0x00007f20a431900b raise /build/glibc-LcI20x/glibc-2.31/signal/../sysdeps/unix/sysv/linux/raise.c:51:1
 #5 0x00007f20a42f8859 abort /build/glibc-LcI20x/glibc-2.31/stdlib/abort.c:81:7
 #6 0x00007f20a42f8729 get_sysdep_segment_value /build/glibc-LcI20x/glibc-2.31/intl/loadmsgcat.c:509:8
 #7 0x00007f20a42f8729 _nl_load_domain /build/glibc-LcI20x/glibc-2.31/intl/loadmsgcat.c:970:34
 #8 0x00007f20a4309fd6 (/usr/lib/x86_64-linux-gnu/libc.so.6+0x33fd6)
 #9 0x00005621e4aac52a mlir::triton::LinearLayout::reshapeIns(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const /root/code/triton/lib/Tools/LinearLayout.cpp:520:37
#10 0x00005621e46b26dd mlir::emitTransferBetweenRegistersAndShared(mlir::RankedTensorType, mlir::triton::MemDescType, mlir::Type, std::optional<int>, mlir::Value, llvm::ArrayRef<mlir::Value>, mlir::Location, mlir::RewriterBase&, mlir::triton::TargetInfoBase const&, std::function<void (mlir::VectorType, mlir::Value)>) /root/code/triton/lib/Conversion/TritonGPUToLLVM/Utility.cpp:307:61
#11 0x00005621e46b31f3 mlir::loadSharedToDistributed(mlir::RankedTensorType, mlir::triton::MemDescType, mlir::Type, mlir::LLVM::SharedMemoryObject, mlir::Location, mlir::RewriterBase&, mlir::triton::TargetInfoBase const&) /root/code/triton/lib/Conversion/TritonGPUToLLVM/Utility.cpp:386:55
#12 0x00005621e47a4185 (anonymous namespace)::LocalLoadOpConversion::lowerSharedToDistributed(mlir::triton::gpu::LocalLoadOp, mlir::triton::gpu::LocalLoadOpAdaptor, mlir::LLVMTypeConverter const*, mlir::ConversionPatternRewriter&) const /root/code/triton/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp:172:69
#13 0x00005621e47a3d05 (anonymous namespace)::LocalLoadOpConversion::matchAndRewrite(mlir::triton::gpu::LocalLoadOp, mlir::triton::gpu::LocalLoadOpAdaptor, mlir::ConversionPatternRewriter&) const /root/code/triton/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp:124:47
#14 0x00005621e47ac85d mlir::ConvertOpToLLVMPattern<mlir::triton::gpu::LocalLoadOp>::matchAndRewrite(mlir::Operation*, llvm::ArrayRef<mlir::Value>, mlir::ConversionPatternRewriter&) const /root/.triton/llvm/llvm-c08c6a71-ubuntu-x64/include/mlir/Conversion/LLVMCommon/Pattern.h:166:77
#15 0x00005621e6b3bd10 mlir::ConversionPattern::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&) const (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x279dd10)
#16 0x00005621e6b7a65b mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<llvm::LogicalResult (mlir::Pattern const&)>)::$_2::operator()() const (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x27dc65b)
#17 0x00005621e6b771df mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<llvm::LogicalResult (mlir::Pattern const&)>) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x27d91df)
#18 0x00005621e6b3cca1 (anonymous namespace)::OperationLegalizer::legalize(mlir::Operation*, mlir::ConversionPatternRewriter&) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x279eca1)
#19 0x00005621e6b3bdb4 mlir::OperationConverter::convert(mlir::ConversionPatternRewriter&, mlir::Operation*) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x279ddb4)
#20 0x00005621e6b3d1bf mlir::OperationConverter::convertOperations(llvm::ArrayRef<mlir::Operation*>) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x279f1bf)
#21 0x00005621e6b438fb mlir::applyPartialConversion(mlir::Operation*, mlir::ConversionTarget const&, mlir::FrozenRewritePatternSet const&, mlir::ConversionConfig) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x27a58fb)
#22 0x00005621e4d6e312 (anonymous namespace)::ConvertTritonGPUToLLVM::runOnOperation() /root/code/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp:178:15
#23 0x00005621e6081996 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1ce3996)
#24 0x00005621e6082140 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1ce4140)
#25 0x00005621e60845f5 mlir::PassManager::run(mlir::Operation*) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1ce65f5)
#26 0x00005621e607dccf performActions(llvm::raw_ostream&, std::shared_ptr<llvm::SourceMgr> const&, mlir::MLIRContext*, mlir::MlirOptMainConfig const&) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1cdfccf)
#27 0x00005621e607d8fd llvm::LogicalResult llvm::function_ref<llvm::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, llvm::raw_ostream&)>::callback_fn<mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::$_2>(long, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, llvm::raw_ostream&) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1cdf8fd)
#28 0x00005621e6fb2656 mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, llvm::function_ref<llvm::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x2c14656)
#29 0x00005621e6078721 mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer> >, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1cda721)
#30 0x00005621e60789d3 mlir::MlirOptMain(int, char**, llvm::StringRef, llvm::StringRef, mlir::DialectRegistry&) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1cda9d3)
#31 0x00005621e6078da6 mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&) (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x1cdada6)
#32 0x00005621e4e72ad0 main /root/code/triton/bin/triton-opt.cpp:9:0
#33 0x00007f20a42fa083 __libc_start_main /build/glibc-LcI20x/glibc-2.31/csu/../csu/libc-start.c:342:3
#34 0x00005621e468707e _start (/root/code/triton/python/build/cmake.linux-x86_64-cpython-3.11/bin/triton-opt+0x2e907e)

cc @Jokeren @jlebar

Jokeren commented 2 months ago

Just to confirm, the TritonGPU IR is generated from valid Triton python code?

peterbell10 commented 2 months ago

It's came from the lowering from a new operator I'm adding, but I'll see if I can reproduce with an existing operator.

peterbell10 commented 2 months ago

This produces the same error on the current master branch

import triton.language as tl
import triton
import torch

@triton.jit
def test_fn(out_ptr, a_ptr, workspace, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
    desc_ptr = workspace
    tl.extra.cuda.experimental_device_tensormap_create2d(desc_ptr=desc_ptr, global_address=a_ptr, load_size=[4, N_BLOCK], global_size=[M, N], element_ty=a_ptr.dtype.element_ty)
    tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(desc_ptr)

    gather = tl._experimental_descriptor_load(desc_ptr, [0, 0], [4, N_BLOCK], a_ptr.dtype.element_ty)
    tl.store(out_ptr + tl.arange(0, 4)[:, None] * N_BLOCK + tl.arange(0, N_BLOCK)[None, :], gather)

out = torch.empty((4, 128), dtype=torch.float32, device="cuda")
inp = torch.arange(4 * 128, dtype=torch.float32, device="cuda").reshape(4, 128)
workspace = torch.empty(128, dtype=torch.uint8, device="cuda")
test_fn[(1,)](out, inp, workspace, 4, 128, 4, 128)
Jokeren commented 2 months ago

I'll take a look today

peterbell10 commented 2 months ago

Reopening this as it seems the TMA hardware does support swizzling with only 4 rows of data.

I get this result if it's helpful:

unswizzled:
tensor([[  0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,  11.,
          12.,  13.,  14.,  15.,  16.,  17.,  18.,  19.,  20.,  21.,  22.,  23.,
          24.,  25.,  26.,  27.,  28.,  29.,  30.,  31.,  32.,  33.,  34.,  35.,
          36.,  37.,  38.,  39.,  40.,  41.,  42.,  43.,  44.,  45.,  46.,  47.,
          48.,  49.,  50.,  51.,  52.,  53.,  54.,  55.,  56.,  57.,  58.,  59.,
          60.,  61.,  62.,  63.],
        [128., 129., 130., 131., 132., 133., 134., 135., 136., 137., 138., 139.,
         140., 141., 142., 143., 144., 145., 146., 147., 148., 149., 150., 151.,
         152., 153., 154., 155., 156., 157., 158., 159., 160., 161., 162., 163.,
         164., 165., 166., 167., 168., 169., 170., 171., 172., 173., 174., 175.,
         176., 177., 178., 179., 180., 181., 182., 183., 184., 185., 186., 187.,
         188., 189., 190., 191.],
        [256., 257., 258., 259., 260., 261., 262., 263., 264., 265., 266., 267.,
         268., 269., 270., 271., 272., 273., 274., 275., 276., 277., 278., 279.,
         280., 281., 282., 283., 284., 285., 286., 287., 288., 289., 290., 291.,
         292., 293., 294., 295., 296., 297., 298., 299., 300., 301., 302., 303.,
         304., 305., 306., 307., 308., 309., 310., 311., 312., 313., 314., 315.,
         316., 317., 318., 319.],
        [384., 385., 386., 387., 388., 389., 390., 391., 392., 393., 394., 395.,
         396., 397., 398., 399., 400., 401., 402., 403., 404., 405., 406., 407.,
         408., 409., 410., 411., 412., 413., 414., 415., 416., 417., 418., 419.,
         420., 421., 422., 423., 424., 425., 426., 427., 428., 429., 430., 431.,
         432., 433., 434., 435., 436., 437., 438., 439., 440., 441., 442., 443.,
         444., 445., 446., 447.]], device='cuda:0', dtype=torch.float16)

swizzled:
tensor([[  0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,  11.,
          12.,  13.,  14.,  15.,  16.,  17.,  18.,  19.,  20.,  21.,  22.,  23.,
          24.,  25.,  26.,  27.,  28.,  29.,  30.,  31.,  32.,  33.,  34.,  35.,
          36.,  37.,  38.,  39.,  40.,  41.,  42.,  43.,  44.,  45.,  46.,  47.,
          48.,  49.,  50.,  51.,  52.,  53.,  54.,  55.,  56.,  57.,  58.,  59.,
          60.,  61.,  62.,  63.],
        [136., 137., 138., 139., 140., 141., 142., 143., 128., 129., 130., 131.,
         132., 133., 134., 135., 152., 153., 154., 155., 156., 157., 158., 159.,
         144., 145., 146., 147., 148., 149., 150., 151., 168., 169., 170., 171.,
         172., 173., 174., 175., 160., 161., 162., 163., 164., 165., 166., 167.,
         184., 185., 186., 187., 188., 189., 190., 191., 176., 177., 178., 179.,
         180., 181., 182., 183.],
        [272., 273., 274., 275., 276., 277., 278., 279., 280., 281., 282., 283.,
         284., 285., 286., 287., 256., 257., 258., 259., 260., 261., 262., 263.,
         264., 265., 266., 267., 268., 269., 270., 271., 304., 305., 306., 307.,
         308., 309., 310., 311., 312., 313., 314., 315., 316., 317., 318., 319.,
         288., 289., 290., 291., 292., 293., 294., 295., 296., 297., 298., 299.,
         300., 301., 302., 303.],
        [408., 409., 410., 411., 412., 413., 414., 415., 400., 401., 402., 403.,
         404., 405., 406., 407., 392., 393., 394., 395., 396., 397., 398., 399.,
         384., 385., 386., 387., 388., 389., 390., 391., 440., 441., 442., 443.,
         444., 445., 446., 447., 432., 433., 434., 435., 436., 437., 438., 439.,
         424., 425., 426., 427., 428., 429., 430., 431., 416., 417., 418., 419.,
         420., 421., 422., 423.]], dtype=torch.float16)
Jokeren commented 2 months ago

I think the problem is on this line int tileRows = 8; I'll try to address it tomorrow