Open jeromeku opened 3 months ago
As a follow-up, trying to implement the above using epilogue visitor trees.
Encountering 2 problems:
EVT
args: I get no instance of constructor "cute::tuple<T...>" matches the argument list
compiler error when trying to construct the arguments
for a DefaultGemmVisitor
kernel with a simple EVT
that does a columnwise bias broadcast. I've followed the nested structure of the EVT
, which consists of a store
node and a EVTCompute
node which contains the bias
, accumulator
, and compute
node but clearly there's still a mistake.GemmUniversalAdapter
with a DefaultGemmVisitor
that uses a threadblock swizzle other than streamk
to construct the arguments
(as opposed to directly using the gemm kernel per above), I get
utlass/include/cutlass/gemm/kernel/gemm_universal.h(78): here is inaccessible
detected during instantiation of class "cutlass::gemm::device::GemmUniversalAdapter<GemmKernel_, std::enable_if_t<<expression>, void>> [with GemmKernel_=EVTKernel]"
I tried to tweak the streamk with broadcast example with the above changes (simpler EVT and ThreadBlockIdentitySwizzle
).
Below is the full script:
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/functional.h"
#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h"
#include "cutlass/epilogue/thread/linear_combination_bias_relu.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_with_broadcast.h"
#include "cutlass/gemm/device/gemm_universal_streamk_with_broadcast.h"
#include "cutlass/epilogue/thread/linear_combination_residual_block.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/gemm/device/gemm_universal_with_broadcast.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "cute/tensor.hpp"
using namespace cute;
/////////////////////////////////////////////////////////////////////////////////////////////////
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) \
{ \
std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \
<< std::endl; \
exit(EXIT_FAILURE); \
} \
}
using DType = cutlass::half_t;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::RowMajor;
constexpr int stages = 3;
using ThreadBlockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>;
// EVT
constexpr int Alignment = 128 / cutlass::sizeof_bits_v<DType>;
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock-level tile size (concept: GemmShape)
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp-level tile size (concept: GemmShape)
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // Instruction-level tile size (concept: GemmShape)
constexpr int NumStages = 4; // Number of global->shared pipeline stages used in the GEMM mainloop
constexpr int EVTEpilogueStages = 1; // Number of epilogue stages in EVT
using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
ThreadblockShape,
WarpShape,
DType,
Alignment,
EVTEpilogueStages>;
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
using Scale = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, cutlass::half_t,
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, DType, DType,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT<
Compute0,
Scale,
Accum>;
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, cutlass::half_t, cutlass::FloatRoundStyle::round_to_nearest,
cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>>;
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<
D,
EVTCompute0>;
using EVTKernel =
typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
DType, LayoutA, cutlass::ComplexTransform::kNone, Alignment,
DType, LayoutB, cutlass::ComplexTransform::kNone, Alignment,
DType, LayoutC, Alignment,
DType,
DType,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
EVTD,
ThreadBlockSwizzle,
NumStages,
cutlass::arch::OpMultiplyAdd,
EVTEpilogueStages>::GemmKernel;
using DeviceGemmEVT = cutlass::gemm::device::GemmUniversalAdapter<EVTKernel>;
template <typename Gemm>
void test(int M = 8, int N = 4, int K = 8, bool verbose = true, int batch_count = 1,
cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
DType alpha = DType(1.0), DType beta = DType(0.0))
{
cutlass::gemm::GemmCoord problem_size = cutlass::gemm::GemmCoord(M, N, K);
cutlass::HostTensor<typename Gemm::ElementA, typename Gemm::LayoutA> tensor_A;
cutlass::HostTensor<typename Gemm::ElementB, typename Gemm::LayoutB> tensor_B;
cutlass::HostTensor<DType, typename Gemm::LayoutC> tensor_Z;
cutlass::HostTensor<DType, typename Gemm::LayoutC> tensor_Broadcast;
tensor_A.resize({problem_size.m(), problem_size.k()});
tensor_B.resize({problem_size.k(), problem_size.n()});
tensor_Z.resize({problem_size.m(), problem_size.n()});
tensor_Broadcast.resize({problem_size.m(), 1});
cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), tensor_A.capacity());
cutlass::reference::host::BlockFill(tensor_B.host_data(), tensor_B.capacity(), DType(1.0));
cutlass::reference::host::BlockFill(tensor_Z.host_data(), tensor_Z.capacity(), DType(0.0));
cutlass::reference::host::BlockFillSequential(tensor_Broadcast.host_data(), tensor_Broadcast.capacity());
tensor_A.sync_device();
tensor_B.sync_device();
tensor_Z.sync_device();
tensor_Broadcast.sync_device();
if (verbose)
{
std::cout << "tensor_A:\n"
<< tensor_A.host_view() << std::endl;
std::cout << "tensor_B:\n"
<< tensor_B.host_view() << std::endl;
std::cout << "tensor_Broadcast:\n"
<< tensor_Broadcast.host_view() << std::endl;
}
typename EVTD::Arguments callback_args{
{
{}, // Compute0
{tensor_Broadcast.device_data(), DType(0), {_0{}, _1{}, int32_t(problem_size.m())}}, // bias / scale
{} // Accum
}, // EvtCompute0
{tensor_Z.device_data(), {problem_size.n(), _1{}, problem_size.mn().product()}}, // D
};
typename EVTKernel::Arguments evtArgs{
cutlass::gemm::GemmUniversalMode::kGemm, // universal mode
problem_size, // problem_size
1, // batch count / splitk slices
callback_args, // argument of EVT callbacks
tensor_A.device_data(), // ptr_A
tensor_B.device_data(), // ptr_B
nullptr, // ptr_C (unused)
nullptr, // ptr_D (unused)
problem_size.mk().product(), // batch_stride_A
problem_size.nk().product(), // batch_stride_B
0, // batch_stride_C (unused)
0, // batch_stride_D (unused)
tensor_A.layout().stride(0), // stride_a
tensor_B.layout().stride(0), // stride_b
0, // stride_c (unused)
0 // stride_d (unused)
};
}
int main()
{
int M = 8;
int N = 8;
int K = 8;
std::cout << "GemmEVT" << std::endl;
test<EVTKernel>(M, N, K);
}
Not an expert but I recently made the exactly same problem when crafting my custom epilogue visitor tree. Here is what I think : for your first problem (cute type): Change
using Scale = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, cutlass::half_t,
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;
To
using Scale = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, cutlass::half_t,
cute::Stride<cute::Int<1>, cute::Int<0>, int32_t>>; // <-- the only change is the type of last parameter, which should be int or int64_t
Do the same with type D
For your second problem, the reason why the example works well with ThreadblockSwizzleStreamK
, but doesn't work with GemmIdentityThreadblockSwizzle<>
, is because underlying the DefaultGemmWithVisitor
class, this difference in parameter leads to dispatch to different GemmKernel
types as described here https://github.com/NVIDIA/cutlass/blob/19f3cc33f1642b490ed7126ea0141f79c0045527/include/cutlass/gemm/kernel/default_gemm_universal_with_visitor.h#L140
to me more specific, the example 47 uses GemmWithEpilogueVisitorStreamk
(link) as the EVTKernelStreamK
, but your example will use GemmWithEpilogueVisitor
(link) as your EVTKernel
.
With that clear, apparently the GemmWithEpilogueVisitor
contains some issues while GemmWithEpilogueVisitorStreamK
works just fine. I had my own fix, but I suggest you post your whole error message and let a maintainer expert fix it.
@hwu36
This issue has been labeled inactive-30d
due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d
if there is no activity in the next 60 days.
What is your question? Trying to understand the behavior of Gemm with a column-broadcasted bias vector epilogue.
When defining a device
GemmUniversalWithBroadcast
with the following config:I get a
core dump
whenever I try to run the above withM != K
. Running withM == N
, I get the correctGEMM
but the epilogue is broadcasted incorrectly (row-wise vs column-wise).When I run the above using
GemmUniversalAdapter
as the device handle, the op runs for allM
andN
. However, theA
andB
inputs transposed because of an internal transpose that the adapter does, while the epilogue op is performed correctly.Questions
GemmUniversalWithBroadcast
?GemmUniversalAdapter
transpose layouts internally?Repro
Here is a simple script for reproducing above.
GemmUniversalWithBroadcast
will fail to run withM != N
GemmUniversalWithBroadcast
runs withM == N
but epilogue incorrectGemmUniversalAdapter
runs, but with operandsA
andB
transposed.