microsoft / triton-shared

Shared Middle-Layer for Triton Compilation
MIT License
142 stars 27 forks source link

addptr operand produced by an unsupported operation: `select` #15

Open manbearian opened 9 months ago

manbearian commented 9 months ago

created from #7.

I don't know what original Triton code looked like that created this, but there is a select in the address expression.

Not in the error message, but when i looked at the provided test case i see an atomic operation which is also not currently supported.

repro.zip

triton-shared-opt -triton-to-linalg 43.mlir

Error output:

%59 = "arith.select"(%57, %58, %45) {MetaUse} : (tensor<1x1xi1>, tensor<1x1xi64>, tensor<1x1xi64>) -> tensor<1x1xi64>
encountered addptr operand produced by an unsupported operation
UNREACHABLE executed at /home/ianb/src/triton/third_party/triton_shared/lib/Analysis/PtrAnalysis.cpp:641!
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.  Program arguments: build/cmake.linux-x86_64-cpython-3.8/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt -triton-to-linalg /home/ianb/test/ttirs_linalg_failed/43.mlir
 #0 0x000055bbe74f737b llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (build/cmake.linux-x86_64-cpython-3.8/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x483637b)
 #1 0x000055bbe74f50b4 SignalHandler(int) Signals.cpp:0:0
 #2 0x00007fc951f661f0 __restore_rt (/lib/x86_64-linux-gnu/libpthread.so.0+0x141f0)
 #3 0x00007fc951a10fbb raise ./signal/../sysdeps/unix/sysv/linux/raise.c:50:1
 #4 0x00007fc9519f6864 abort ./stdlib/abort.c:81:7
 #5 0x000055bbe74346aa (build/cmake.linux-x86_64-cpython-3.8/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x47736aa)
 #6 0x000055bbe414d0ab /home/ianb/src/triton/third_party/triton_shared/lib/Analysis/PtrAnalysis.cpp:640:5
 #7 0x000055bbe414d164 llvm::SmallVectorTemplateCommon<mlir::OpFoldResult, void>::getFirstEl() const /home/ianb/.triton/llvm/llvm+mlir-17.0.0-x86_64-linux-gnu-ubuntu-18.04-release/include/llvm/ADT/SmallVector.h:133:46
 #8 0x000055bbe414d164 llvm::SmallVectorTemplateCommon<mlir::OpFoldResult, void>::SmallVectorTemplateCommon(unsigned long) /home/ianb/.triton/llvm/llvm+mlir-17.0.0-x86_64-linux-gnu-ubuntu-18.04-release/include/llvm/ADT/SmallVector.h:138:49
 #9 0x000055bbe414d164 llvm::SmallVectorTemplateBase<mlir::OpFoldResult, true>::SmallVectorTemplateBase(unsigned long) /home/ianb/.triton/llvm/llvm+mlir-17.0.0-x86_64-linux-gnu-ubuntu-18.04-release/include/llvm/ADT/SmallVector.h:491:42
