chsasank / llama.lisp

Lisp dialect designed for HPC and AI
GNU Lesser General Public License v2.1
15 stars 6 forks source link

Triton -> C-Lisp compiler #79

Open chsasank opened 4 months ago

chsasank commented 4 months ago

For kernel programming, we should be able to create triton backend for Nvidia GPUs etc. Paper is found here: https://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf. Implementation itself is here: https://github.com/triton-lang/triton

chsasank commented 3 months ago

Contributor should complete tutorials here: https://triton-lang.org/main/getting-started/tutorials/index.html

chsasank commented 3 months ago

A way to look at LLVM IR generated is given here: https://github.com/triton-lang/triton?tab=readme-ov-file#tips-for-hacking

chsasank commented 3 months ago

IR Dump after LLVM_IR_ENABLE_DUMP=1 python 01-vector-add.py 2> vector_add.ll

import triton
import triton.language as tl

@triton.jit
def add_kernel(x_ptr,  # *Pointer* to first input vector.
               y_ptr,  # *Pointer* to second input vector.
               output_ptr,  # *Pointer* to output vector.
               n_elements,  # Size of the vector.
               BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process.
               # NOTE: `constexpr` so it can be used as a shape value.
               ):
    # There are multiple 'programs' processing different data. We identify which program
    # we are here:
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    # This program will process inputs that are offset from the initial data.
    # For instance, if you had a vector of length 256 and block_size of 64, the programs
    # would each access the elements [0:64, 64:128, 128:192, 192:256].
    # Note that offsets is a list of pointers:
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # Create a mask to guard memory operations against out-of-bounds accesses.
    mask = offsets < n_elements
    # Load x and y from DRAM, masking out any extra elements in case the input is not a
    # multiple of the block size.
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    # Write x + y back to DRAM.
    tl.store(output_ptr + offsets, output, mask=mask)

Only pasting first IR

; *** IR Dump After Annotation2MetadataPass on [module] ***
; ModuleID = 'LLVMDialectModule'
source_filename = "LLVMDialectModule"

@global_smem = external addrspace(3) global [0 x i8], align 16

