google / uVkCompute

A micro Vulkan compute pipeline and a collection of benchmarking compute shaders
Apache License 2.0
224 stars 38 forks source link

Add sample argmax kernel for a single subgroup #47

Closed qedawkins closed 2 months ago

qedawkins commented 11 months ago

(not expecting a review yet, this is still a draft)

qedawkins commented 11 months ago

Can't seem to add a reviewer, so @raikonenfnu

kuhar commented 11 months ago

@qedawkins @antiagainst ISA for the subgroup ops from llpc:

Command: amdllpc one_workgroup_argmax_subgroup_f32.comp -o /dev/null --gfxip=11.0 -v Assembly:

.AMDGPU.disasm (size = 6467 bytes)
_amdgpu_cs_main:
BB0_0:
    s_getpc_b64 s[4:5]                                                                   ; BE844700
    s_mov_b32 s0, s1                                                                     ; BE800001
    s_mov_b32 s1, s5                                                                     ; BE810005
    v_and_b32_e32 v3, 0x3ff, v0                                                          ; 360600FF 000003FF
    s_load_b256 s[4:11], s[0:1], 0x0                                                     ; F40C0100 F8000000
    v_mov_b32_e32 v0, 0                                                                  ; 7E000280
    s_cmpk_lt_u32 s2, 0x80                                                               ; B6820080
    s_delay_alu instid0(VALU_DEP_2)                                                      ; BF870002
    v_lshlrev_b32_e32 v5, 2, v3                                                          ; 300A0682
    s_waitcnt lgkmcnt(0)                                                                 ; BF89FC07
    buffer_load_b32 v4, v5, s[4:7], 0 offen                                              ; E0500000 80410405
    s_cbranch_scc1 .LBB0_3                                                               ; BFA20000
    v_add_nc_u32_e32 v5, 0x100, v5                                                       ; 4A0A0AFF 00000100
    v_add_nc_u32_e32 v6, 64, v3                                                          ; 4A0C06C0
    v_mov_b32_e32 v0, 0                                                                  ; 7E000280
    s_lshr_b32 s0, s2, 6                                                                 ; 85008602
    s_delay_alu instid0(SALU_CYCLE_1)                                                    ; BF870009
    s_add_i32 s0, s0, -1                                                                 ; 8100C100
BB0_2:
    buffer_load_b32 v7, v5, s[4:7], 0 offen                                              ; E0500000 80410705
    s_waitcnt vmcnt(1)                                                                   ; BF8907F7
    v_mov_b32_e32 v8, v4                                                                 ; 7E100304
    v_add_nc_u32_e32 v5, 0x100, v5                                                       ; 4A0A0AFF 00000100
    s_add_i32 s0, s0, -1                                                                 ; 8100C100
    s_delay_alu instid0(SALU_CYCLE_1)                                                    ; BF870009
    s_cmp_lg_u32 s0, 0                                                                   ; BF078000
    s_waitcnt vmcnt(0)                                                                   ; BF8903F7
    v_cmp_lt_f32_e32 vcc, v8, v7                                                         ; 7C220F08
    v_max_f32_e32 v4, v8, v7                                                             ; 20080F08
    v_cndmask_b32_e32 v0, v0, v6, vcc                                                    ; 02000D00
    v_add_nc_u32_e32 v6, 64, v6                                                          ; 4A0C0CC0
    s_cbranch_scc1 .LBB0_2                                                               ; BFA20000