#10 0x000055bbe414d164 llvm::SmallVectorImpl<mlir::OpFoldResult>::SmallVectorImpl(unsigned int) /home/ianb/.triton/llvm/llvm+mlir-17.0.0-x86_64-linux-gnu-ubuntu-18.04-release/include/llvm/ADT/SmallVector.h:592:9
#11 0x000055bbe414d164 llvm::SmallVector<mlir::OpFoldResult, 6u>::SmallVector() /home/ianb/.triton/llvm/llvm+mlir-17.0.0-x86_64-linux-gnu-ubuntu-18.04-release/include/llvm/ADT/SmallVector.h:1202:19
#12 0x000055bbe414d164 mlir::triton::PtrState::PtrState() /home/ianb/src/triton/third_party/triton_shared/include/triton-shared/Analysis/PtrAnalysis.h:41:7
#13 0x000055bbe414d164 mlir::triton::PtrAnalysis::visitOperandMul(mlir::arith::MulIOp, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/ianb/src/triton/third_party/triton_shared/lib/Analysis/PtrAnalysis.cpp:356:12
#14 0x000055bbe414ce62 mlir::triton::PtrAnalysis::visitOperand(mlir::Value, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/ianb/src/triton/third_party/triton_shared/lib/Analysis/PtrAnalysis.cpp:0:5
#15 0x000055bbe414ca70 llvm::SmallVectorBase<unsigned int>::size() const /home/ianb/.triton/llvm/llvm+mlir-17.0.0-x86_64-linux-gnu-ubuntu-18.04-release/include/llvm/ADT/SmallVector.h:91:32
#16 0x000055bbe414ca70 mlir::triton::PtrState::getRank() const /home/ianb/src/triton/third_party/triton_shared/lib/Analysis/PtrAnalysis.cpp:48:3
#17 0x000055bbe414ca70 mlir::triton::PtrAnalysis::visitOperandAdd(mlir::arith::AddIOp, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/ianb/src/triton/third_party/triton_shared/lib/Analysis/PtrAnalysis.cpp:341:17
#18 0x000055bbe414cde1 mlir::triton::PtrAnalysis::visitOperand(mlir::Value, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/ianb/src/triton/third_party/triton_shared/lib/Analysis/PtrAnalysis.cpp:0:5
#19 0x000055bbe414ec16 mlir::Value::operator bool() const /home/ianb/.triton/llvm/llvm+mlir-17.0.0-x86_64-linux-gnu-ubuntu-18.04-release/include/mlir/IR/Value.h:117:43
#20 0x000055bbe414ec16 mlir::triton::PtrAnalysis::visitOperandAddptr(mlir::triton::AddPtrOp, mlir::triton::PtrState&, mlir::Location, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>> const&) /home/ianb/src/triton/third_party/triton_shared/lib/Analysis/PtrAnalysis.cpp:528:3
#21 0x000055bbe414f714 mlir::triton::PtrAnalysis::rewriteAddptrOp(mlir::triton::AddPtrOp, mlir::ConversionPatternRewriter&, llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>>&) /home/ianb/src/triton/third_party/triton_shared/lib/Analysis/PtrAnalysis.cpp:689:26
#22 0x000055bbe407b3e1 llvm::SmallDenseMap<mlir::Value, mlir::triton::PtrState, 4u, llvm::DenseMapInfo<mlir::Value, void>, llvm::detail::DenseMapPair<mlir::Value, mlir::triton::PtrState>>::~SmallDenseMap() /home/ianb/.triton/llvm/llvm+mlir-17.0.0-x86_64-linux-gnu-ubuntu-18.04-release/include/llvm/ADT/DenseMap.h:960:11
#23 0x000055bbe407b3e1 (anonymous namespace)::AddPtrConverter::matchAndRewrite(mlir::triton::AddPtrOp, mlir::triton::AddPtrOpAdaptor, mlir::ConversionPatternRewriter&) const /home/ianb/src/triton/third_party/triton_shared/lib/Conversion/TritonToLinalg/TritonToLinalg.cpp:436:3
#24 0x000055bbe3f6d2a7 mlir::OpConversionPattern<mlir::triton::AddPtrOp>::matchAndRewrite(mlir::Operation*, llvm::ArrayRef<mlir::Value>, mlir::ConversionPatternRewriter&) const /home/ianb/.triton/llvm/llvm+mlir-17.0.0-x86_64-linux-gnu-ubuntu-18.04-release/include/mlir/Transforms/DialectConversion.h:536:73
#25 0x000055bbe664a5b1 mlir::ConversionPattern::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&) const (build/cmake.linux-x86_64-cpython-3.8/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x39895b1)
#26 0x000055bbe66966b2 mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>) (build/cmake.linux-x86_64-cpython-3.8/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x39d56b2)
#27 0x000055bbe6656d19 (anonymous namespace)::OperationLegalizer::legalize(mlir::Operation*, mlir::ConversionPatternRewriter&) DialectConversion.cpp:0:0
#28 0x000055bbe66572f0 (anonymous namespace)::OperationConverter::convertOperations(llvm::ArrayRef<mlir::Operation*>, llvm::function_ref<void (mlir::Diagnostic&)>) DialectConversion.cpp:0:0
#29 0x000055bbe66596b0 mlir::applyFullConversion(mlir::Operation*, mlir::ConversionTarget&, mlir::FrozenRewritePatternSet const&) (build/cmake.linux-x86_64-cpython-3.8/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x39986b0)
#30 0x000055bbe40731b2 (anonymous namespace)::TritonToLinalgPass::runOnOperation() /home/ianb/src/triton/third_party/triton_shared/lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp:194:16
#31 0x000055bbe40a9991 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) (build/cmake.linux-x86_64-cpython-3.8/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x13e8991)
#32 0x000055bbe40aa1e1 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) (build/cmake.linux-x86_64-cpython-3.8/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x13e91e1)
#33 0x000055bbe40aad4a mlir::PassManager::run(mlir::Operation*) (build/cmake.linux-x86_64-cpython-3.8/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x13e9d4a)
#34 0x000055bbe409af7b performActions(llvm::raw_ostream&, std::shared_ptr<llvm::SourceMgr> const&, mlir::MLIRContext*, mlir::MlirOptMainConfig const&) MlirOptMain.cpp:0:0
#35 0x000055bbe409bab5 processBuffer(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, mlir::MlirOptMainConfig const&, mlir::DialectRegistry&, llvm::ThreadPool*) MlirOptMain.cpp:0:0
#36 0x000055bbe409bba0 mlir::LogicalResult llvm::function_ref<mlir::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>::callback_fn<mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::'lambda'(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>(long, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&) MlirOptMain.cpp:0:0
#37 0x000055bbe73e24c5 mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::function_ref<mlir::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>, llvm::raw_ostream&, bool, bool) (build/cmake.linux-x86_64-cpython-3.8/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x47214c5)
#38 0x000055bbe4099aa3 mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) (build/cmake.linux-x86_64-cpython-3.8/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x13d8aa3)
#39 0x000055bbe409bed3 mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&) (build/cmake.linux-x86_64-cpython-3.8/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x13daed3)
#40 0x000055bbe309cd5b main /home/ianb/src/triton/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt.cpp:16:33
#41 0x00007fc9519f8565 __libc_start_main ./csu/../csu/libc-start.c:332:16
#42 0x000055bbe309cc5e _start (build/cmake.linux-x86_64-cpython-3.8/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt+0x3dbc5e)
yuanfz98 commented 8 months ago