define void @add_kernel(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, i32 %3) !dbg !7 {
  %5 = call i32 asm "mov.u32 $0, %ctaid.x;", "=r"(), !dbg !10
  %6 = mul i32 %5, 1024, !dbg !11
  %7 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !dbg !12
  %8 = urem i32 %7, 32, !dbg !12
  %9 = udiv i32 %7, 32, !dbg !12
  %10 = and i32 %8, 1, !dbg !12
  %11 = icmp eq i32 %10, 0, !dbg !12
  %12 = select i1 %11, i32 0, i32 4, !dbg !12
  %13 = xor i32 0, %12, !dbg !12
  %14 = and i32 %8, 2, !dbg !12
  %15 = icmp eq i32 %14, 0, !dbg !12
  %16 = select i1 %15, i32 0, i32 8, !dbg !12
  %17 = xor i32 %13, %16, !dbg !12
  %18 = and i32 %8, 4, !dbg !12
  %19 = icmp eq i32 %18, 0, !dbg !12
  %20 = select i1 %19, i32 0, i32 16, !dbg !12
  %21 = xor i32 %17, %20, !dbg !12
  %22 = and i32 %8, 8, !dbg !12
  %23 = icmp eq i32 %22, 0, !dbg !12
  %24 = select i1 %23, i32 0, i32 32, !dbg !12
  %25 = xor i32 %21, %24, !dbg !12
  %26 = and i32 %8, 16, !dbg !12
  %27 = icmp eq i32 %26, 0, !dbg !12
  %28 = select i1 %27, i32 0, i32 64, !dbg !12
  %29 = xor i32 %25, %28, !dbg !12
  %30 = and i32 %9, 1, !dbg !12
  %31 = icmp eq i32 %30, 0, !dbg !12
  %32 = select i1 %31, i32 0, i32 128, !dbg !12
  %33 = xor i32 %29, %32, !dbg !12
  %34 = and i32 %9, 2, !dbg !12
  %35 = icmp eq i32 %34, 0, !dbg !12
  %36 = select i1 %35, i32 0, i32 256, !dbg !12
  %37 = xor i32 %33, %36, !dbg !12
  %38 = xor i32 512, %12, !dbg !12
  %39 = xor i32 %38, %16, !dbg !12
  %40 = xor i32 %39, %20, !dbg !12
  %41 = xor i32 %40, %24, !dbg !12
  %42 = xor i32 %41, %28, !dbg !12
  %43 = xor i32 %42, %32, !dbg !12
  %44 = xor i32 %43, %36, !dbg !12
  %45 = add i32 %37, 0, !dbg !12
  %46 = add i32 %44, 0, !dbg !12
  %47 = add i32 %6, %45, !dbg !13
  %48 = add i32 %6, %46, !dbg !13
  %49 = icmp slt i32 %47, %3, !dbg !14
  %50 = icmp slt i32 %48, %3, !dbg !14
  %51 = getelementptr float, ptr addrspace(1) %0, i32 %47, !dbg !15
  %52 = getelementptr float, ptr addrspace(1) %0, i32 %48, !dbg !15
  %53 = call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09@$5 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l,b"(ptr addrspace(1) %51, i1 %49), !dbg !16
  %54 = extractvalue { i32, i32, i32, i32 } %53, 0, !dbg !16
  %55 = bitcast i32 %54 to <1 x float>, !dbg !16
  %56 = extractvalue { i32, i32, i32, i32 } %53, 1, !dbg !16
  %57 = bitcast i32 %56 to <1 x float>, !dbg !16
  %58 = extractvalue { i32, i32, i32, i32 } %53, 2, !dbg !16
  %59 = bitcast i32 %58 to <1 x float>, !dbg !16
  %60 = extractvalue { i32, i32, i32, i32 } %53, 3, !dbg !16
  %61 = bitcast i32 %60 to <1 x float>, !dbg !16
  %62 = extractelement <1 x float> %55, i32 0, !dbg !16
  %63 = extractelement <1 x float> %57, i32 0, !dbg !16
  %64 = extractelement <1 x float> %59, i32 0, !dbg !16
  %65 = extractelement <1 x float> %61, i32 0, !dbg !16
  %66 = call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09@$5 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l,b"(ptr addrspace(1) %52, i1 %50), !dbg !16
  %67 = extractvalue { i32, i32, i32, i32 } %66, 0, !dbg !16
  %68 = bitcast i32 %67 to <1 x float>, !dbg !16
  %69 = extractvalue { i32, i32, i32, i32 } %66, 1, !dbg !16
  %70 = bitcast i32 %69 to <1 x float>, !dbg !16
  %71 = extractvalue { i32, i32, i32, i32 } %66, 2, !dbg !16
  %72 = bitcast i32 %71 to <1 x float>, !dbg !16
  %73 = extractvalue { i32, i32, i32, i32 } %66, 3, !dbg !16
  %74 = bitcast i32 %73 to <1 x float>, !dbg !16
  %75 = extractelement <1 x float> %68, i32 0, !dbg !16
  %76 = extractelement <1 x float> %70, i32 0, !dbg !16
  %77 = extractelement <1 x float> %72, i32 0, !dbg !16
  %78 = extractelement <1 x float> %74, i32 0, !dbg !16
  %79 = getelementptr float, ptr addrspace(1) %1, i32 %47, !dbg !17
  %80 = getelementptr float, ptr addrspace(1) %1, i32 %48, !dbg !17
  %81 = call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09@$5 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l,b"(ptr addrspace(1) %79, i1 %49), !dbg !18
  %82 = extractvalue { i32, i32, i32, i32 } %81, 0, !dbg !18
  %83 = bitcast i32 %82 to <1 x float>, !dbg !18
  %84 = extractvalue { i32, i32, i32, i32 } %81, 1, !dbg !18
  %85 = bitcast i32 %84 to <1 x float>, !dbg !18
  %86 = extractvalue { i32, i32, i32, i32 } %81, 2, !dbg !18
  %87 = bitcast i32 %86 to <1 x float>, !dbg !18
  %88 = extractvalue { i32, i32, i32, i32 } %81, 3, !dbg !18
  %89 = bitcast i32 %88 to <1 x float>, !dbg !18
  %90 = extractelement <1 x float> %83, i32 0, !dbg !18
  %91 = extractelement <1 x float> %85, i32 0, !dbg !18
  %92 = extractelement <1 x float> %87, i32 0, !dbg !18
  %93 = extractelement <1 x float> %89, i32 0, !dbg !18
  %94 = call { i32, i32, i32, i32 } asm sideeffect "mov.u32 $0, 0x0;\0A\09mov.u32 $1, 0x0;\0A\09mov.u32 $2, 0x0;\0A\09mov.u32 $3, 0x0;\0A\09@$5 ld.global.v4.b32 { $0, $1, $2, $3 }, [ $4 + 0 ];", "=r,=r,=r,=r,l,b"(ptr addrspace(1) %80, i1 %50), !dbg !18
  %95 = extractvalue { i32, i32, i32, i32 } %94, 0, !dbg !18
  %96 = bitcast i32 %95 to <1 x float>, !dbg !18
  %97 = extractvalue { i32, i32, i32, i32 } %94, 1, !dbg !18
  %98 = bitcast i32 %97 to <1 x float>, !dbg !18
  %99 = extractvalue { i32, i32, i32, i32 } %94, 2, !dbg !18
  %100 = bitcast i32 %99 to <1 x float>, !dbg !18
  %101 = extractvalue { i32, i32, i32, i32 } %94, 3, !dbg !18
  %102 = bitcast i32 %101 to <1 x float>, !dbg !18
  %103 = extractelement <1 x float> %96, i32 0, !dbg !18
  %104 = extractelement <1 x float> %98, i32 0, !dbg !18
  %105 = extractelement <1 x float> %100, i32 0, !dbg !18
  %106 = extractelement <1 x float> %102, i32 0, !dbg !18
  %107 = fadd float %62, %90, !dbg !19
  %108 = fadd float %63, %91, !dbg !19
  %109 = fadd float %64, %92, !dbg !19
  %110 = fadd float %65, %93, !dbg !19
  %111 = fadd float %75, %103, !dbg !19
  %112 = fadd float %76, %104, !dbg !19
  %113 = fadd float %77, %105, !dbg !19
  %114 = fadd float %78, %106, !dbg !19
  %115 = getelementptr float, ptr addrspace(1) %2, i32 %47, !dbg !20
  %116 = getelementptr float, ptr addrspace(1) %2, i32 %48, !dbg !20
  %117 = insertelement <1 x float> undef, float %107, i32 0, !dbg !21
  %118 = bitcast <1 x float> %117 to i32, !dbg !21
  %119 = insertelement <1 x float> undef, float %108, i32 0, !dbg !21
  %120 = bitcast <1 x float> %119 to i32, !dbg !21
  %121 = insertelement <1 x float> undef, float %109, i32 0, !dbg !21
  %122 = bitcast <1 x float> %121 to i32, !dbg !21
  %123 = insertelement <1 x float> undef, float %110, i32 0, !dbg !21
  %124 = bitcast <1 x float> %123 to i32, !dbg !21
  %125 = and i1 true, %49, !dbg !21
  call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %118, i32 %120, i32 %122, i32 %124, ptr addrspace(1) %115, i1 %125), !dbg !21
  %126 = insertelement <1 x float> undef, float %111, i32 0, !dbg !21
  %127 = bitcast <1 x float> %126 to i32, !dbg !21
  %128 = insertelement <1 x float> undef, float %112, i32 0, !dbg !21
  %129 = bitcast <1 x float> %128 to i32, !dbg !21
  %130 = insertelement <1 x float> undef, float %113, i32 0, !dbg !21
  %131 = bitcast <1 x float> %130 to i32, !dbg !21
  %132 = insertelement <1 x float> undef, float %114, i32 0, !dbg !21
  %133 = bitcast <1 x float> %132 to i32, !dbg !21
  %134 = and i1 true, %50, !dbg !21
  call void asm sideeffect "@$5 st.global.v4.b32 [ $4 + 0 ], { $0, $1, $2, $3 };", "r,r,r,r,l,b"(i32 %127, i32 %129, i32 %131, i32 %133, ptr addrspace(1) %116, i1 %134), !dbg !21
  ret void, !dbg !22
}

; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x() #0

attributes #0 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }

!llvm.module.flags = !{!0, !1}
!llvm.dbg.cu = !{!2}
!nvvm.annotations = !{!4, !5}
!llvm.ident = !{!6}

!0 = !{i32 2, !"Debug Info Version", i32 3}
!1 = !{i32 4, !"nvvm-reflect-ftz", i32 1}
!2 = distinct !DICompileUnit(language: DW_LANG_C, file: !3, producer: "triton", isOptimized: true, runtimeVersion: 0, emissionKind: LineTablesOnly)
!3 = !DIFile(filename: "01-vector-add.py", directory: "/home/sasank/code")
!4 = !{ptr @add_kernel, !"kernel", i32 1}
!5 = !{ptr @add_kernel, !"maxntidx", i32 128}
!6 = !{!"clang version 3.8.0 (tags/RELEASE_380/final)"}
!7 = distinct !DISubprogram(name: "add_kernel", linkageName: "add_kernel", scope: !3, file: !3, line: 28, type: !8, scopeLine: 28, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !2)
!8 = !DISubroutineType(cc: DW_CC_normal, types: !9)
!9 = !{}
!10 = !DILocation(line: 37, column: 24, scope: !7)
!11 = !DILocation(line: 42, column: 24, scope: !7)
!12 = !DILocation(line: 43, column: 41, scope: !7)
!13 = !DILocation(line: 43, column: 28, scope: !7)
!14 = !DILocation(line: 45, column: 21, scope: !7)
!15 = !DILocation(line: 48, column: 24, scope: !7)
!16 = !DILocation(line: 48, column: 16, scope: !7)
!17 = !DILocation(line: 49, column: 24, scope: !7)
!18 = !DILocation(line: 49, column: 16, scope: !7)
!19 = !DILocation(line: 50, column: 17, scope: !7)
!20 = !DILocation(line: 52, column: 26, scope: !7)
!21 = !DILocation(line: 52, column: 35, scope: !7)
!22 = !DILocation(line: 52, column: 4, scope: !7)
chsasank commented 3 months ago