BB0_3:
    s_waitcnt vmcnt(0)                                                                   ; BF8903F7
    v_mov_b32_e32 v1, v4                                                                 ; 7E020304
    s_not_b64 exec, exec                                                                 ; BEFE1F7E
    v_mov_b32_e32 v1, 0xff800000                                                         ; 7E0202FF FF800000
    s_not_b64 exec, exec                                                                 ; BEFE1F7E
    s_or_saveexec_b64 s[0:1], -1                                                         ; BE8023C1
    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)               ; BF870091
    v_max_f32_dpp v1, v1, v1 quad_perm:[1,0,3,2] row_mask:0xf bank_mask:0xf bound_ctrl:1 ; 200202FA FF08B101
    v_max_f32_dpp v1, v1, v1 quad_perm:[2,3,0,1] row_mask:0xf bank_mask:0xf bound_ctrl:1 ; 200202FA FF084E01
    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)               ; BF870091
    v_max_f32_dpp v1, v1, v1 row_half_mirror row_mask:0xf bank_mask:0xf bound_ctrl:1     ; 200202FA FF094101
    v_max_f32_dpp v1, v1, v1 row_mirror row_mask:0xf bank_mask:0xf bound_ctrl:1          ; 200202FA FF094001
    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)               ; BF870091
    v_permlanex16_b32 v2, v1, -1, -1 op_sel:[1,0]                                        ; D65C0802 03058301
    v_max_f32_e32 v1, v1, v2                                                             ; 20020501
    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)               ; BF870091
    v_permlane64_b32 v2, v1                                                              ; 7E04CF01
    v_max_f32_e32 v1, v1, v2                                                             ; 20020501
    s_mov_b64 exec, s[0:1]                                                               ; BEFE0100
    s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)               ; BF870091
    v_mov_b32_e32 v5, v1                                                                 ; 7E0A0301
    v_cmp_eq_f32_e32 vcc, v4, v5                                                         ; 7C240B04
    s_ctz_i32_b32 s0, vcc_hi                                                             ; BE80086B
    s_ctz_i32_b32 s1, vcc_lo                                                             ; BE81086A
    s_add_i32 s0, s0, 32                                                                 ; 8100A000
    s_delay_alu instid0(SALU_CYCLE_1) | instskip(NEXT) | instid1(SALU_CYCLE_1)           ; BF870499
    s_min_u32 s0, s1, s0                                                                 ; 89800001
    v_cmp_eq_u32_e32 vcc, s0, v3                                                         ; 7C940600
    s_and_saveexec_b64 s[0:1], vcc                                                       ; BE80216A
    s_cbranch_execz .LBB0_5                                                              ; BFA50000
    buffer_store_b32 v0, off, s[8:11], 0                                                 ; E0680000 80020000
BB0_5:
    s_nop 0                                                                              ; BF800000
    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)                                                 ; BFB60003
    s_endpgm                                                                             ; BFB00000

llvm IR for BB4:

._crit_edge:                                      ; preds = %.lr.ph, %.entry
  %laneMax.0.lcssa = phi float [ %10, %.entry ], [ %17, %.lr.ph ]
  %laneResult.0.lcssa = phi i32 [ 0, %.entry ], [ %19, %.lr.ph ]
  %21 = bitcast float %laneMax.0.lcssa to i32
  %22 = call i32 @llvm.amdgcn.set.inactive.i32(i32 %21, i32 -8388608)
  %23 = bitcast i32 %22 to float
  %24 = call i32 @llvm.amdgcn.update.dpp.i32(i32 undef, i32 %22, i32 177, i32 15, i32 15, i1 true)
  %25 = bitcast i32 %24 to float
  %26 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %23, float %25)
  %27 = bitcast float %26 to i32
  %28 = call i32 @llvm.amdgcn.update.dpp.i32(i32 undef, i32 %27, i32 78, i32 15, i32 15, i1 true)
  %29 = bitcast i32 %28 to float
  %30 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %26, float %29)
  %31 = bitcast float %30 to i32
  %32 = call i32 @llvm.amdgcn.update.dpp.i32(i32 undef, i32 %31, i32 321, i32 15, i32 15, i1 true)
  %33 = bitcast i32 %32 to float
  %34 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %30, float %33)
  %35 = bitcast float %34 to i32
  %36 = call i32 @llvm.amdgcn.update.dpp.i32(i32 undef, i32 %35, i32 320, i32 15, i32 15, i1 true)
  %37 = bitcast i32 %36 to float
  %38 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %34, float %37)
  %39 = bitcast float %38 to i32
  %40 = call i32 @llvm.amdgcn.permlanex16(i32 undef, i32 %39, i32 -1, i32 -1, i1 true, i1 false)
  %41 = bitcast i32 %40 to float
  %42 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %38, float %41)
  %43 = bitcast float %42 to i32
  %44 = call i32 @llvm.amdgcn.permlane64(i32 %43)
  %45 = bitcast i32 %44 to float
  %46 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %42, float %45)
  %47 = bitcast float %46 to i32
  %48 = call i32 @llvm.amdgcn.wwm.i32(i32 %47)
  %49 = bitcast i32 %48 to float
  %50 = fcmp oeq float %laneMax.0.lcssa, %49
  %51 = call i64 @llvm.amdgcn.ballot.i64(i1 %50)
  %52 = call i64 @llvm.cttz.i64(i64 %51, i1 true), !range !8
  %.fr1 = freeze i64 %52
  %53 = trunc i64 %.fr1 to i32
  %54 = icmp eq i32 %8, %53
  br i1 %54, label %55, label %56
kuhar commented 11 months ago

ISA for gfx90:

_amdgpu_cs_main:
BB0_0:
    s_getpc_b64 s[4:5]                                                                   ; BE841C00
    s_mov_b32 s0, s1                                                                     ; BE800001
    s_mov_b32 s1, s5                                                                     ; BE810005
    s_load_dwordx8 s[4:11], s[0:1], 0x0                                                  ; C00E0100 00000000
    v_lshlrev_b32_e32 v6, 2, v0                                                          ; 240C0082
    s_cmpk_lt_u32 s2, 0x80                                                               ; B6020080
    v_mov_b32_e32 v4, 0                                                                  ; 7E080280
    s_waitcnt lgkmcnt(0)                                                                 ; BF8CC07F
    buffer_load_dword v5, v6, s[4:7], 0 offen                                            ; E0501000 80010506
    s_cbranch_scc1 .LBB0_3                                                               ; BF850000
    s_lshr_b32 s0, s2, 6                                                                 ; 8F008602
    s_add_i32 s0, s0, -1                                                                 ; 8100C100
    v_add_u32_e32 v6, 0x100, v6                                                          ; 680C0CFF 00000100
    v_add_u32_e32 v7, 64, v0                                                             ; 680E00C0
    v_mov_b32_e32 v4, 0                                                                  ; 7E080280
BB0_2:
    buffer_load_dword v8, v6, s[4:7], 0 offen                                            ; E0501000 80010806
    s_waitcnt vmcnt(1)                                                                   ; BF8C0F71
    v_mov_b32_e32 v9, v5                                                                 ; 7E120305
    s_add_i32 s0, s0, -1                                                                 ; 8100C100
    v_add_u32_e32 v6, 0x100, v6                                                          ; 680C0CFF 00000100
    s_cmp_lg_u32 s0, 0                                                                   ; BF078000
    s_waitcnt vmcnt(0)                                                                   ; BF8C0F70
    v_cmp_lt_f32_e32 vcc, v9, v8                                                         ; 7C821109
    v_max_f32_e32 v5, v9, v8                                                             ; 160A1109
    v_cndmask_b32_e32 v4, v4, v7, vcc                                                    ; 00080F04
    v_add_u32_e32 v7, 64, v7                                                             ; 680E0EC0
    s_cbranch_scc1 .LBB0_2                                                               ; BF850000
BB0_3:
    s_or_saveexec_b64 s[0:1], -1                                                         ; BE8021C1
    v_mov_b32_e32 v1, 0xff800000                                                         ; 7E0202FF FF800000
    s_mov_b64 exec, s[0:1]                                                               ; BEFE0100
    s_waitcnt vmcnt(0)                                                                   ; BF8C0F70
    v_mov_b32_e32 v2, v5                                                                 ; 7E040305
    s_not_b64 exec, exec                                                                 ; BEFE057E
    v_mov_b32_e32 v2, 0xff800000                                                         ; 7E0402FF FF800000
    s_not_b64 exec, exec                                                                 ; BEFE057E
    s_or_saveexec_b64 s[0:1], -1                                                         ; BE8021C1
    v_max_f32_dpp v2, v2, v2 quad_perm:[1,0,3,2] row_mask:0xf bank_mask:0xf bound_ctrl:1 ; 160404FA FF08B102
    v_mov_b32_e32 v3, 0xff800000                                                         ; 7E0602FF FF800000
    s_nop 0                                                                              ; BF800000
    v_max_f32_dpp v2, v2, v2 quad_perm:[2,3,0,1] row_mask:0xf bank_mask:0xf bound_ctrl:1 ; 160404FA FF084E02
    s_nop 1                                                                              ; BF800001
    v_max_f32_dpp v2, v2, v2 row_half_mirror row_mask:0xf bank_mask:0xf bound_ctrl:1     ; 160404FA FF094102
    s_nop 1                                                                              ; BF800001
    v_max_f32_dpp v2, v2, v2 row_mirror row_mask:0xf bank_mask:0xf bound_ctrl:1          ; 160404FA FF094002
    s_nop 1                                                                              ; BF800001
    v_mov_b32_dpp v3, v2 row_bcast:15 row_mask:0xa bank_mask:0xf bound_ctrl:1            ; 7E0602FA AF094202
    v_max_f32_e32 v2, v2, v3                                                             ; 16040702
    s_nop 1                                                                              ; BF800001
    v_mov_b32_dpp v1, v2 row_bcast:31 row_mask:0x8 bank_mask:0xf bound_ctrl:1            ; 7E0202FA 8F094302
    v_max_f32_e32 v1, v2, v1                                                             ; 16020302
    v_readlane_b32 s2, v1, 63                                                            ; D2890002 00017F01
    s_mov_b64 exec, s[0:1]                                                               ; BEFE0100
    v_cmp_eq_f32_e32 vcc, s2, v5                                                         ; 7C840A02
    s_ff1_i32_b32 s0, vcc_hi                                                             ; BE80106B
    s_add_i32 s0, s0, 32                                                                 ; 8100A000
    s_ff1_i32_b32 s1, vcc_lo                                                             ; BE81106A
    s_min_u32 s0, s1, s0                                                                 ; 83800001
    v_cmp_eq_u32_e32 vcc, s0, v0                                                         ; 7D940000
    s_and_saveexec_b64 s[0:1], vcc                                                       ; BE80206A
    s_cbranch_execz .LBB0_5                                                              ; BF880000
    buffer_store_dword v4, off, s[8:11], 0                                               ; E0700000 80020400