@nhat-nguyen Would you mind taking a look at this issue ?

def triton_(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, out_ptr2, xnumel, rnumel):
    xnumel = 16384
    XBLOCK: tl.constexpr = 1
    rnumel = 384
    RBLOCK: tl.constexpr = 512
    xoffset = tl.program_id(0) * XBLOCK
    xindex = tl.full([1], xoffset, tl.int32)
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[:]
    rmask = rindex < rnumel
    r1 = rindex
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r1 + (384*x0)), rmask, other=0).to(tl.float32)
    tmp2 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0)
    tmp8 = tl.load(in_ptr2 + (r1 + (384*x0)), rmask, other=0)
    tmp9 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
    tmp11 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last')
    tmp18 = tl.load(in_out_ptr0 + (r1 + (384*x0)), rmask, other=0)
    tmp27 = tl.load(in_ptr5 + (r1 + (384*x0)), rmask)
    tmp32 = tl.load(in_ptr6 + (x0), None, eviction_policy='evict_last')
    tmp1 = tmp0.to(tl.float32)
    tmp3 = tmp1 * tmp2
    tmp4 = tl.broadcast_to(tmp3, [RBLOCK])
    tmp6 = tl.where(rmask, tmp4, 0)
    tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp6, 0))
    tmp10 = tmp8 - tmp9
    tmp12 = tmp10 * tmp11
    tmp13 = tmp3 * tmp12
    tmp14 = tl.broadcast_to(tmp13, [RBLOCK])
    tmp16 = tl.where(rmask, tmp14, 0)
    tmp17 = triton_helpers.promote_to_tensor(tl.sum(tmp16, 0))
    tmp19 = 384.0
    tmp20 = tmp11 / tmp19
    tmp21 = tmp3 * tmp19
    tmp22 = tmp21 - tmp7
    tmp23 = tmp12 * tmp17
    tmp24 = tmp22 - tmp23
    tmp25 = tmp20 * tmp24
    tmp26 = tmp18 + tmp25
    tmp28 = tmp27.to(tl.float32)
    tmp29 = 1.25
    tmp30 = tmp28 * tmp29
    tmp31 = tmp26 * tmp30
    tmp33 = tl.where(tmp32 < 0, tmp32 + 65, tmp32)
    tmp34 = tl.full([1], -1, tl.int64)
    tmp35 = tmp32 == tmp34
    tmp36 = 0.0
    tmp37 = tl.where(tmp35, tmp36, tmp31)
    tl.store(in_out_ptr0 + (r1 + (384*x0)), tmp31, rmask)
    tl.atomic_add(out_ptr2 + (tl.broadcast_to(r1 + (384*tmp33), [RBLOCK])), tmp37, rmask)