This doesn't look very nice. We have

  1. Inline assembly code.
  2. Lot of instructions for what should be fairly simple code
chsasank commented 3 months ago

I checked out version 1.1.1 and hacked around the LLVM IR from the dumbed pickle binary:

Triton code:

@triton.jit
def add_kernel(
    x_ptr,  # *Pointer* to first input vector
    y_ptr,  # *Pointer* to second input vector
    output_ptr,  # *Pointer* to output vector
    n_elements,  # Size of the vector
    **meta,  # Optional meta-parameters for the kernel
):
    BLOCK_SIZE = meta['BLOCK_SIZE']  # How many inputs each program should process
    # There are multiple 'program's processing different data. We identify which program
    # we are here
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0
    # This program will process inputs that are offset from the initial data.
    # for instance, if you had a vector of length 256 and block_size of 64, the programs
    # would each access the elements [0:64, 64:128, 128:192, 192:256].
    # Note that offsets is a list of pointers
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # Create a mask to guard memory operations against out-of-bounds accesses
    mask = offsets < n_elements
    # Load x and y from DRAM, masking out any extar elements in case the input is not a
    # multiple of the block size
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    # Write x + y back to DRAM
    tl.store(output_ptr + offsets, output, mask=mask)

Triton IR:

def void add_kernel(f32* x_ptr .aligned(16) , f32* y_ptr .aligned(16) , f32* output_ptr .aligned(16) , i32 n_elements .multipleof(16) )
{
entry:
  %0 = get_program_id(0) i32;
  %1 = mul i32 %0, 1024;
  %3 = make_range[0 : 1024] i32<1024>;
  %4 = splat i32<1024> %1;
  %6 = add i32<1024> %4, %3;
  %9 = splat i32<1024> n_elements;
  %11 = icmp_slt i1<1024> %6, %9;
  %14 = splat f32*<1024> x_ptr;
  %16 = getelementptr f32*<1024> %14, %6;
  %19 = splat f32<1024> undef;
  %20 = masked_load f32<1024> %16, %11, %19;
  %24 = splat f32*<1024> y_ptr;
  %26 = getelementptr f32*<1024> %24, %6;
  %29 = splat f32<1024> undef;
  %30 = masked_load f32<1024> %26, %11, %29;
  %34 = fadd f32<1024> %20, %30;
  %37 = splat f32*<1024> output_ptr;
  %39 = getelementptr f32*<1024> %37, %6;
  masked_store void %39, %34, %11;
  ret void;
}

