Open chsasank opened 4 months ago
Contributor should complete tutorials here: https://triton-lang.org/main/getting-started/tutorials/index.html
A way to look at LLVM IR generated is given here: https://github.com/triton-lang/triton?tab=readme-ov-file#tips-for-hacking
IR Dump after LLVM_IR_ENABLE_DUMP=1 python 01-vector-add.py 2> vector_add.ll
import triton
import triton.language as tl
@triton.jit
def add_kernel(x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
output_ptr, # *Pointer* to output vector.
n_elements, # Size of the vector.
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
# NOTE: `constexpr` so it can be used as a shape value.
):
# There are multiple 'programs' processing different data. We identify which program
# we are here:
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
# This program will process inputs that are offset from the initial data.
# For instance, if you had a vector of length 256 and block_size of 64, the programs
# would each access the elements [0:64, 64:128, 128:192, 192:256].
# Note that offsets is a list of pointers:
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses.
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extra elements in case the input is not a
# multiple of the block size.
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
# Write x + y back to DRAM.
tl.store(output_ptr + offsets, output, mask=mask)
Only pasting first IR
; *** IR Dump After Annotation2MetadataPass on [module] ***
; ModuleID = 'LLVMDialectModule'
source_filename = "LLVMDialectModule"
@global_smem = external addrspace(3) global [0 x i8], align 16
define void @add_kernel(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, i32 %3) !dbg !7 {
%5 = call i32 asm "mov.u32 $0, %ctaid.x;", "=r"(), !dbg !10
%6 = mul i32 %5, 1024, !dbg !11
%7 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !dbg !12
%8 = urem i32 %7, 32, !dbg !12
%9 = udiv i32 %7, 32, !dbg !12
%10 = and i32 %8, 1, !dbg !12
%11 = icmp eq i32 %10, 0, !dbg !12
%12 = select i1 %11, i32 0, i32 4, !dbg !12
%13 = xor i32 0, %12, !dbg !12
%14 = and i32 %8, 2, !dbg !12
%15 = icmp eq i32 %14, 0, !dbg !12
%16 = select i1 %15, i32 0, i32 8, !dbg !12
%17 = xor i32 %13, %16, !dbg !12
%18 = and i32 %8, 4, !dbg !12
%19 = icmp eq i32 %18, 0, !dbg !12
%20 = select i1 %19, i32 0, i32 16, !dbg !12
%21 = xor i32 %17, %20, !dbg !12
%22 = and i32 %8, 8, !dbg !12
%23 = icmp eq i32 %22, 0, !dbg !12
%24 = select i1 %23, i32 0, i32 32, !dbg !12
%25 = xor i32 %21, %24, !dbg !12
%26 = and i32 %8, 16, !dbg !12
%27 = icmp eq i32 %26, 0, !dbg !12
%28 = select i1 %27, i32 0, i32 64, !dbg !12
%29 = xor i32 %25, %28, !dbg !12
%30 = and i32 %9, 1, !dbg !12
%31 = icmp eq i32 %30, 0, !dbg !12
%32 = select i1 %31, i32 0, i32 128, !dbg !12
%33 = xor i32 %29, %32, !dbg !12
%34 = and i32 %9, 2, !dbg !12
%35 = icmp eq i32 %34, 0, !dbg !12
%36 = select i1 %35, i32 0, i32 256, !dbg !12
%37 = xor i32 %33, %36, !dbg !12
%38 = xor i32 512, %12, !dbg !12
%39 = xor i32 %38, %16, !dbg !12
%40 = xor i32 %39, %20, !dbg !12
%41 = xor i32 %40, %24, !dbg !12
%42 = xor i32 %41, %28, !dbg !12
%43 = xor i32 %42, %32, !dbg !12
%44 = xor i32 %43, %36, !dbg !12
%45 = add i32 %37, 0, !dbg !12
%46 = add i32 %44, 0, !dbg !12
%47 = add i32 %6, %45, !dbg !13
%48 = add i32 %6, %46, !dbg !13
%49 = icmp slt i32 %47, %3, !dbg !14
%50 = icmp slt i32 %48, %3, !dbg !14
%51 = getelementptr float, ptr addrspace(1) %0, i32 %47, !dbg !15
%52 = getelementptr float, ptr addrspace(1) %0, i32 %48, !dbg !15
%53 = call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09@$5 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l,b"(ptr addrspace(1) %51, i1 %49), !dbg !16
%54 = extractvalue { i32, i32, i32, i32 } %53, 0, !dbg !16
%55 = bitcast i32 %54 to <1 x float>, !dbg !16
%56 = extractvalue { i32, i32, i32, i32 } %53, 1, !dbg !16
%57 = bitcast i32 %56 to <1 x float>, !dbg !16
%58 = extractvalue { i32, i32, i32, i32 } %53, 2, !dbg !16
%59 = bitcast i32 %58 to <1 x float>, !dbg !16
%60 = extractvalue { i32, i32, i32, i32 } %53, 3, !dbg !16
%61 = bitcast i32 %60 to <1 x float>, !dbg !16
%62 = extractelement <1 x float> %55, i32 0, !dbg !16
%63 = extractelement <1 x float> %57, i32 0, !dbg !16
%64 = extractelement <1 x float> %59, i32 0, !dbg !16
%65 = extractelement <1 x float> %61, i32 0, !dbg !16
%66 = call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09@$5 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l,b"(ptr addrspace(1) %52, i1 %50), !dbg !16
%67 = extractvalue { i32, i32, i32, i32 } %66, 0, !dbg !16
%68 = bitcast i32 %67 to <1 x float>, !dbg !16
%69 = extractvalue { i32, i32, i32, i32 } %66, 1, !dbg !16
%70 = bitcast i32 %69 to <1 x float>, !dbg !16
%71 = extractvalue { i32, i32, i32, i32 } %66, 2, !dbg !16
%72 = bitcast i32 %71 to <1 x float>, !dbg !16
%73 = extractvalue { i32, i32, i32, i32 } %66, 3, !dbg !16
%74 = bitcast i32 %73 to <1 x float>, !dbg !16
%75 = extractelement <1 x float> %68, i32 0, !dbg !16
%76 = extractelement <1 x float> %70, i32 0, !dbg !16
%77 = extractelement <1 x float> %72, i32 0, !dbg !16
%78 = extractelement <1 x float> %74, i32 0, !dbg !16
%79 = getelementptr float, ptr addrspace(1) %1, i32 %47, !dbg !17
%80 = getelementptr float, ptr addrspace(1) %1, i32 %48, !dbg !17
%81 = call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09@$5 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l,b"(ptr addrspace(1) %79, i1 %49), !dbg !18
%82 = extractvalue { i32, i32, i32, i32 } %81, 0, !dbg !18
%83 = bitcast i32 %82 to <1 x float>, !dbg !18
%84 = extractvalue { i32, i32, i32, i32 } %81, 1, !dbg !18
%85 = bitcast i32 %84 to <1 x float>, !dbg !18
%86 = extractvalue { i32, i32, i32, i32 } %81, 2, !dbg !18
%87 = bitcast i32 %86 to <1 x float>, !dbg !18
%88 = extractvalue { i32, i32, i32, i32 } %81, 3, !dbg !18
%89 = bitcast i32 %88 to <1 x float>, !dbg !18
%90 = extractelement <1 x float> %83, i32 0, !dbg !18
%91 = extractelement <1 x float> %85, i32 0, !dbg !18
%92 = extractelement <1 x float> %87, i32 0, !dbg !18
%93 = extractelement <1 x float> %89, i32 0, !dbg !18
%94 = call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09@$5 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l,b"(ptr addrspace(1) %80, i1 %50), !dbg !18
%95 = extractvalue { i32, i32, i32, i32 } %94, 0, !dbg !18
%96 = bitcast i32 %95 to <1 x float>, !dbg !18
%97 = extractvalue { i32, i32, i32, i32 } %94, 1, !dbg !18
%98 = bitcast i32 %97 to <1 x float>, !dbg !18
%99 = extractvalue { i32, i32, i32, i32 } %94, 2, !dbg !18
%100 = bitcast i32 %99 to <1 x float>, !dbg !18
%101 = extractvalue { i32, i32, i32, i32 } %94, 3, !dbg !18
%102 = bitcast i32 %101 to <1 x float>, !dbg !18
%103 = extractelement <1 x float> %96, i32 0, !dbg !18
%104 = extractelement <1 x float> %98, i32 0, !dbg !18
%105 = extractelement <1 x float> %100, i32 0, !dbg !18
%106 = extractelement <1 x float> %102, i32 0, !dbg !18
%107 = fadd float %62, %90, !dbg !19
%108 = fadd float %63, %91, !dbg !19
%109 = fadd float %64, %92, !dbg !19
%110 = fadd float %65, %93, !dbg !19
%111 = fadd float %75, %103, !dbg !19
%112 = fadd float %76, %104, !dbg !19
%113 = fadd float %77, %105, !dbg !19
%114 = fadd float %78, %106, !dbg !19
%115 = getelementptr float, ptr addrspace(1) %2, i32 %47, !dbg !20
%116 = getelementptr float, ptr addrspace(1) %2, i32 %48, !dbg !20
%117 = insertelement <1 x float> undef, float %107, i32 0, !dbg !21
%118 = bitcast <1 x float> %117 to i32, !dbg !21
%119 = insertelement <1 x float> undef, float %108, i32 0, !dbg !21
%120 = bitcast <1 x float> %119 to i32, !dbg !21
%121 = insertelement <1 x float> undef, float %109, i32 0, !dbg !21
%122 = bitcast <1 x float> %121 to i32, !dbg !21
%123 = insertelement <1 x float> undef, float %110, i32 0, !dbg !21
%124 = bitcast <1 x float> %123 to i32, !dbg !21
%125 = and i1 true, %49, !dbg !21
call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %118, i32 %120, i32 %122, i32 %124, ptr addrspace(1) %115, i1 %125), !dbg !21
%126 = insertelement <1 x float> undef, float %111, i32 0, !dbg !21
%127 = bitcast <1 x float> %126 to i32, !dbg !21
%128 = insertelement <1 x float> undef, float %112, i32 0, !dbg !21
%129 = bitcast <1 x float> %128 to i32, !dbg !21
%130 = insertelement <1 x float> undef, float %113, i32 0, !dbg !21
%131 = bitcast <1 x float> %130 to i32, !dbg !21
%132 = insertelement <1 x float> undef, float %114, i32 0, !dbg !21
%133 = bitcast <1 x float> %132 to i32, !dbg !21
%134 = and i1 true, %50, !dbg !21
call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %127, i32 %129, i32 %131, i32 %133, ptr addrspace(1) %116, i1 %134), !dbg !21
ret void, !dbg !22
}
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x() #0
attributes #0 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
!llvm.module.flags = !{!0, !1}
!llvm.dbg.cu = !{!2}
!nvvm.annotations = !{!4, !5}
!llvm.ident = !{!6}
!0 = !{i32 2, !"Debug Info Version", i32 3}
!1 = !{i32 4, !"nvvm-reflect-ftz", i32 1}
!2 = distinct !DICompileUnit(language: DW_LANG_C, file: !3, producer: "triton", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly)
!3 = !DIFile(filename: "01-vector-add.py", directory: "/home/sasank/code")
!4 = !{ptr @add_kernel, !"kernel", i32 1}
!5 = !{ptr @add_kernel, !"maxntidx", i32 128}
!6 = !{!"clang version 3.8.0 (tags/RELEASE_380/final)"}
!7 = distinct !DISubprogram(name: "add_kernel", linkageName: "add_kernel", scope: !3, file: !3, line: 28, type: !8, scopeLine: 28, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !2)
!8 = !DISubroutineType(cc: DW_CC_normal, types: !9)
!9 = !{}
!10 = !DILocation(line: 37, column: 24, scope: !7)
!11 = !DILocation(line: 42, column: 24, scope: !7)
!12 = !DILocation(line: 43, column: 41, scope: !7)
!13 = !DILocation(line: 43, column: 28, scope: !7)
!14 = !DILocation(line: 45, column: 21, scope: !7)
!15 = !DILocation(line: 48, column: 24, scope: !7)
!16 = !DILocation(line: 48, column: 16, scope: !7)
!17 = !DILocation(line: 49, column: 24, scope: !7)
!18 = !DILocation(line: 49, column: 16, scope: !7)
!19 = !DILocation(line: 50, column: 17, scope: !7)
!20 = !DILocation(line: 52, column: 26, scope: !7)
!21 = !DILocation(line: 52, column: 35, scope: !7)
!22 = !DILocation(line: 52, column: 4, scope: !7)
This doesn't look very nice. We have
I checked out version 1.1.1 and hacked around the LLVM IR from the dumbed pickle binary:
Triton code:
@triton.jit
def add_kernel(
x_ptr, # *Pointer* to first input vector
y_ptr, # *Pointer* to second input vector
output_ptr, # *Pointer* to output vector
n_elements, # Size of the vector
**meta, # Optional meta-parameters for the kernel
):
BLOCK_SIZE = meta['BLOCK_SIZE'] # How many inputs each program should process
# There are multiple 'program's processing different data. We identify which program
# we are here
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0
# This program will process inputs that are offset from the initial data.
# for instance, if you had a vector of length 256 and block_size of 64, the programs
# would each access the elements [0:64, 64:128, 128:192, 192:256].
# Note that offsets is a list of pointers
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extar elements in case the input is not a
# multiple of the block size
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
# Write x + y back to DRAM
tl.store(output_ptr + offsets, output, mask=mask)
Triton IR:
def void add_kernel(f32* x_ptr .aligned(16) , f32* y_ptr .aligned(16) , f32* output_ptr .aligned(16) , i32 n_elements .multipleof(16) )
{
entry:
%0 = get_program_id(0) i32;
%1 = mul i32 %0, 1024;
%3 = make_range[0 : 1024] i32<1024>;
%4 = splat i32<1024> %1;
%6 = add i32<1024> %4, %3;
%9 = splat i32<1024> n_elements;
%11 = icmp_slt i1<1024> %6, %9;
%14 = splat f32*<1024> x_ptr;
%16 = getelementptr f32*<1024> %14, %6;
%19 = splat f32<1024> undef;
%20 = masked_load f32<1024> %16, %11, %19;
%24 = splat f32*<1024> y_ptr;
%26 = getelementptr f32*<1024> %24, %6;
%29 = splat f32<1024> undef;
%30 = masked_load f32<1024> %26, %11, %29;
%34 = fadd f32<1024> %20, %30;
%37 = splat f32*<1024> output_ptr;
%39 = getelementptr f32*<1024> %37, %6;
masked_store void %39, %34, %11;
ret void;
}
LLVM IR:
; ModuleID = 'add_kernel'
source_filename = "add_kernel"
define void @add_kernel(float addrspace(1)* align 16 %0, float addrspace(1)* align 16 %1, float addrspace(1)* align 16 %2, i32 %3) {
entry:
%4 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%5 = urem i32 %4, 32
%6 = udiv i32 %4, 32
%7 = mul i32 %6, 32
%8 = add i32 %7, %5
%9 = mul i32 %8, 4
%idx_0_0 = add i32 %9, 0
%idx_0_1 = add i32 %9, 1
%idx_0_2 = add i32 %9, 2
%idx_0_3 = add i32 %9, 3
%idx_0_4 = add i32 %9, 512
%idx_0_5 = add i32 %9, 513
%idx_0_6 = add i32 %9, 514
%idx_0_7 = add i32 %9, 515
%10 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%11 = mul i32 %10, 1024
%12 = add i32 0, %9
%13 = add i32 %12, 0
%14 = add i32 0, %9
%15 = add i32 %14, 1
%16 = add i32 0, %9
%17 = add i32 %16, 2
%18 = add i32 0, %9
%19 = add i32 %18, 3
%20 = add i32 0, %9
%21 = add i32 %20, 512
%22 = add i32 0, %9
%23 = add i32 %22, 513
%24 = add i32 0, %9
%25 = add i32 %24, 514
%26 = add i32 0, %9
%27 = add i32 %26, 515
%28 = add i32 %11, %12
%29 = add i32 %28, 0
%30 = add i32 %11, %14
%31 = add i32 %30, 1
%32 = add i32 %11, %16
%33 = add i32 %32, 2
%34 = add i32 %11, %18
%35 = add i32 %34, 3
%36 = add i32 %11, %20
%37 = add i32 %36, 512
%38 = add i32 %11, %22
%39 = add i32 %38, 513
%40 = add i32 %11, %24
%41 = add i32 %40, 514
%42 = add i32 %11, %26
%43 = add i32 %42, 515
%44 = icmp slt i32 %29, %3
%45 = icmp slt i32 %31, %3
%46 = icmp slt i32 %33, %3
%47 = icmp slt i32 %35, %3
%48 = icmp slt i32 %37, %3
%49 = icmp slt i32 %39, %3
%50 = icmp slt i32 %41, %3
%51 = icmp slt i32 %43, %3
%52 = getelementptr float, float addrspace(1)* %0, i32 %28
%53 = getelementptr float, float addrspace(1)* %52, i32 0
%54 = getelementptr float, float addrspace(1)* %0, i32 %30
%55 = getelementptr float, float addrspace(1)* %54, i32 1
%56 = getelementptr float, float addrspace(1)* %0, i32 %32
%57 = getelementptr float, float addrspace(1)* %56, i32 2
%58 = getelementptr float, float addrspace(1)* %0, i32 %34
%59 = getelementptr float, float addrspace(1)* %58, i32 3
%60 = getelementptr float, float addrspace(1)* %0, i32 %36
%61 = getelementptr float, float addrspace(1)* %60, i32 512
%62 = getelementptr float, float addrspace(1)* %0, i32 %38
%63 = getelementptr float, float addrspace(1)* %62, i32 513
%64 = getelementptr float, float addrspace(1)* %0, i32 %40
%65 = getelementptr float, float addrspace(1)* %64, i32 514
%66 = getelementptr float, float addrspace(1)* %0, i32 %42
%67 = getelementptr float, float addrspace(1)* %66, i32 515
%68 = call { i32, i32, i32, i32 } asm sideeffect "@$4 ld.global.v4.b32 {$0,$1,$2,$3}, [ $5 + 0];", "=r,=r,=r,=r,b,l"(i1 %44, float addrspace(1)* %52)
%69 = extractvalue { i32, i32, i32, i32 } %68, 0
%70 = bitcast i32 %69 to <1 x float>
%71 = extractvalue { i32, i32, i32, i32 } %68, 1
%72 = bitcast i32 %71 to <1 x float>
%73 = extractvalue { i32, i32, i32, i32 } %68, 2
%74 = bitcast i32 %73 to <1 x float>
%75 = extractvalue { i32, i32, i32, i32 } %68, 3
%76 = bitcast i32 %75 to <1 x float>
%77 = extractelement <1 x float> %70, i64 0
%78 = extractelement <1 x float> %72, i64 0
%79 = extractelement <1 x float> %74, i64 0
%80 = extractelement <1 x float> %76, i64 0
%81 = call { i32, i32, i32, i32 } asm sideeffect "@$4 ld.global.v4.b32 {$0,$1,$2,$3}, [ $5 + 2048];", "=r,=r,=r,=r,b,l"(i1 %48, float addrspace(1)* %60)
%82 = extractvalue { i32, i32, i32, i32 } %81, 0
%83 = bitcast i32 %82 to <1 x float>
%84 = extractvalue { i32, i32, i32, i32 } %81, 1
%85 = bitcast i32 %84 to <1 x float>
%86 = extractvalue { i32, i32, i32, i32 } %81, 2
%87 = bitcast i32 %86 to <1 x float>
%88 = extractvalue { i32, i32, i32, i32 } %81, 3
%89 = bitcast i32 %88 to <1 x float>
%90 = extractelement <1 x float> %83, i64 0
%91 = extractelement <1 x float> %85, i64 0
%92 = extractelement <1 x float> %87, i64 0
%93 = extractelement <1 x float> %89, i64 0
%94 = getelementptr float, float addrspace(1)* %1, i32 %28
%95 = getelementptr float, float addrspace(1)* %94, i32 0
%96 = getelementptr float, float addrspace(1)* %1, i32 %30
%97 = getelementptr float, float addrspace(1)* %96, i32 1
%98 = getelementptr float, float addrspace(1)* %1, i32 %32
%99 = getelementptr float, float addrspace(1)* %98, i32 2
%100 = getelementptr float, float addrspace(1)* %1, i32 %34
%101 = getelementptr float, float addrspace(1)* %100, i32 3
%102 = getelementptr float, float addrspace(1)* %1, i32 %36
%103 = getelementptr float, float addrspace(1)* %102, i32 512
%104 = getelementptr float, float addrspace(1)* %1, i32 %38
%105 = getelementptr float, float addrspace(1)* %104, i32 513
%106 = getelementptr float, float addrspace(1)* %1, i32 %40
%107 = getelementptr float, float addrspace(1)* %106, i32 514
%108 = getelementptr float, float addrspace(1)* %1, i32 %42
%109 = getelementptr float, float addrspace(1)* %108, i32 515
%110 = call { i32, i32, i32, i32 } asm sideeffect "@$4 ld.global.v4.b32 {$0,$1,$2,$3}, [ $5 + 0];", "=r,=r,=r,=r,b,l"(i1 %44, float addrspace(1)* %94)
%111 = extractvalue { i32, i32, i32, i32 } %110, 0
%112 = bitcast i32 %111 to <1 x float>
%113 = extractvalue { i32, i32, i32, i32 } %110, 1
%114 = bitcast i32 %113 to <1 x float>
%115 = extractvalue { i32, i32, i32, i32 } %110, 2
%116 = bitcast i32 %115 to <1 x float>
%117 = extractvalue { i32, i32, i32, i32 } %110, 3
%118 = bitcast i32 %117 to <1 x float>
%119 = extractelement <1 x float> %112, i64 0
%120 = extractelement <1 x float> %114, i64 0
%121 = extractelement <1 x float> %116, i64 0
%122 = extractelement <1 x float> %118, i64 0
%123 = call { i32, i32, i32, i32 } asm sideeffect "@$4 ld.global.v4.b32 {$0,$1,$2,$3}, [ $5 + 2048];", "=r,=r,=r,=r,b,l"(i1 %48, float addrspace(1)* %102)
%124 = extractvalue { i32, i32, i32, i32 } %123, 0
%125 = bitcast i32 %124 to <1 x float>
%126 = extractvalue { i32, i32, i32, i32 } %123, 1
%127 = bitcast i32 %126 to <1 x float>
%128 = extractvalue { i32, i32, i32, i32 } %123, 2
%129 = bitcast i32 %128 to <1 x float>
%130 = extractvalue { i32, i32, i32, i32 } %123, 3
%131 = bitcast i32 %130 to <1 x float>
%132 = extractelement <1 x float> %125, i64 0
%133 = extractelement <1 x float> %127, i64 0
%134 = extractelement <1 x float> %129, i64 0
%135 = extractelement <1 x float> %131, i64 0
%136 = fadd float %77, %119
%137 = fadd float %78, %120
%138 = fadd float %79, %121
%139 = fadd float %80, %122
%140 = fadd float %90, %132
%141 = fadd float %91, %133
%142 = fadd float %92, %134
%143 = fadd float %93, %135
%144 = getelementptr float, float addrspace(1)* %2, i32 %28
%145 = getelementptr float, float addrspace(1)* %144, i32 0
%146 = getelementptr float, float addrspace(1)* %2, i32 %30
%147 = getelementptr float, float addrspace(1)* %146, i32 1
%148 = getelementptr float, float addrspace(1)* %2, i32 %32
%149 = getelementptr float, float addrspace(1)* %148, i32 2
%150 = getelementptr float, float addrspace(1)* %2, i32 %34
%151 = getelementptr float, float addrspace(1)* %150, i32 3
%152 = getelementptr float, float addrspace(1)* %2, i32 %36
%153 = getelementptr float, float addrspace(1)* %152, i32 512
%154 = getelementptr float, float addrspace(1)* %2, i32 %38
%155 = getelementptr float, float addrspace(1)* %154, i32 513
%156 = getelementptr float, float addrspace(1)* %2, i32 %40
%157 = getelementptr float, float addrspace(1)* %156, i32 514
%158 = getelementptr float, float addrspace(1)* %2, i32 %42
%159 = getelementptr float, float addrspace(1)* %158, i32 515
%160 = bitcast float addrspace(1)* %145 to <4 x float> addrspace(1)*
%161 = insertelement <4 x float> undef, float %136, i64 0
%162 = insertelement <4 x float> %161, float %137, i64 1
%163 = insertelement <4 x float> %162, float %138, i64 2
%164 = insertelement <4 x float> %163, float %139, i64 3
br i1 %44, label %165, label %166
165: ; preds = %entry
store <4 x float> %164, <4 x float> addrspace(1)* %160, align 16
br label %166
166: ; preds = %entry, %165
%167 = bitcast float addrspace(1)* %153 to <4 x float> addrspace(1)*
%168 = insertelement <4 x float> undef, float %140, i64 0
%169 = insertelement <4 x float> %168, float %141, i64 1
%170 = insertelement <4 x float> %169, float %142, i64 2
%171 = insertelement <4 x float> %170, float %143, i64 3
br i1 %48, label %172, label %173
172: ; preds = %166
store <4 x float> %171, <4 x float> addrspace(1)* %167, align 16
br label %173
173: ; preds = %166, %172
call void @llvm.donothing()
call void @llvm.donothing()
ret void
}
; Function Attrs: nounwind readnone
declare i32 @llvm.nvvm.read.ptx.sreg.tid.x() #0
; Function Attrs: nounwind readnone
declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #0
; Function Attrs: nounwind readnone willreturn
declare void @llvm.donothing() #1
attributes #0 = { nounwind readnone }
attributes #1 = { nounwind readnone willreturn }
!nvvm.annotations = !{!0, !1}
!0 = !{void (float addrspace(1)*, float addrspace(1)*, float addrspace(1)*, i32)* @add_kernel, !"kernel", i32 1}
!1 = !{void (float addrspace(1)*, float addrspace(1)*, float addrspace(1)*, i32)* @add_kernel, !"maxntidx", i32 128}
PTX
//
// Generated by LLVM NVPTX Back-End
//
.version 7.3
.target sm_75
.address_size 64
// .globl add_kernel
.visible .entry add_kernel(
.param .u64 add_kernel_param_0,
.param .u64 add_kernel_param_1,
.param .u64 add_kernel_param_2,
.param .u32 add_kernel_param_3
)
.maxntid 128, 1, 1
{
.reg .pred %p<7>;
.reg .f32 %f<29>;
.reg .b32 %r<23>;
.reg .b64 %rd<10>;
ld.param.u32 %r2, [add_kernel_param_3];
ld.param.u64 %rd6, [add_kernel_param_0];
ld.param.u64 %rd7, [add_kernel_param_1];
mov.u32 %r19, %tid.x;
ld.param.u64 %rd8, [add_kernel_param_2];
shl.b32 %r20, %r19, 2;
mov.u32 %r21, %ctaid.x;
mad.lo.s32 %r22, %r21, 1024, %r20;
add.s32 %r1, %r22, 512;
setp.ge.s32 %p5, %r22, %r2;
setp.lt.s32 %p1, %r22, %r2;
setp.lt.s32 %p2, %r1, %r2;
mul.wide.s32 %rd9, %r22, 4;
add.s64 %rd2, %rd6, %rd9;
@%p1 ld.global.v4.b32 {%r3,%r4,%r5,%r6}, [ %rd2 + 0];
@%p2 ld.global.v4.b32 {%r7,%r8,%r9,%r10}, [ %rd2 + 2048];
add.s64 %rd4, %rd7, %rd9;
@%p1 ld.global.v4.b32 {%r11,%r12,%r13,%r14}, [ %rd4 + 0];
@%p2 ld.global.v4.b32 {%r15,%r16,%r17,%r18}, [ %rd4 + 2048];
add.s64 %rd1, %rd8, %rd9;
@%p5 bra LBB0_2;
mov.b32 %f13, %r3;
mov.b32 %f14, %r4;
mov.b32 %f15, %r5;
mov.b32 %f16, %r6;
mov.b32 %f21, %r11;
mov.b32 %f22, %r12;
mov.b32 %f23, %r13;
mov.b32 %f24, %r14;
add.f32 %f5, %f13, %f21;
add.f32 %f6, %f14, %f22;
add.f32 %f7, %f15, %f23;
add.f32 %f8, %f16, %f24;
st.global.v4.f32 [%rd1], {%f5, %f6, %f7, %f8};
LBB0_2:
setp.ge.s32 %p6, %r1, %r2;
@%p6 bra LBB0_4;
mov.b32 %f17, %r7;
mov.b32 %f18, %r8;
mov.b32 %f19, %r9;
mov.b32 %f20, %r10;
mov.b32 %f25, %r15;
mov.b32 %f26, %r16;
mov.b32 %f27, %r17;
mov.b32 %f28, %r18;
add.f32 %f1, %f17, %f25;
add.f32 %f2, %f18, %f26;
add.f32 %f3, %f19, %f27;
add.f32 %f4, %f20, %f28;
st.global.v4.f32 [%rd1+2048], {%f1, %f2, %f3, %f4};
LBB0_4:
ret;
}
I don't like seeing inline assembly here :(.
Why do we need inline assembly? Can't LLVM intrinsics get us to this?
Had a quick look at the IR. My observations are the following:
Thanks a lot for your analysis. I wonder why these assembly instructions are used instead of LLVM intrinsics.
For kernel programming, we should be able to create triton backend for Nvidia GPUs etc. Paper is found here: https://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf. Implementation itself is here: https://github.com/triton-lang/triton