module {
  tt.func public @triton__0d1d2d3d4d5d6d7d8d9de10de(%arg0: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<bf16, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<i1, 1> {tt.divisibility = 16 : i32}, %arg7: !tt.ptr<i64, 1> {tt.divisibility = 16 : i32}, %arg8: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg10: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} {
    %c384_i32 = arith.constant 384 : i32
    %cst = arith.constant 0.000000e+00 : f32
    %cst_0 = arith.constant dense<384> : tensor<1xi64>
    %cst_1 = arith.constant dense<65> : tensor<1xi64>
    %cst_2 = arith.constant dense<0> : tensor<1xi64>
    %cst_3 = arith.constant dense<0.000000e+00> : tensor<512xf32>
    %cst_4 = arith.constant dense<0.000000e+00> : tensor<512xbf16>
    %cst_5 = arith.constant dense<-1> : tensor<1xi64>
    %cst_6 = arith.constant dense<1.250000e+00> : tensor<512xf32>
    %cst_7 = arith.constant dense<3.840000e+02> : tensor<512xf32>
    %cst_8 = arith.constant dense<3.840000e+02> : tensor<1xf32>
    %cst_9 = arith.constant dense<384> : tensor<512xi32>
    %0 = tt.get_program_id x : i32
    %1 = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32>
    %2 = arith.cmpi slt, %1, %cst_9 : tensor<512xi32>
    %3 = arith.muli %0, %c384_i32 : i32
    %4 = tt.splat %3 : (i32) -> tensor<512xi32>
    %5 = arith.addi %1, %4 : tensor<512xi32>
    %6 = tt.splat %arg1 : (!tt.ptr<bf16, 1>) -> tensor<512x!tt.ptr<bf16, 1>>
    %7 = tt.addptr %6, %5 : tensor<512x!tt.ptr<bf16, 1>>, tensor<512xi32>
    %8 = tt.load %7, %2, %cst_4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<512xbf16>
    %9 = arith.extf %8 : tensor<512xbf16> to tensor<512xf32>
    %10 = tt.splat %arg2 : (!tt.ptr<f32, 1>) -> tensor<512x!tt.ptr<f32, 1>>
    %11 = tt.addptr %10, %1 : tensor<512x!tt.ptr<f32, 1>>, tensor<512xi32>
    %12 = tt.load %11, %2, %cst_3 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<512xf32>
    %13 = tt.splat %arg3 : (!tt.ptr<f32, 1>) -> tensor<512x!tt.ptr<f32, 1>>
    %14 = tt.addptr %13, %5 : tensor<512x!tt.ptr<f32, 1>>, tensor<512xi32>
    %15 = tt.load %14, %2, %cst_3 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<512xf32>
    %16 = tt.addptr %arg4, %0 : !tt.ptr<f32, 1>, i32
    %17 = tt.splat %16 : (!tt.ptr<f32, 1>) -> tensor<1x!tt.ptr<f32, 1>>
    %18 = tt.load %17 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<1xf32>
    %19 = tt.addptr %arg5, %0 : !tt.ptr<f32, 1>, i32
    %20 = tt.splat %19 : (!tt.ptr<f32, 1>) -> tensor<1x!tt.ptr<f32, 1>>
    %21 = tt.load %20 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<1xf32>
    %22 = tt.splat %arg0 : (!tt.ptr<f32, 1>) -> tensor<512x!tt.ptr<f32, 1>>
    %23 = tt.addptr %22, %5 : tensor<512x!tt.ptr<f32, 1>>, tensor<512xi32>
    %24 = tt.load %23, %2, %cst_3 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<512xf32>
    %25 = tt.splat %arg6 : (!tt.ptr<i1, 1>) -> tensor<512x!tt.ptr<i1, 1>>
    %26 = tt.addptr %25, %5 : tensor<512x!tt.ptr<i1, 1>>, tensor<512xi32>
    %27 = tt.bitcast %26 : tensor<512x!tt.ptr<i1, 1>> -> tensor<512x!tt.ptr<i8, 1>>
    %28 = tt.load %27, %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<512xi8>
    %29 = tt.addptr %arg7, %0 : !tt.ptr<i64, 1>, i32
    %30 = tt.splat %29 : (!tt.ptr<i64, 1>) -> tensor<1x!tt.ptr<i64, 1>>
    %31 = tt.load %30 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<1xi64>
    %32 = arith.mulf %9, %12 : tensor<512xf32>
    %33 = arith.select %2, %32, %cst_3 : tensor<512xi1>, tensor<512xf32>
    %34 = "tt.reduce"(%33) <{axis = 0 : i32}> ({
    ^bb0(%arg11: f32, %arg12: f32):
      %70 = arith.addf %arg11, %arg12 : f32
      tt.reduce.return %70 : f32
    }) : (tensor<512xf32>) -> f32
    %35 = arith.addf %34, %cst : f32
    %36 = tt.broadcast %18 : (tensor<1xf32>) -> tensor<512xf32>
    %37 = arith.subf %15, %36 : tensor<512xf32>
    %38 = tt.broadcast %21 : (tensor<1xf32>) -> tensor<512xf32>
    %39 = arith.mulf %37, %38 : tensor<512xf32>
    %40 = arith.mulf %32, %39 : tensor<512xf32>
    %41 = arith.select %2, %40, %cst_3 : tensor<512xi1>, tensor<512xf32>
    %42 = "tt.reduce"(%41) <{axis = 0 : i32}> ({
    ^bb0(%arg11: f32, %arg12: f32):
      %70 = arith.addf %arg11, %arg12 : f32
      tt.reduce.return %70 : f32
    }) : (tensor<512xf32>) -> f32
    %43 = arith.addf %42, %cst : f32
    %44 = arith.divf %21, %cst_8 : tensor<1xf32>
    %45 = arith.mulf %32, %cst_7 : tensor<512xf32>
    %46 = tt.splat %35 : (f32) -> tensor<512xf32>
    %47 = arith.subf %45, %46 : tensor<512xf32>
    %48 = tt.splat %43 : (f32) -> tensor<512xf32>
    %49 = arith.mulf %39, %48 : tensor<512xf32>
    %50 = arith.subf %47, %49 : tensor<512xf32>
    %51 = tt.broadcast %44 : (tensor<1xf32>) -> tensor<512xf32>
    %52 = arith.mulf %51, %50 : tensor<512xf32>
    %53 = arith.addf %24, %52 : tensor<512xf32>
    %54 = arith.sitofp %28 : tensor<512xi8> to tensor<512xf32>
    %55 = arith.mulf %54, %cst_6 : tensor<512xf32>
    %56 = arith.mulf %53, %55 : tensor<512xf32>
    %57 = arith.cmpi slt, %31, %cst_2 : tensor<1xi64>
    %58 = arith.addi %31, %cst_1 : tensor<1xi64>
    %59 = arith.select %57, %58, %31 : tensor<1xi1>, tensor<1xi64>
    %60 = arith.cmpi eq, %31, %cst_5 : tensor<1xi64>
    %61 = tt.broadcast %60 : (tensor<1xi1>) -> tensor<512xi1>
    %62 = arith.select %61, %cst_3, %56 : tensor<512xi1>, tensor<512xf32>
    tt.store %23, %56, %2 {cache = 1 : i32, evict = 1 : i32} : tensor<512xf32>
    %63 = arith.muli %59, %cst_0 : tensor<1xi64>
    %64 = tt.broadcast %63 : (tensor<1xi64>) -> tensor<512xi64>
    %65 = arith.extsi %1 : tensor<512xi32> to tensor<512xi64>
    %66 = arith.addi %65, %64 : tensor<512xi64>
    %67 = tt.splat %arg8 : (!tt.ptr<f32, 1>) -> tensor<512x!tt.ptr<f32, 1>>
    %68 = tt.addptr %67, %66 : tensor<512x!tt.ptr<f32, 1>>, tensor<512xi64>
    %69 = "tt.atomic_rmw"(%68, %62, %2) <{atomic_rmw_op = 5 : i32, scope = 1 : i32, sem = 4 : i32}> : (tensor<512x!tt.ptr<f32, 1>>, tensor<512xf32>, tensor<512xi1>) -> tensor<512xf32>
    tt.return
  }
}