LLVM IR:

; ModuleID = 'add_kernel'
source_filename = "add_kernel"

define void @add_kernel(float addrspace(1)* align 16 %0, float addrspace(1)* align 16 %1, float addrspace(1)* align 16 %2, i32 %3) {
entry:
  %4 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
  %5 = urem i32 %4, 32
  %6 = udiv i32 %4, 32
  %7 = mul i32 %6, 32
  %8 = add i32 %7, %5
  %9 = mul i32 %8, 4
  %idx_0_0 = add i32 %9, 0
  %idx_0_1 = add i32 %9, 1
  %idx_0_2 = add i32 %9, 2
  %idx_0_3 = add i32 %9, 3
  %idx_0_4 = add i32 %9, 512
  %idx_0_5 = add i32 %9, 513
  %idx_0_6 = add i32 %9, 514
  %idx_0_7 = add i32 %9, 515
  %10 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
  %11 = mul i32 %10, 1024
  %12 = add i32 0, %9
  %13 = add i32 %12, 0
  %14 = add i32 0, %9
  %15 = add i32 %14, 1
  %16 = add i32 0, %9
  %17 = add i32 %16, 2
  %18 = add i32 0, %9
  %19 = add i32 %18, 3
  %20 = add i32 0, %9
  %21 = add i32 %20, 512
  %22 = add i32 0, %9
  %23 = add i32 %22, 513
  %24 = add i32 0, %9
  %25 = add i32 %24, 514
  %26 = add i32 0, %9
  %27 = add i32 %26, 515
  %28 = add i32 %11, %12
  %29 = add i32 %28, 0
  %30 = add i32 %11, %14
  %31 = add i32 %30, 1
  %32 = add i32 %11, %16
  %33 = add i32 %32, 2
  %34 = add i32 %11, %18
  %35 = add i32 %34, 3
  %36 = add i32 %11, %20
  %37 = add i32 %36, 512
  %38 = add i32 %11, %22
  %39 = add i32 %38, 513
  %40 = add i32 %11, %24
  %41 = add i32 %40, 514
  %42 = add i32 %11, %26
  %43 = add i32 %42, 515
  %44 = icmp slt i32 %29, %3
  %45 = icmp slt i32 %31, %3
  %46 = icmp slt i32 %33, %3
  %47 = icmp slt i32 %35, %3
  %48 = icmp slt i32 %37, %3
  %49 = icmp slt i32 %39, %3
  %50 = icmp slt i32 %41, %3
  %51 = icmp slt i32 %43, %3
  %52 = getelementptr float, float addrspace(1)* %0, i32 %28
  %53 = getelementptr float, float addrspace(1)* %52, i32 0
  %54 = getelementptr float, float addrspace(1)* %0, i32 %30
  %55 = getelementptr float, float addrspace(1)* %54, i32 1
  %56 = getelementptr float, float addrspace(1)* %0, i32 %32
  %57 = getelementptr float, float addrspace(1)* %56, i32 2
  %58 = getelementptr float, float addrspace(1)* %0, i32 %34
  %59 = getelementptr float, float addrspace(1)* %58, i32 3
  %60 = getelementptr float, float addrspace(1)* %0, i32 %36
  %61 = getelementptr float, float addrspace(1)* %60, i32 512
  %62 = getelementptr float, float addrspace(1)* %0, i32 %38
  %63 = getelementptr float, float addrspace(1)* %62, i32 513
  %64 = getelementptr float, float addrspace(1)* %0, i32 %40
  %65 = getelementptr float, float addrspace(1)* %64, i32 514
  %66 = getelementptr float, float addrspace(1)* %0, i32 %42
  %67 = getelementptr float, float addrspace(1)* %66, i32 515
  %68 = call { i32, i32, i32, i32 } asm sideeffect "@$4 ld.global.v4.b32 {$0,$1,$2,$3}, [ $5 + 0];", "=r,=r,=r,=r,b,l"(i1 %44, float addrspace(1)* %52)
  %69 = extractvalue { i32, i32, i32, i32 } %68, 0
  %70 = bitcast i32 %69 to <1 x float>
  %71 = extractvalue { i32, i32, i32, i32 } %68, 1
  %72 = bitcast i32 %71 to <1 x float>
  %73 = extractvalue { i32, i32, i32, i32 } %68, 2
  %74 = bitcast i32 %73 to <1 x float>
  %75 = extractvalue { i32, i32, i32, i32 } %68, 3
  %76 = bitcast i32 %75 to <1 x float>
  %77 = extractelement <1 x float> %70, i64 0
  %78 = extractelement <1 x float> %72, i64 0
  %79 = extractelement <1 x float> %74, i64 0
  %80 = extractelement <1 x float> %76, i64 0
  %81 = call { i32, i32, i32, i32 } asm sideeffect "@$4 ld.global.v4.b32 {$0,$1,$2,$3}, [ $5 + 2048];", "=r,=r,=r,=r,b,l"(i1 %48, float addrspace(1)* %60)
  %82 = extractvalue { i32, i32, i32, i32 } %81, 0
  %83 = bitcast i32 %82 to <1 x float>
  %84 = extractvalue { i32, i32, i32, i32 } %81, 1
  %85 = bitcast i32 %84 to <1 x float>
  %86 = extractvalue { i32, i32, i32, i32 } %81, 2
  %87 = bitcast i32 %86 to <1 x float>
  %88 = extractvalue { i32, i32, i32, i32 } %81, 3
  %89 = bitcast i32 %88 to <1 x float>
  %90 = extractelement <1 x float> %83, i64 0
  %91 = extractelement <1 x float> %85, i64 0
  %92 = extractelement <1 x float> %87, i64 0
  %93 = extractelement <1 x float> %89, i64 0
  %94 = getelementptr float, float addrspace(1)* %1, i32 %28
  %95 = getelementptr float, float addrspace(1)* %94, i32 0
  %96 = getelementptr float, float addrspace(1)* %1, i32 %30
  %97 = getelementptr float, float addrspace(1)* %96, i32 1
  %98 = getelementptr float, float addrspace(1)* %1, i32 %32
  %99 = getelementptr float, float addrspace(1)* %98, i32 2
  %100 = getelementptr float, float addrspace(1)* %1, i32 %34
  %101 = getelementptr float, float addrspace(1)* %100, i32 3
  %102 = getelementptr float, float addrspace(1)* %1, i32 %36
  %103 = getelementptr float, float addrspace(1)* %102, i32 512
  %104 = getelementptr float, float addrspace(1)* %1, i32 %38
  %105 = getelementptr float, float addrspace(1)* %104, i32 513
  %106 = getelementptr float, float addrspace(1)* %1, i32 %40
  %107 = getelementptr float, float addrspace(1)* %106, i32 514
  %108 = getelementptr float, float addrspace(1)* %1, i32 %42
  %109 = getelementptr float, float addrspace(1)* %108, i32 515
  %110 = call { i32, i32, i32, i32 } asm sideeffect "@$4 ld.global.v4.b32 {$0,$1,$2,$3}, [ $5 + 0];", "=r,=r,=r,=r,b,l"(i1 %44, float addrspace(1)* %94)
  %111 = extractvalue { i32, i32, i32, i32 } %110, 0
  %112 = bitcast i32 %111 to <1 x float>
  %113 = extractvalue { i32, i32, i32, i32 } %110, 1
  %114 = bitcast i32 %113 to <1 x float>
  %115 = extractvalue { i32, i32, i32, i32 } %110, 2
  %116 = bitcast i32 %115 to <1 x float>
  %117 = extractvalue { i32, i32, i32, i32 } %110, 3
  %118 = bitcast i32 %117 to <1 x float>
  %119 = extractelement <1 x float> %112, i64 0
  %120 = extractelement <1 x float> %114, i64 0
  %121 = extractelement <1 x float> %116, i64 0
  %122 = extractelement <1 x float> %118, i64 0
  %123 = call { i32, i32, i32, i32 } asm sideeffect "@$4 ld.global.v4.b32 {$0,$1,$2,$3}, [ $5 + 2048];", "=r,=r,=r,=r,b,l"(i1 %48, float addrspace(1)* %102)
  %124 = extractvalue { i32, i32, i32, i32 } %123, 0
  %125 = bitcast i32 %124 to <1 x float>
  %126 = extractvalue { i32, i32, i32, i32 } %123, 1
  %127 = bitcast i32 %126 to <1 x float>
  %128 = extractvalue { i32, i32, i32, i32 } %123, 2
  %129 = bitcast i32 %128 to <1 x float>
  %130 = extractvalue { i32, i32, i32, i32 } %123, 3
  %131 = bitcast i32 %130 to <1 x float>
  %132 = extractelement <1 x float> %125, i64 0
  %133 = extractelement <1 x float> %127, i64 0
  %134 = extractelement <1 x float> %129, i64 0
  %135 = extractelement <1 x float> %131, i64 0
  %136 = fadd float %77, %119
  %137 = fadd float %78, %120
  %138 = fadd float %79, %121
  %139 = fadd float %80, %122
  %140 = fadd float %90, %132
  %141 = fadd float %91, %133
  %142 = fadd float %92, %134
  %143 = fadd float %93, %135
  %144 = getelementptr float, float addrspace(1)* %2, i32 %28
  %145 = getelementptr float, float addrspace(1)* %144, i32 0
  %146 = getelementptr float, float addrspace(1)* %2, i32 %30
  %147 = getelementptr float, float addrspace(1)* %146, i32 1
  %148 = getelementptr float, float addrspace(1)* %2, i32 %32
  %149 = getelementptr float, float addrspace(1)* %148, i32 2
  %150 = getelementptr float, float addrspace(1)* %2, i32 %34
  %151 = getelementptr float, float addrspace(1)* %150, i32 3
  %152 = getelementptr float, float addrspace(1)* %2, i32 %36
  %153 = getelementptr float, float addrspace(1)* %152, i32 512
  %154 = getelementptr float, float addrspace(1)* %2, i32 %38
  %155 = getelementptr float, float addrspace(1)* %154, i32 513
  %156 = getelementptr float, float addrspace(1)* %2, i32 %40
  %157 = getelementptr float, float addrspace(1)* %156, i32 514
  %158 = getelementptr float, float addrspace(1)* %2, i32 %42
  %159 = getelementptr float, float addrspace(1)* %158, i32 515
  %160 = bitcast float addrspace(1)* %145 to <4 x float> addrspace(1)*
  %161 = insertelement <4 x float> undef, float %136, i64 0
  %162 = insertelement <4 x float> %161, float %137, i64 1
  %163 = insertelement <4 x float> %162, float %138, i64 2
  %164 = insertelement <4 x float> %163, float %139, i64 3
  br i1 %44, label %165, label %166

165:                                              ; preds = %entry
  store <4 x float> %164, <4 x float> addrspace(1)* %160, align 16
  br label %166

166:                                              ; preds = %entry, %165
  %167 = bitcast float addrspace(1)* %153 to <4 x float> addrspace(1)*
  %168 = insertelement <4 x float> undef, float %140, i64 0
  %169 = insertelement <4 x float> %168, float %141, i64 1
  %170 = insertelement <4 x float> %169, float %142, i64 2
  %171 = insertelement <4 x float> %170, float %143, i64 3
  br i1 %48, label %172, label %173

172:                                              ; preds = %166
  store <4 x float> %171, <4 x float> addrspace(1)* %167, align 16
  br label %173

173:                                              ; preds = %166, %172
  call void @llvm.donothing()
  call void @llvm.donothing()
  ret void
}