BB0_5:
    s_endpgm                                                                             ; BF810000

llvm IR for BB4:

._crit_edge:                                      ; preds = %.lr.ph, %.entry
  %laneMax.0.lcssa = phi float [ %9, %.entry ], [ %16, %.lr.ph ]
  %laneResult.0.lcssa = phi i32 [ 0, %.entry ], [ %18, %.lr.ph ]
  %20 = bitcast float %laneMax.0.lcssa to i32
  %21 = call i32 @llvm.amdgcn.set.inactive.i32(i32 %20, i32 -8388608)
  %22 = bitcast i32 %21 to float
  %23 = call i32 @llvm.amdgcn.update.dpp.i32(i32 undef, i32 %21, i32 177, i32 15, i32 15, i1 true)
  %24 = bitcast i32 %23 to float
  %25 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %22, float %24)
  %26 = bitcast float %25 to i32
  %27 = call i32 @llvm.amdgcn.update.dpp.i32(i32 undef, i32 %26, i32 78, i32 15, i32 15, i1 true)
  %28 = bitcast i32 %27 to float
  %29 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %25, float %28)
  %30 = bitcast float %29 to i32
  %31 = call i32 @llvm.amdgcn.update.dpp.i32(i32 undef, i32 %30, i32 321, i32 15, i32 15, i1 true)
  %32 = bitcast i32 %31 to float
  %33 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %29, float %32)
  %34 = bitcast float %33 to i32
  %35 = call i32 @llvm.amdgcn.update.dpp.i32(i32 undef, i32 %34, i32 320, i32 15, i32 15, i1 true)
  %36 = bitcast i32 %35 to float
  %37 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %33, float %36)
  %38 = bitcast float %37 to i32
  %39 = call i32 @llvm.amdgcn.update.dpp.i32(i32 -8388608, i32 %38, i32 322, i32 10, i32 15, i1 true)
  %40 = bitcast i32 %39 to float
  %41 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %37, float %40)
  %42 = bitcast float %41 to i32
  %43 = call i32 @llvm.amdgcn.update.dpp.i32(i32 -8388608, i32 %42, i32 323, i32 8, i32 15, i1 true)
  %44 = bitcast i32 %43 to float
  %45 = call reassoc nnan nsz arcp contract afn float @llvm.maxnum.f32(float %41, float %44)
  %46 = bitcast float %45 to i32
  %47 = call i32 @llvm.amdgcn.readlane(i32 %46, i32 63)
  %48 = call i32 @llvm.amdgcn.wwm.i32(i32 %47)
  %49 = bitcast i32 %48 to float
  %50 = fcmp oeq float %laneMax.0.lcssa, %49
  %51 = call i64 @llvm.amdgcn.ballot.i64(i1 %50)
  %52 = call i64 @llvm.cttz.i64(i64 %51, i1 true), !range !8
  %.fr1 = freeze i64 %52
  %53 = trunc i64 %.fr1 to i32
  %54 = icmp eq i32 %LocalInvocationId.i0, %53
  br i1 %54, label %55, label %56
antiagainst commented 11 months ago

Awesome, thanks Quinn and Jakub! We can get these intrinsics via HIP (courtesy from ChatGPT):

For GLSL subgroupMax:

__inline__ __device__ float warpMax(float val) {
    for (int offset = warpSize / 2; offset > 0; offset /= 2) {
        val = max(val, __shfl_down(val, offset));
    }
    return val;
}

For GLSL subgroupBallot, we can use __ballot.

For GLSL subgroupBallotFindLSB, we can use __ffsll I think.

kuhar commented 2 months ago

Closing due to inactivity