We should support :

  1. visitOperandExtSI, which is fine
  2. visitOperandSelect, visitOperandCmpI, visitOperandExtSI, visitOperandLoad. None is evident enough for me to resolve. Could you please give me some suggestions ?
  3. tt.atomic_rmw which may be roughly mapped to memref.atomic_rmw, which is fine
nhat-nguyen commented 8 months ago

@yuanfz98 Sorry for the late response.

The case you're hitting is basically the limit of our PtrAnalysis pass. We attempt to convert triton pointer load on a best effort basis: pointers that point to contiguous memory that can be determined at compile time will be converted to a single memref load. The common cases almost always involve having tl.arange as the base case during the recursion, that's where we set up the initial offsets, strides, and sizes of the memref to load from.

For this particular case, the offsets being used by tl.atomic_add(out_ptr2 + (tl.broadcast_to(r1 + (384*tmp33), [RBLOCK])), tmp37, rmask) are obtained from another load:

tmp32 = tl.load(in_ptr6 + (x0), None, eviction_policy='evict_last')

It is impossible to know at compile time whether tmp3 only contains offsets that are contiguous, and so we can't really support this case with the current approach. If we somehow have a user annotation that guarantees tmp3 are contiguous, then it will be possible.

Another feature that we have in mind but isn't the priority is to have a fallback mode; this mode will support cases where PtrAnalysis fails at the expense of performance. This will probably involve loading each individual element one at a time and storing them into the local buffer. It would take me some time to give this some more thoughts on how we can make this possible (I'll be away all December). In the meantime, if you have any ideas or further thoughts on this, we would love to hear in the discussions tab.

Now, for the select op, technically we can support it provided that both the true and false values don't end up being produced by a load. I think the code would look something like:

visitSelect() {

    PtrState trueValState;
    visit(trueVal);

    PtrState falseValState;
    visit(falseVal);

    PtrState cond;
    visit(cond);

    // now we need to generate code to select the correct strides and offsets depending on the condition
}