; Function Attrs: nounwind readnone
declare i32 @llvm.nvvm.read.ptx.sreg.tid.x() #0

; Function Attrs: nounwind readnone
declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #0

; Function Attrs: nounwind readnone willreturn
declare void @llvm.donothing() #1

attributes #0 = { nounwind readnone }
attributes #1 = { nounwind readnone willreturn }

!nvvm.annotations = !{!0, !1}

!0 = !{void (float addrspace(1)*, float addrspace(1)*, float addrspace(1)*, i32)* @add_kernel, !"kernel", i32 1}
!1 = !{void (float addrspace(1)*, float addrspace(1)*, float addrspace(1)*, i32)* @add_kernel, !"maxntidx", i32 128}

PTX

//
// Generated by LLVM NVPTX Back-End
//

.version 7.3
.target sm_75
.address_size 64

    // .globl   add_kernel

.visible .entry add_kernel(
    .param .u64 add_kernel_param_0,
    .param .u64 add_kernel_param_1,
    .param .u64 add_kernel_param_2,
    .param .u32 add_kernel_param_3
)
.maxntid 128, 1, 1
{
    .reg .pred  %p<7>;
    .reg .f32   %f<29>;
    .reg .b32   %r<23>;
    .reg .b64   %rd<10>;

    ld.param.u32    %r2, [add_kernel_param_3];
    ld.param.u64    %rd6, [add_kernel_param_0];
    ld.param.u64    %rd7, [add_kernel_param_1];
    mov.u32     %r19, %tid.x;
    ld.param.u64    %rd8, [add_kernel_param_2];
    shl.b32     %r20, %r19, 2;
    mov.u32     %r21, %ctaid.x;
    mad.lo.s32  %r22, %r21, 1024, %r20;
    add.s32     %r1, %r22, 512;
    setp.ge.s32     %p5, %r22, %r2;
    setp.lt.s32     %p1, %r22, %r2;
    setp.lt.s32     %p2, %r1, %r2;
    mul.wide.s32    %rd9, %r22, 4;
    add.s64     %rd2, %rd6, %rd9;
    @%p1 ld.global.v4.b32 {%r3,%r4,%r5,%r6}, [ %rd2 + 0];
    @%p2 ld.global.v4.b32 {%r7,%r8,%r9,%r10}, [ %rd2 + 2048];
    add.s64     %rd4, %rd7, %rd9;
    @%p1 ld.global.v4.b32 {%r11,%r12,%r13,%r14}, [ %rd4 + 0];
    @%p2 ld.global.v4.b32 {%r15,%r16,%r17,%r18}, [ %rd4 + 2048];
    add.s64     %rd1, %rd8, %rd9;
    @%p5 bra    LBB0_2;
    mov.b32     %f13, %r3;
    mov.b32     %f14, %r4;
    mov.b32     %f15, %r5;
    mov.b32     %f16, %r6;
    mov.b32     %f21, %r11;
    mov.b32     %f22, %r12;
    mov.b32     %f23, %r13;
    mov.b32     %f24, %r14;
    add.f32     %f5, %f13, %f21;
    add.f32     %f6, %f14, %f22;
    add.f32     %f7, %f15, %f23;
    add.f32     %f8, %f16, %f24;
    st.global.v4.f32    [%rd1], {%f5, %f6, %f7, %f8};
LBB0_2:
    setp.ge.s32     %p6, %r1, %r2;
    @%p6 bra    LBB0_4;
    mov.b32     %f17, %r7;
    mov.b32     %f18, %r8;
    mov.b32     %f19, %r9;
    mov.b32     %f20, %r10;
    mov.b32     %f25, %r15;
    mov.b32     %f26, %r16;
    mov.b32     %f27, %r17;
    mov.b32     %f28, %r18;
    add.f32     %f1, %f17, %f25;
    add.f32     %f2, %f18, %f26;
    add.f32     %f3, %f19, %f27;
    add.f32     %f4, %f20, %f28;
    st.global.v4.f32    [%rd1+2048], {%f1, %f2, %f3, %f4};
LBB0_4:
    ret;

}
chsasank commented 3 months ago

I don't like seeing inline assembly here :(.

Why do we need inline assembly? Can't LLVM intrinsics get us to this?

dasdibye commented 3 months ago

Had a quick look at the IR. My observations are the following:

  1. Inline asm is used only for ld.v4 and st.v4. My guess is there is no intrinsic/builtin for this supported in LLVM IR in this compilation pipe. If not, they should be added.
  2. The compute happens normally via the fadd.
  3. The Triton code is unrolled by a factor of 2. Though you dont see any loop its implicit due to the BLOCK.
  4. You will see 4 ld.v4 and 2 st.v4 due to UF=2. Also lots of extract and bit cast around each such call to get the loaded/stored value.
  5. There are 8 fadd as 4 elements are loaded via ld.v4 and UF=2. All these add to the size of the code.
chsasank commented 3 months ago

Thanks a lot for your analysis. I wonder why these assembly instructions are used instead of LLVM intrinsics.