Open maleadt opened 6 years ago
it should have still been able to vectorize accesses to elements 6 and 7
Looks like LSV doesn't split the chain when it turns out misaligned, eg.:
LSV: Loads to try and vectorize:
%17 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 6), align 8
%22 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 7), align 4
%29 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 8), align 32
%39 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 9), align 4
This could be split into vectorizable chains of elements 6 7 and 8 9, but LSV goes through until it discards the entire chain due to misalignment. Maybe isLegalToVectorizeLoadChain
should return false for this chain? Not sure what the semantics are of that API call.
Finally, some bitcode that contains both of these problematic patterns (chains with gaps, misaligned chains), generated from the MWEs above:
; ModuleID = 'kernel'
source_filename = "kernel"
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"
target triple = "nvptx64-nvidia-cuda"
@shmem1 = addrspace(3) global [25 x float] zeroinitializer, align 128
@shmem2 = local_unnamed_addr addrspace(3) global [25 x float] zeroinitializer, align 128
; Function Attrs: nounwind readnone
declare i32 @llvm.nvvm.read.ptx.sreg.tid.x() #0
define void @ptxcall_kernel_3({ [2 x i64], { i64 } }, i64) local_unnamed_addr {
L40.L45_crit_edge.i.2:
%2 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%3 = mul nuw nsw i32 %2, 5
%mulconv = add nuw nsw i32 %3, 5
%4 = zext i32 %mulconv to i64
%5 = add nsw i64 %4, -30
%6 = getelementptr [25 x float], [25 x float] addrspace(3)* @shmem2, i64 0, i64 %5
%7 = load float, float addrspace(3)* %6, align 4
%8 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 0), align 128
%9 = fdiv float %7, %8
store float %9, float addrspace(3)* %6, align 4
%10 = add nsw i64 %4, -29
%11 = getelementptr [25 x float], [25 x float] addrspace(3)* @shmem2, i64 0, i64 %10
%.promoted.1 = load float, float addrspace(3)* %11, align 4
%12 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 1), align 4
%13 = fsub float %.promoted.1, %12
%14 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 6), align 8
%15 = fdiv float %13, %14
store float %15, float addrspace(3)* %11, align 4
%16 = add nsw i64 %4, -28
%17 = getelementptr [25 x float], [25 x float] addrspace(3)* @shmem2, i64 0, i64 %16
%.promoted.2 = load float, float addrspace(3)* %17, align 4
%18 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 2), align 8
%19 = fsub float %.promoted.2, %18
%20 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 7), align 4
%21 = fsub float %19, %20
%22 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 12), align 16
%23 = fdiv float %21, %22
store float %23, float addrspace(3)* %17, align 4
%24 = add nsw i64 %4, -27
%25 = getelementptr [25 x float], [25 x float] addrspace(3)* @shmem2, i64 0, i64 %24
%.promoted.3 = load float, float addrspace(3)* %25, align 4
%26 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 3), align 4
%27 = fsub float %.promoted.3, %26
%28 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 8), align 32
%29 = fsub float %27, %28
%30 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 13), align 4
%31 = fsub float %29, %30
%32 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 18), align 8
%33 = fdiv float %31, %32
store float %33, float addrspace(3)* %25, align 4
%34 = add nsw i64 %4, -26
%35 = getelementptr [25 x float], [25 x float] addrspace(3)* @shmem2, i64 0, i64 %34
%.promoted.4 = load float, float addrspace(3)* %35, align 4
%36 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 4), align 16
%37 = fsub float %.promoted.4, %36
%38 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 9), align 4
%39 = fsub float %37, %38
%40 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 14), align 8
%41 = fsub float %39, %40
%42 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 19), align 4
%43 = fsub float %41, %42
%44 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 24), align 32
%45 = fdiv float %43, %44
store float %45, float addrspace(3)* %35, align 4
ret void
}
cc @alinas, @arsenm, @jholewinski
So this probably needs a fix to the Load Store Vectorizer to vectorize chains with gaps.
Interesting... The heuristic would be, vectorize (up to 16 bytes) if it reduces the number of load instructions? (Obviously we can't do this for stores.)
I'm not sure this would always be beneficial. There's the obvious question about increasing memory traffic. For global memory maybe this is OK: A vectorized load is always aligned, so presumably it can't span two cache lines. But I'd expect that shared memory operates at granularity of 1 or 4 bytes (since on Volta shared memory is fungible with L1 cache), and so this would actually result in us loading more data.
The other problem it creates is increased register pressure. From the SASS, it appears that a vectorized load loads into four consecutive registers: LDS.128 R4, [RZ]
. I'm not sure how LLVM is going to know that this is a good trade-off.
Not saying we shouldn't do it, just that I'm worried about pessimizing other workloads.
This could be split into vectorizable chains of elements 6 7 and 8 9, but LSV goes through until it discards the entire chain due to misalignment.
That looks like a bug to me that we can/should fix.
int idx = threadIdx.x - BLOCK_SIZE;
Isn't idx
negative for threadIdx.x < 5
?
Perhaps this is just for your MWE, but indexing past the beginning of an array is (?) UB, and so I wouldn't be surprised if this causes LLVM to do unreasonable things...
Perhaps this is just for your MWE, but indexing past the beginning of an array is (?) UB, and so I wouldn't be surprised if this causes LLVM to do unreasonable things...
Yeah that's fair. It's a result from reducing the original benchmark, but it doesn't seem to change generated code. I'll update the examples.
Making LSV bail out earlier and have it split the chain fixes one of these issues:
Index: Transforms/Vectorize/LoadStoreVectorizer.cpp
===================================================================
--- Transforms/Vectorize/LoadStoreVectorizer.cpp (revision 336178)
+++ Transforms/Vectorize/LoadStoreVectorizer.cpp (working copy)
@@ -874,11 +878,13 @@
Chain = NewChain;
ChainSize = Chain.size();
- // Check if it's legal to vectorize this chain. If not, split the chain and
- // try again.
+ // Check if it's legal to vectorize this chain, and whether it's aligned.
+ // If not, split the chain and try again.
unsigned EltSzInBytes = Sz / 8;
unsigned SzInBytes = EltSzInBytes * ChainSize;
- if (!TTI.isLegalToVectorizeStoreChain(SzInBytes, Alignment, AS)) {
+ if (!TTI.isLegalToVectorizeStoreChain(SzInBytes, Alignment, AS) ||
+ (S0->getPointerAddressSpace() != 0 &&
+ accessIsMisaligned(SzInBytes, AS, Alignment))) {
auto Chains = splitOddVectorElts(Chain, Sz);
return vectorizeStoreChain(Chains.first, InstructionsProcessed) |
vectorizeStoreChain(Chains.second, InstructionsProcessed);
@@ -1022,11 +1037,13 @@
Chain = NewChain;
ChainSize = Chain.size();
- // Check if it's legal to vectorize this chain. If not, split the chain and
- // try again.
+ // Check if it's legal to vectorize this chain, and whether it's aligned.
+ // If not, split the chain and try again.
unsigned EltSzInBytes = Sz / 8;
unsigned SzInBytes = EltSzInBytes * ChainSize;
- if (!TTI.isLegalToVectorizeLoadChain(SzInBytes, Alignment, AS)) {
+ if (!TTI.isLegalToVectorizeLoadChain(SzInBytes, Alignment, AS) ||
+ (L0->getPointerAddressSpace() != 0 &&
+ accessIsMisaligned(SzInBytes, AS, Alignment))) {
auto Chains = splitOddVectorElts(Chain, Sz);
return vectorizeLoadChain(Chains.first, InstructionsProcessed) |
vectorizeLoadChain(Chains.second, InstructionsProcessed);
LSV: Loads to try and vectorize:
%17 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 6), align 8
%22 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 7), align 4
%29 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 8), align 32
%39 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 9), align 4
LSV: Target said misaligned is allowed? 0 and fast? 0
LSV: Loads to try and vectorize:
%17 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 6), align 8
%22 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 7), align 4
LSV: Loads to vectorize:
%17 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 6), align 8
%22 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 7), align 4
LSV: Loads to try and vectorize:
%30 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 8), align 32
%40 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 9), align 4
LSV: Loads to vectorize:
%30 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 8), align 32
%40 = load float, float addrspace(3)* getelementptr inbounds ([25 x float], [25 x float] addrspace(3)* @shmem1, i64 0, i64 9), align 4
I'll put something on phabricator as soon as I have some time.
Had a go with the machine scheduler memop cluster mutation, and while it does seem to work well (and combined with improved alignment & load/store vectorization even results in drastically fewer loads/stores in the SASS code), it didn't fix the performance problem I was looking at. So in the meantime, dumping a patch here for when I have time to revisit this again:
Index: MCTargetDesc/NVPTXBaseInfo.h
===================================================================
--- MCTargetDesc/NVPTXBaseInfo.h (revision 336178)
+++ MCTargetDesc/NVPTXBaseInfo.h (working copy)
@@ -33,6 +33,8 @@
namespace NVPTXII {
enum {
// These must be kept in sync with TSFlags in NVPTXInstrFormats.td
+ IsLoadFlag = 0x20,
+ IsStoreFlag = 0x40,
IsTexFlag = 0x80,
IsSuldMask = 0x300,
IsSuldShift = 8,
Index: NVPTXInstrFormats.td
===================================================================
--- NVPTXInstrFormats.td (revision 336178)
+++ NVPTXInstrFormats.td (working copy)
@@ -29,6 +29,7 @@
dag InOperandList = ins;
let AsmString = asmstr;
let Pattern = pattern;
+ let UseNamedOperandTable = 1;
// TSFlagFields
bits<4> VecInstType = VecNOP.Value;
Index: NVPTXInstrInfo.cpp
===================================================================
--- NVPTXInstrInfo.cpp (revision 336178)
+++ NVPTXInstrInfo.cpp (working copy)
@@ -23,6 +23,7 @@
using namespace llvm;
#define GET_INSTRINFO_CTOR_DTOR
+#define GET_INSTRINFO_NAMED_OPS
#include "NVPTXGenInstrInfo.inc"
// Pin the vtable to this file.
@@ -30,6 +31,15 @@
NVPTXInstrInfo::NVPTXInstrInfo() : NVPTXGenInstrInfo(), RegInfo() {}
+MachineOperand *NVPTXInstrInfo::getNamedOperand(MachineInstr &MI,
+ unsigned OperandName) const {
+ int Idx = NVPTX::getNamedOperandIdx(MI.getOpcode(), OperandName);
+ if (Idx == -1)
+ return nullptr;
+
+ return &MI.getOperand(Idx);
+}
+
void NVPTXInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
MachineBasicBlock::iterator I,
const DebugLoc &DL, unsigned DestReg,
@@ -206,3 +216,48 @@
BuildMI(&MBB, DL, get(NVPTX::GOTO)).addMBB(FBB);
return 2;
}
+
+bool NVPTXInstrInfo::getMemOpBaseRegImmOfs(MachineInstr &LdSt, unsigned &BaseReg,
+ int64_t &OffsetValue,
+ const TargetRegisterInfo *TRI) const {\
+ const MCInstrDesc &MCID = LdSt.getDesc();
+ // TODO: identify vector loads
+ if (MCID.TSFlags & NVPTXII::IsLoadFlag) {
+ const MachineOperand *Offset = getNamedOperand(LdSt, NVPTX::OpName::offset);
+ const MachineOperand *Base = getNamedOperand(LdSt, NVPTX::OpName::addr);
+ // FIXME: this is a WIP, need to match up with the tablegen defs
+
+ if (Offset) {
+ OffsetValue = Offset->getImm();
+ } else {
+ OffsetValue = 0; // XXX: is this correct? eg. LDV_f32_v4_avar, no offset
+ // TODO: make sure all offset args are actually named offset.
+ // TODO: check for special instrs (eg. amdgpu has instr with two offsets)
+ }
+
+ assert(Base && "Cannot handle load without address argument");
+ if (Base->isReg()) {
+ BaseReg = Base->getReg();
+ }
+ else if (Base->isGlobal()) {
+ // HACK: NVPTX demotes global variables to shared memory during ISel...
+ // let's try with a pseudo register, this is only used to unique.
+ BaseReg = (unsigned)(uintptr_t)Base->getGlobal();
+ } else {
+ // FIXME: other cases, eg. load from param mem
+ return false;
+ }
+ return true;
+ }
+ else if (MCID.TSFlags & NVPTXII::IsStoreFlag) {
+ // TODO
+ }
+ return false;
+}
+
+bool NVPTXInstrInfo::shouldClusterMemOps(MachineInstr &FirstLdSt, unsigned BaseReg1,
+ MachineInstr &SecondLdSt, unsigned BaseReg2,
+ unsigned NumLoads) const {
+ // nvcc is crazy aggressive
+ return true;
+}
Index: NVPTXInstrInfo.h
===================================================================
--- NVPTXInstrInfo.h (revision 336178)
+++ NVPTXInstrInfo.h (working copy)
@@ -19,6 +19,7 @@
#include "llvm/CodeGen/TargetInstrInfo.h"
#define GET_INSTRINFO_HEADER
+#define GET_INSTRINFO_OPERAND_ENUM
#include "NVPTXGenInstrInfo.inc"
namespace llvm {
@@ -31,6 +32,17 @@
const NVPTXRegisterInfo &getRegisterInfo() const { return RegInfo; }
+ /// Returns the operand named \p Op. If \p MI does not have an
+ /// operand named \c Op, this function returns nullptr.
+ LLVM_READONLY
+ MachineOperand *getNamedOperand(MachineInstr &MI, unsigned OperandName) const;
+
+ LLVM_READONLY
+ const MachineOperand *getNamedOperand(const MachineInstr &MI,
+ unsigned OpName) const {
+ return getNamedOperand(const_cast<MachineInstr &>(MI), OpName);
+ }
+
/* The following virtual functions are used in register allocation.
* They are not implemented because the existing interface and the logic
* at the caller side do not work for the elementized vector load and store.
@@ -64,6 +76,14 @@
MachineBasicBlock *FBB, ArrayRef<MachineOperand> Cond,
const DebugLoc &DL,
int *BytesAdded = nullptr) const override;
+
+ // memop clustering
+ bool getMemOpBaseRegImmOfs(MachineInstr &LdSt, unsigned &BaseReg,
+ int64_t &Offset,
+ const TargetRegisterInfo *TRI) const final;
+ bool shouldClusterMemOps(MachineInstr &FirstLdSt, unsigned BaseReg1,
+ MachineInstr &SecondLdSt, unsigned BaseReg2,
+ unsigned NumLoads) const override;
};
} // namespace llvm
Index: NVPTXInstrInfo.td
===================================================================
--- NVPTXInstrInfo.td (revision 336178)
+++ NVPTXInstrInfo.td (working copy)
@@ -2295,7 +2296,7 @@
"\t$dst, [$addr+$offset];", []>;
}
-let mayLoad=1, hasSideEffects=0 in {
+let IsLoad=1, mayLoad=1, hasSideEffects=0 in {
defm LD_i8 : LD<Int16Regs>;
defm LD_i16 : LD<Int16Regs>;
defm LD_i32 : LD<Int32Regs>;
@@ -2345,7 +2346,7 @@
" \t[$addr+$offset], $src;", []>;
}
-let mayStore=1, hasSideEffects=0 in {
+let IsStore=1, mayStore=1, hasSideEffects=0 in {
defm ST_i8 : ST<Int16Regs>;
defm ST_i16 : ST<Int16Regs>;
defm ST_i32 : ST<Int32Regs>;
@@ -2433,7 +2434,7 @@
"ld${isVol:volatile}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
"\t{{$dst1, $dst2, $dst3, $dst4}}, [$addr+$offset];", []>;
}
-let mayLoad=1, hasSideEffects=0 in {
+let IsLoad=1, mayLoad=1, hasSideEffects=0 in {
defm LDV_i8 : LD_VEC<Int16Regs>;
defm LDV_i16 : LD_VEC<Int16Regs>;
defm LDV_i32 : LD_VEC<Int32Regs>;
@@ -2528,7 +2529,7 @@
"$fromWidth \t[$addr+$offset], {{$src1, $src2, $src3, $src4}};", []>;
}
-let mayStore=1, hasSideEffects=0 in {
+let IsStore=1, mayStore=1, hasSideEffects=0 in {
defm STV_i8 : ST_VEC<Int16Regs>;
defm STV_i16 : ST_VEC<Int16Regs>;
defm STV_i32 : ST_VEC<Int32Regs>;
Index: NVPTXIntrinsics.td
===================================================================
--- NVPTXIntrinsics.td (revision 336178)
+++ NVPTXIntrinsics.td (working copy)
@@ -1728,6 +1728,7 @@
[]>, Requires<[hasLDU]>;
}
+let IsLoad = 1 in {
defm INT_PTX_LDU_GLOBAL_i8 : LDU_G<"u8 \t$result, [$src];", Int16Regs>;
defm INT_PTX_LDU_GLOBAL_i16 : LDU_G<"u16 \t$result, [$src];", Int16Regs>;
defm INT_PTX_LDU_GLOBAL_i32 : LDU_G<"u32 \t$result, [$src];", Int32Regs>;
@@ -1738,6 +1739,7 @@
defm INT_PTX_LDU_GLOBAL_f64 : LDU_G<"f64 \t$result, [$src];", Float64Regs>;
defm INT_PTX_LDU_GLOBAL_p32 : LDU_G<"u32 \t$result, [$src];", Int32Regs>;
defm INT_PTX_LDU_GLOBAL_p64 : LDU_G<"u64 \t$result, [$src];", Int64Regs>;
+}
// vector
@@ -1778,6 +1780,7 @@
!strconcat("ldu.global.", TyStr), []>;
}
+let IsLoad = 1 in {
defm INT_PTX_LDU_G_v2i8_ELE
: VLDU_G_ELE_V2<"v2.u8 \t{{$dst1, $dst2}}, [$src];", Int16Regs>;
defm INT_PTX_LDU_G_v2i16_ELE
@@ -1811,8 +1814,8 @@
defm INT_PTX_LDU_G_v4f32_ELE
: VLDU_G_ELE_V4<"v4.f32 \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];",
Float32Regs>;
+}
//-----------------------------------
// Support for ldg on sm_35 or later
//-----------------------------------
@@ -1839,6 +1842,7 @@
[]>, Requires<[hasLDG]>;
}
+let IsLoad = 1 in {
defm INT_PTX_LDG_GLOBAL_i8
: LDG_G<"u8 \t$result, [$src];", Int16Regs>;
defm INT_PTX_LDG_GLOBAL_i16
@@ -1859,6 +1863,7 @@
: LDG_G<"u32 \t$result, [$src];", Int32Regs>;
defm INT_PTX_LDG_GLOBAL_p64
: LDG_G<"u64 \t$result, [$src];", Int64Regs>;
+}
// vector
@@ -1900,6 +1905,7 @@
}
// FIXME: 8-bit LDG should be fixed once LDG/LDU nodes are made into proper loads.
+let IsLoad = 1 in {
defm INT_PTX_LDG_G_v2i8_ELE
: VLDG_G_ELE_V2<"v2.u8 \t{{$dst1, $dst2}}, [$src];", Int16Regs>;
defm INT_PTX_LDG_G_v2i16_ELE
@@ -1928,8 +1934,8 @@
: VLDG_G_ELE_V4<"v4.b32 \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", Float16x2Regs>;
defm INT_PTX_LDG_G_v4f32_ELE
: VLDG_G_ELE_V4<"v4.f32 \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", Float32Regs>;
+}
-
multiclass NG_TO_G<string Str, Intrinsic Intrin> {
def _yes : NVPTXInst<(outs Int32Regs:$result), (ins Int32Regs:$src),
!strconcat("cvta.", Str, ".u32 \t$result, $src;"),
Index: NVPTXSubtarget.h
===================================================================
--- NVPTXSubtarget.h (revision 336178)
+++ NVPTXSubtarget.h (working copy)
@@ -85,6 +85,8 @@
NVPTXSubtarget &initializeSubtargetDependencies(StringRef CPU, StringRef FS);
void ParseSubtargetFeatures(StringRef CPU, StringRef FS);
+
+ bool enableMachineScheduler() const override { return true; }
};
} // End llvm namespace
Index: NVPTXTargetMachine.cpp
===================================================================
--- NVPTXTargetMachine.cpp (revision 336178)
+++ NVPTXTargetMachine.cpp (working copy)
@@ -21,6 +21,7 @@
#include "llvm/ADT/Triple.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/CodeGen/Passes.h"
+#include "llvm/CodeGen/MachineScheduler.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/Pass.h"
@@ -173,6 +174,14 @@
void addFastRegAlloc(FunctionPass *RegAllocPass) override;
void addOptimizedRegAlloc(FunctionPass *RegAllocPass) override;
+ ScheduleDAGInstrs *
+ createMachineScheduler(MachineSchedContext *C) const override {
+ ScheduleDAGMILive *DAG = createGenericSchedLive(C);
+ DAG->addMutation(createLoadClusterDAGMutation(DAG->TII, DAG->TRI));
+ DAG->addMutation(createStoreClusterDAGMutation(DAG->TII, DAG->TRI));
+ return DAG;
+ }
+
private:
// If the opt level is aggressive, add GVN; otherwise, add EarlyCSE. This
// function is only called in opt mode.
@jlebar is the performance testing/reporting infrastructure as used for the gpucc paper available somewhere? Would be convenient in order to tune the clustering behavior on a variety of GPUs.
@jlebar is the performance testing/reporting infrastructure as used for the gpucc paper available somewhere? Would be convenient in order to tune the clustering behavior on a variety of GPUs.
Not really. Like the rest of gpucc, it got thrown away when we did the complete rewrite into clang.
It's a source of much personal embarrassment that we don't have good GPU compiler benchmarks here at Google. We have some XLA models and infrastructure for running them on P100s and V100s, but XLA currently only uses shared memory in one limited case, and I don't expect this to help with that. I'd also be surprised if grouping loads/stores gets us to vectorize more loads/stores to global memory, because fundamentally ptx doesn't (?) have a way for us to tell it that a pointer is aligned, and so it seems hard for it to to apply this optimization. And in addition, most of our kernels run at peak theoretical performance anyway, so compiler optimizations like this are currently not applicable to most of the code we generate.
When we start generating our own matmuls and convolutions instead of calling into cudnn, I expect this may change...
For now your own benchmarks may be the best guide.
Is this still relevant? There have been some alignment patches to LoadStoreVectorizer. The posted cuda example optimizes away to nothing for me
@arsenm Thanks for checking in; it's been a while since I looked at this, I'll check after the holidays.
I take it this would be due to improvements to the middle end, as NVPTX still doesn't see to use a MachineScheduler/ClusterDAGMutation?
@arsenm Thanks for checking in; it's been a while since I looked at this, I'll check after the holidays.
I take it this would be due to improvements to the middle end, as NVPTX still doesn't see to use a MachineScheduler/ClusterDAGMutation?
I only skimmed through and saw the alignment patch for LSV, don't know what else would have mattered
The posted examples indeed optimize away, but the issue reported here still seems present. Slightly modifying the code to avoid complete removal of the loop:
#define BLOCK_SIZE 5
__global__ void kernel() {
__shared__ float dia[BLOCK_SIZE][BLOCK_SIZE];
__shared__ float peri_col[BLOCK_SIZE][BLOCK_SIZE];
// initialize so that the loop below doesn't get optimized away
dia[threadIdx.x][threadIdx.y] = 1.0f;
int idx = threadIdx.x;
for (int i = 0; i < BLOCK_SIZE; i++) {
for (int j = 0; j < i; j++)
peri_col[idx][i] -= dia[j][i];
peri_col[idx][i] /= dia[i][i];
}
}
With Clang 15, this results in the following PTX code:
❯ clang -O3 -S --cuda-device-only --cuda-gpu-arch=sm_50 mwe.cu -o -
//
// Generated by LLVM NVPTX Back-End
//
.version 7.5
.target sm_50
.address_size 64
// .globl _Z6kernelv
// _ZZ6kernelvE3dia has been demoted
// _ZZ6kernelvE8peri_col has been demoted
.visible .entry _Z6kernelv()
{
.reg .b32 %r<4>;
.reg .f32 %f<36>;
.reg .b64 %rd<9>;
// demoted variable
.shared .align 4 .b8 _ZZ6kernelvE3dia[100];
// demoted variable
.shared .align 4 .b8 _ZZ6kernelvE8peri_col[100];
mov.u32 %r1, %tid.x;
mov.u32 %r2, %tid.y;
mul.wide.u32 %rd1, %r1, 20;
mov.u64 %rd2, _ZZ6kernelvE3dia;
add.s64 %rd3, %rd2, %rd1;
mul.wide.u32 %rd4, %r2, 4;
add.s64 %rd5, %rd3, %rd4;
mov.u32 %r3, 1065353216;
st.shared.u32 [%rd5], %r3;
mul.wide.s32 %rd6, %r1, 20;
mov.u64 %rd7, _ZZ6kernelvE8peri_col;
add.s64 %rd8, %rd7, %rd6;
ld.shared.f32 %f1, [%rd8];
ld.shared.f32 %f2, [_ZZ6kernelvE3dia];
div.rn.f32 %f3, %f1, %f2;
st.shared.f32 [%rd8], %f3;
ld.shared.f32 %f4, [%rd8+4];
ld.shared.f32 %f5, [_ZZ6kernelvE3dia+4];
sub.f32 %f6, %f4, %f5;
st.shared.f32 [%rd8+4], %f6;
ld.shared.f32 %f7, [_ZZ6kernelvE3dia+24];
div.rn.f32 %f8, %f6, %f7;
st.shared.f32 [%rd8+4], %f8;
ld.shared.f32 %f9, [%rd8+8];
ld.shared.f32 %f10, [_ZZ6kernelvE3dia+8];
sub.f32 %f11, %f9, %f10;
ld.shared.f32 %f12, [_ZZ6kernelvE3dia+28];
sub.f32 %f13, %f11, %f12;
st.shared.f32 [%rd8+8], %f13;
ld.shared.f32 %f14, [_ZZ6kernelvE3dia+48];
div.rn.f32 %f15, %f13, %f14;
st.shared.f32 [%rd8+8], %f15;
ld.shared.f32 %f16, [%rd8+12];
ld.shared.f32 %f17, [_ZZ6kernelvE3dia+12];
sub.f32 %f18, %f16, %f17;
ld.shared.f32 %f19, [_ZZ6kernelvE3dia+32];
sub.f32 %f20, %f18, %f19;
ld.shared.f32 %f21, [_ZZ6kernelvE3dia+52];
sub.f32 %f22, %f20, %f21;
st.shared.f32 [%rd8+12], %f22;
ld.shared.f32 %f23, [_ZZ6kernelvE3dia+72];
div.rn.f32 %f24, %f22, %f23;
st.shared.f32 [%rd8+12], %f24;
ld.shared.f32 %f25, [%rd8+16];
ld.shared.f32 %f26, [_ZZ6kernelvE3dia+16];
sub.f32 %f27, %f25, %f26;
ld.shared.f32 %f28, [_ZZ6kernelvE3dia+36];
sub.f32 %f29, %f27, %f28;
ld.shared.f32 %f30, [_ZZ6kernelvE3dia+56];
sub.f32 %f31, %f29, %f30;
ld.shared.f32 %f32, [_ZZ6kernelvE3dia+76];
sub.f32 %f33, %f31, %f32;
st.shared.f32 [%rd8+16], %f33;
ld.shared.f32 %f34, [_ZZ6kernelvE3dia+96];
div.rn.f32 %f35, %f33, %f34;
st.shared.f32 [%rd8+16], %f35;
ret;
}
i.e. the loads/stores are still spread around, resulting in non-vectorized operations at the SASS level:
❯ clang -O3 -c --cuda-device-only --cuda-gpu-arch=sm_50 mwe.cu -o mwe.o
❯ cuobjdump -sass /tmp/wip.o | grep LDS
/*0078*/ LDS.U.32 R3, [RZ] }
/*0090*/ LDS.U.32 R2, [R0] ;
/*0118*/ LDS.U.32 R2, [R0+0x4] ;
/*0128*/ LDS.U.32 R3, [0x4] ;
/*0148*/ LDS.U.32 R3, [0x18] ;
/*01d8*/ LDS.U.32 R2, [R0+0x8] ;
/*01e8*/ LDS.U.32 R3, [0x8] ;
/*01f0*/ LDS.U.32 R4, [0x1c] ;
/*0218*/ LDS.U.32 R3, [0x30] ;
/*02a8*/ LDS.U.32 R2, [R0+0xc] ;
/*02b0*/ LDS.U.32 R3, [0xc] ;
/*02b8*/ LDS.U.32 R4, [0x20] ;
/*02c8*/ LDS.U.32 R7, [0x34] ;
/*02f8*/ LDS.U.32 R3, [0x48] ;
/*0388*/ LDS.U.32 R2, [R0+0x10] ;
/*0390*/ LDS.U.32 R3, [0x10] ;
/*0398*/ LDS.U.32 R4, [0x24] ;
/*03a8*/ LDS.U.32 R7, [0x38] ;
/*03b0*/ LDS.U.32 R8, [0x4c] ;
/*03f0*/ LDS.U.32 R3, [0x60] ;
Interestingly, nvcc from CUDA 12 has "regressed" as well:
❯ /opt/cuda-12.0/bin/nvcc mwe.cu -arch sm_50 -ptx -o -
//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-31968024
// Cuda compilation tools, release 12.0, V12.0.76
// Based on NVVM 7.0.1
//
.version 8.0
.target sm_50
.address_size 64
// .globl _Z6kernelv
// _ZZ6kernelvE3dia has been demoted
// _ZZ6kernelvE8peri_col has been demoted
.visible .entry _Z6kernelv()
{
.reg .f32 %f<36>;
.reg .b32 %r<11>;
// demoted variable
.shared .align 4 .b8 _ZZ6kernelvE3dia[100];
// demoted variable
.shared .align 4 .b8 _ZZ6kernelvE8peri_col[100];
mov.u32 %r1, %tid.x;
mov.u32 %r2, %tid.y;
mul.lo.s32 %r3, %r1, 20;
mov.u32 %r4, _ZZ6kernelvE3dia;
add.s32 %r5, %r4, %r3;
shl.b32 %r6, %r2, 2;
add.s32 %r7, %r5, %r6;
mov.u32 %r8, 1065353216;
st.shared.u32 [%r7], %r8;
mov.u32 %r9, _ZZ6kernelvE8peri_col;
add.s32 %r10, %r9, %r3;
ld.shared.f32 %f1, [%r10];
ld.shared.f32 %f2, [_ZZ6kernelvE3dia];
div.rn.f32 %f3, %f1, %f2;
st.shared.f32 [%r10], %f3;
ld.shared.f32 %f4, [%r10+4];
ld.shared.f32 %f5, [_ZZ6kernelvE3dia+4];
sub.f32 %f6, %f4, %f5;
ld.shared.f32 %f7, [_ZZ6kernelvE3dia+24];
div.rn.f32 %f8, %f6, %f7;
st.shared.f32 [%r10+4], %f8;
ld.shared.f32 %f9, [%r10+8];
ld.shared.f32 %f10, [_ZZ6kernelvE3dia+8];
sub.f32 %f11, %f9, %f10;
ld.shared.f32 %f12, [_ZZ6kernelvE3dia+28];
sub.f32 %f13, %f11, %f12;
ld.shared.f32 %f14, [_ZZ6kernelvE3dia+48];
div.rn.f32 %f15, %f13, %f14;
st.shared.f32 [%r10+8], %f15;
ld.shared.f32 %f16, [%r10+12];
ld.shared.f32 %f17, [_ZZ6kernelvE3dia+12];
sub.f32 %f18, %f16, %f17;
ld.shared.f32 %f19, [_ZZ6kernelvE3dia+32];
sub.f32 %f20, %f18, %f19;
ld.shared.f32 %f21, [_ZZ6kernelvE3dia+52];
sub.f32 %f22, %f20, %f21;
ld.shared.f32 %f23, [_ZZ6kernelvE3dia+72];
div.rn.f32 %f24, %f22, %f23;
st.shared.f32 [%r10+12], %f24;
ld.shared.f32 %f25, [%r10+16];
ld.shared.f32 %f26, [_ZZ6kernelvE3dia+16];
sub.f32 %f27, %f25, %f26;
ld.shared.f32 %f28, [_ZZ6kernelvE3dia+36];
sub.f32 %f29, %f27, %f28;
ld.shared.f32 %f30, [_ZZ6kernelvE3dia+56];
sub.f32 %f31, %f29, %f30;
ld.shared.f32 %f32, [_ZZ6kernelvE3dia+76];
sub.f32 %f33, %f31, %f32;
ld.shared.f32 %f34, [_ZZ6kernelvE3dia+96];
div.rn.f32 %f35, %f33, %f34;
st.shared.f32 [%r10+16], %f35;
ret;
}
❯ /opt/cuda-12.0/bin/nvcc mwe.cu -arch sm_50 -ptx -o -
❯ /opt/cuda-12.0/bin/cuobjdump mwe.o -sass | grep LDS
/*0068*/ LDS.U.32 R4, [RZ] ;
/*0070*/ LDS.U.32 R3, [R0+0x64] ;
/*0118*/ LDS.U.32 R4, [0x18] ;
/*0128*/ LDS.U.32 R2, [R0+0x68] ;
/*0130*/ LDS.U.32 R3, [0x4] ;
/*01d8*/ LDS.U.32 R5, [0x30] ;
/*01e8*/ LDS.U.32 R2, [R0+0x6c] ;
/*01f0*/ LDS.U.32 R3, [0x8] ;
/*01f8*/ LDS.U.32 R4, [0x1c] ;
/*0298*/ LDS.U.32 R8, [0x48] ;
/*02a8*/ LDS.U.32 R2, [R0+0x70] ;
/*02b0*/ LDS.U.32 R4, [0xc] ;
/*02b8*/ LDS.U.32 R5, [0x20] ;
/*02c8*/ LDS.U.32 R7, [0x34] ;
/*0378*/ LDS.U.32 R2, [R0+0x74] ;
/*0388*/ LDS.U.32 R9, [0x60] ;
/*0390*/ LDS.U.32 R4, [0x10] ;
/*0398*/ LDS.U.32 R5, [0x24] ;
/*03a8*/ LDS.U.32 R7, [0x38] ;
/*03b0*/ LDS.U.32 R8, [0x4c] ;
Going back to CUDA 10 (requiring GCC 7 which I got from the gcc:7
Docker image) still produces nicely-grouped operations:
❯ /opt/cuda-10.0/bin/nvcc -arch sm_50 -ptx mwe.cu -o -
//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-24817639
// Cuda compilation tools, release 10.0, V10.0.130
// Based on LLVM 3.4svn
//
.version 6.3
.target sm_30
.address_size 64
// .globl _Z6kernelv
// _ZZ6kernelvE3dia has been demoted
// _ZZ6kernelvE8peri_col has been demoted
.visible .entry _Z6kernelv(
)
{
.reg .f32 %f<36>;
.reg .b32 %r<11>;
// demoted variable
.shared .align 4 .b8 _ZZ6kernelvE3dia[100];
// demoted variable
.shared .align 4 .b8 _ZZ6kernelvE8peri_col[100];
mov.u32 %r1, %tid.x;
mov.u32 %r2, %tid.y;
mul.lo.s32 %r3, %r1, 20;
mov.u32 %r4, _ZZ6kernelvE3dia;
add.s32 %r5, %r4, %r3;
shl.b32 %r6, %r2, 2;
add.s32 %r7, %r5, %r6;
mov.u32 %r8, 1065353216;
st.shared.u32 [%r7], %r8;
mov.u32 %r9, _ZZ6kernelvE8peri_col;
add.s32 %r10, %r9, %r3;
ld.shared.f32 %f1, [%r10];
ld.shared.f32 %f2, [_ZZ6kernelvE3dia];
div.rn.f32 %f3, %f1, %f2;
ld.shared.f32 %f4, [%r10+4];
ld.shared.f32 %f5, [_ZZ6kernelvE3dia+4];
ld.shared.f32 %f6, [_ZZ6kernelvE3dia+24];
ld.shared.f32 %f7, [%r10+8];
ld.shared.f32 %f8, [_ZZ6kernelvE3dia+8];
ld.shared.f32 %f9, [_ZZ6kernelvE3dia+28];
ld.shared.f32 %f10, [_ZZ6kernelvE3dia+48];
ld.shared.f32 %f11, [%r10+12];
ld.shared.f32 %f12, [_ZZ6kernelvE3dia+12];
ld.shared.f32 %f13, [_ZZ6kernelvE3dia+32];
ld.shared.f32 %f14, [_ZZ6kernelvE3dia+52];
ld.shared.f32 %f15, [_ZZ6kernelvE3dia+72];
ld.shared.f32 %f16, [%r10+16];
ld.shared.f32 %f17, [_ZZ6kernelvE3dia+16];
ld.shared.f32 %f18, [_ZZ6kernelvE3dia+36];
ld.shared.f32 %f19, [_ZZ6kernelvE3dia+56];
ld.shared.f32 %f20, [_ZZ6kernelvE3dia+76];
ld.shared.f32 %f21, [_ZZ6kernelvE3dia+96];
st.shared.f32 [%r10], %f3;
sub.f32 %f22, %f4, %f5;
div.rn.f32 %f23, %f22, %f6;
st.shared.f32 [%r10+4], %f23;
sub.f32 %f24, %f7, %f8;
sub.f32 %f25, %f24, %f9;
div.rn.f32 %f26, %f25, %f10;
st.shared.f32 [%r10+8], %f26;
sub.f32 %f27, %f11, %f12;
sub.f32 %f28, %f27, %f13;
sub.f32 %f29, %f28, %f14;
div.rn.f32 %f30, %f29, %f15;
st.shared.f32 [%r10+12], %f30;
sub.f32 %f31, %f16, %f17;
sub.f32 %f32, %f31, %f18;
sub.f32 %f33, %f32, %f19;
sub.f32 %f34, %f33, %f20;
div.rn.f32 %f35, %f34, %f21;
st.shared.f32 [%r10+16], %f35;
ret;
}
... which then result in better vectorization at the SASS level:
❯ /opt/cuda-10.0/bin/nvcc -arch sm_50 -c mwe.cu -o mwe.o
❯ /opt/cuda-10.0/bin/cuobjdump -sass mwe.o | grep LDS
/*0058*/ LDS.U.32 R5, [R0+0x64] ;
/*0068*/ LDS.U.32 R10, [RZ] ;
/*0078*/ LDS.U.128 R4, [RZ] ;
/*0088*/ LDS.U.128 R8, [0x10] ;
/*0090*/ LDS.U.32 R20, [R0+0x68] ;
/*0098*/ LDS.U.32 R19, [R0+0x6c] ;
/*00a8*/ LDS.U.128 R12, [0x30] ;
/*00b0*/ LDS.U.32 R18, [0x60] ;
/*00b8*/ LDS.U.32 R4, [R0+0x70] ;
/*00c8*/ LDS.U.32 R9, [R0+0x74] ;
/*00d0*/ LDS.U.64 R16, [0x20] ;
/*00e8*/ LDS.U.64 R2, [0x48] }
Though maybe this was intentionally changed to reduce register pressure, as noted by Justin above. I currently don't have the time to dive into this, but seeing how not much has changed, I guess this issue is still relevant.
Given the following MWE:
nvcc
generates PTX with loads clustered together:Note the pessimistic alignment of 4 for the shared memory arrays. I guess this follows the C spec, but it rules out vectorization. However, when ptxas assembles this code, it knows the memory layout and "physical" alignment of the shared memory arrays, and as a result nicely vectorizes this code:
We tackle this differently, using LLVM's Load Store Vectorizer by specifying a much more optimistic alignment that enables vectorization (https://github.com/JuliaGPU/CUDAnative.jl/pull/204):
While our PTX contains vectorized loads now (16, vs 20 with nvcc), the resulting SASS still performs more loads (remaining at 16, as we don't cluster our loads for ptxas to vectorize, vs only 12). This is due to ptxas vectorizing load chains with gaps, eg.
LDS.128 R8, [0x10]
corresponding with loading elements 4 to 7 fromkernel_dia
, whilekernel_dia+20
isn't actually used by the code (andR9
isn't either in the SASS code). LLVM refuses to vectorize these accesses, emitting 3 loads instead (it should have still been able to vectorize accesses to elements 6 and 7, looking into that right now).So this probably needs a fix to the Load Store Vectorizer to vectorize chains with gaps. Alternatively, we could cluster loads like nvcc, but the existing load/store clustering DAG mutation isn't quite as aggressive and as such only enables
ptxas
to merge a couple of loads (maybe there's some missing bits since NVPTX hasn't been using the machine scheduler before):cc @jlebar reduced from rodinia/lud ref https://lists.llvm.org/pipermail/llvm-dev/2018-June/124209.htm