NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
4.96k stars 849 forks source link

[QST] Epilogue Broadcast: `Adapter` vs `GemmUniversal` #1459

Open jeromeku opened 3 months ago

jeromeku commented 3 months ago

What is your question? Trying to understand the behavior of Gemm with a column-broadcasted bias vector epilogue.

When defining a device GemmUniversalWithBroadcast with the following config:

using DType = cutlass::half_t;
using ElementWiseOp = cutlass::epilogue::thread::Identity<DType>;
using BinaryOp = cutlass::plus<DType>;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::RowMajor;

constexpr int stages = 3;
using ThreadBlockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>;

using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise<
    DType,
    DType,
    DType,
    DType,
    DType,
    8,
    ElementWiseOp,
    BinaryOp>;

using GemmUniversal = cutlass::gemm::device::GemmUniversalWithBroadcast<
    DType,
    LayoutA,
    DType, LayoutB,
    DType, LayoutC,
    DType,
    cutlass::arch::OpClassTensorOp,
    cutlass::arch::Sm80,
    cutlass::gemm::GemmShape<128, 128, 32>,
    cutlass::gemm::GemmShape<64, 64, 32>,
    cutlass::gemm::GemmShape<16, 8, 16>,
    EpilogueOutputOp,
    ThreadBlockSwizzle,
    stages>;

I get a core dump whenever I try to run the above with M != K. Running with M == N, I get the correct GEMM but the epilogue is broadcasted incorrectly (row-wise vs column-wise).

When I run the above using GemmUniversalAdapter as the device handle, the op runs for all M and N. However, the A and B inputs transposed because of an internal transpose that the adapter does, while the epilogue op is performed correctly.

Questions

Repro

Here is a simple script for reproducing above.

#include <iostream>

#include "cutlass/cutlass.h"
#include "cutlass/functional.h"

#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"

#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h"
#include "cutlass/epilogue/thread/linear_combination_bias_relu.h"

#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/gemm.h"

#include "cutlass/gemm/device/gemm_universal_with_broadcast.h"
/////////////////////////////////////////////////////////////////////////////////////////////////

#define CUTLASS_CHECK(status)                                                                    \
  {                                                                                              \
    cutlass::Status error = status;                                                              \
    if (error != cutlass::Status::kSuccess)                                                      \
    {                                                                                            \
      std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \
                << std::endl;                                                                    \
      exit(EXIT_FAILURE);                                                                        \
    }                                                                                            \
  }

using DType = cutlass::half_t;
using ElementWiseOp = cutlass::epilogue::thread::Identity<DType>;
using BinaryOp = cutlass::plus<DType>;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::RowMajor;

constexpr int stages = 3;
using ThreadBlockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>;

using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise<
    DType,
    DType,
    DType,
    DType,
    DType,
    8,
    ElementWiseOp,
    BinaryOp>;

using GemmKernel =
    typename cutlass::gemm::kernel::DefaultGemmWithBroadcast<
        DType, LayoutA, cutlass::ComplexTransform::kNone, 8, // transposed B operand
        DType, LayoutB, cutlass::ComplexTransform::kNone, 8, // transposed A operand
        DType, LayoutC,
        DType,
        cutlass::arch::OpClassTensorOp,
        cutlass::arch::Sm80,
        cutlass::gemm::GemmShape<128, 128, 32>,
        cutlass::gemm::GemmShape<64, 64, 32>,
        cutlass::gemm::GemmShape<16, 8, 16>,
        EpilogueOutputOp,
        ThreadBlockSwizzle,
        stages,
        cutlass::arch::OpMultiplyAdd>::GemmKernel;

using GemmUniversal = cutlass::gemm::device::GemmUniversalWithBroadcast<
    DType,
    LayoutA,
    DType, LayoutB,
    DType, LayoutC,
    DType,
    cutlass::arch::OpClassTensorOp,
    cutlass::arch::Sm80,
    cutlass::gemm::GemmShape<128, 128, 32>,
    cutlass::gemm::GemmShape<64, 64, 32>,
    cutlass::gemm::GemmShape<16, 8, 16>,
    EpilogueOutputOp,
    ThreadBlockSwizzle,
    stages>;

using GemmAdapter = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

template <typename Gemm>
void test(int M = 8, int N = 4, int K = 8, bool verbose = true, int batch_count = 1,
          cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
          DType alpha = DType(1.0), DType beta = DType(0.0))
{
  cutlass::gemm::GemmCoord problem_size = cutlass::gemm::GemmCoord(M, N, K);
  cutlass::HostTensor<typename Gemm::ElementA, typename Gemm::LayoutA> tensor_A;
  cutlass::HostTensor<typename Gemm::ElementB, typename Gemm::LayoutB> tensor_B;
  cutlass::HostTensor<EpilogueOutputOp::ElementZ, typename Gemm::LayoutC> tensor_Z;
  cutlass::HostTensor<EpilogueOutputOp::ElementVector, typename Gemm::LayoutC> tensor_Broadcast;

  tensor_A.resize({problem_size.m(), problem_size.k()});
  tensor_B.resize({problem_size.k(), problem_size.n()});
  tensor_Z.resize({problem_size.m(), problem_size.n()});
  tensor_Broadcast.resize({problem_size.m(), 1});
  cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), tensor_A.capacity());
  cutlass::reference::host::BlockFill(tensor_B.host_data(), tensor_B.capacity(), typename Gemm::ElementB(1.0));
  // cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), tensor_B.capacity());
  cutlass::reference::host::BlockFill(tensor_Z.host_data(), tensor_Z.capacity(), EpilogueOutputOp::ElementZ(0.0));
  cutlass::reference::host::BlockFillSequential(tensor_Broadcast.host_data(), tensor_Broadcast.capacity());

  tensor_A.sync_device();
  tensor_B.sync_device();
  tensor_Z.sync_device();
  tensor_Broadcast.sync_device();
  if (verbose)
  {
    std::cout << "tensor_A:\n"
              << tensor_A.host_view() << std::endl;
    std::cout << "tensor_B:\n"
              << tensor_B.host_view() << std::endl;
    std::cout << "tensor_Broadcast:\n"
              << tensor_Broadcast.host_view() << std::endl;
  }

  typename Gemm::Arguments arguments{
      mode,
      problem_size,
      batch_count,
      {alpha, beta},
      tensor_A.device_data(),
      tensor_B.device_data(),
      nullptr, // C
      tensor_Z.device_data(),
      tensor_Broadcast.device_data(),
      nullptr,                             // T
      problem_size.m() * problem_size.k(), // batch stride A
      problem_size.n() * problem_size.k(), // batch stride B
      problem_size.m() * problem_size.n(), // batch stride C
      problem_size.m() * problem_size.n(), // batch stride Z
      problem_size.m(),                    // batch stride broadcast
      problem_size.m() * problem_size.n(), // batch stride T
      tensor_A.layout().stride(0),         // stride A
      tensor_B.layout().stride(0),         // stride B
      tensor_Z.layout().stride(0),         // stride C
      tensor_Z.layout().stride(0),         // stride Z
      0,                                   // This must be zero for broadcast
      tensor_Z.layout().stride(0),         // stride T
  };

  Gemm gemm_op;

  size_t workspace_size = Gemm::get_workspace_size(arguments);

  cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

  cutlass::Status status = gemm_op.initialize(arguments, workspace.get());

  CUTLASS_CHECK(status);

  status = gemm_op();

  CUTLASS_CHECK(status);
  tensor_Z.sync_host();
  std::cout << "tensor_Z:\n"
            << tensor_Z.host_view() << std::endl;
}
int main()
{
  int M = 8;
  int N = 8;
  int K = 8;

  // NOTE: Running with `GemmUniversalBroadcast` will segfault if M != N
  std::cout << "GemmUniversalBroadcast" << std::endl;
  test<GemmUniversal>(M, N, K);
  std::cout << " ----------------------- " << std::endl;

  std::cout << "GemmAdapterBroadcast" << std::endl;
  test<GemmAdapter>(M, N, K);
}
jeromeku commented 3 months ago

As a follow-up, trying to implement the above using epilogue visitor trees.

Encountering 2 problems:

Below is the full script:

#include <iostream>

#include "cutlass/cutlass.h"
#include "cutlass/functional.h"

#include "cutlass/gemm/kernel/default_gemm_with_broadcast.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"

#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h"
#include "cutlass/epilogue/thread/linear_combination_bias_relu.h"

#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/gemm/device/gemm_universal_with_broadcast.h"
#include "cutlass/gemm/device/gemm_universal_streamk_with_broadcast.h"
#include "cutlass/epilogue/thread/linear_combination_residual_block.h"

#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/gemm.h"

#include "cutlass/gemm/device/gemm_universal_with_broadcast.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"

#include "cute/tensor.hpp"
using namespace cute;

/////////////////////////////////////////////////////////////////////////////////////////////////

#define CUTLASS_CHECK(status)                                                                          \
    {                                                                                                  \
        cutlass::Status error = status;                                                                \
        if (error != cutlass::Status::kSuccess)                                                        \
        {                                                                                              \
            std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \
                      << std::endl;                                                                    \
            exit(EXIT_FAILURE);                                                                        \
        }                                                                                              \
    }

using DType = cutlass::half_t;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::RowMajor;

constexpr int stages = 3;
using ThreadBlockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>;

// EVT
constexpr int Alignment = 128 / cutlass::sizeof_bits_v<DType>;
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock-level tile size (concept: GemmShape)
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>;          // Warp-level tile size (concept: GemmShape)
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;    // Instruction-level tile size (concept: GemmShape)
constexpr int NumStages = 4;                                     // Number of global->shared pipeline stages used in the GEMM mainloop
constexpr int EVTEpilogueStages = 1;                             // Number of epilogue stages in EVT

using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
    ThreadblockShape,
    WarpShape,
    DType,
    Alignment,
    EVTEpilogueStages>;

using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;

using Scale = cutlass::epilogue::threadblock::VisitorColBroadcast<
    OutputTileThreadMap, cutlass::half_t,
    cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;

using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
    cutlass::multiplies, DType, DType,
    cutlass::FloatRoundStyle::round_to_nearest>;

using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT<
    Compute0,
    Scale,
    Accum>;

using D = cutlass::epilogue::threadblock::VisitorAuxStore<
    OutputTileThreadMap, cutlass::half_t, cutlass::FloatRoundStyle::round_to_nearest,
    cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>>;

using EVTD = cutlass::epilogue::threadblock::Sm80EVT<
    D,
    EVTCompute0>;

using EVTKernel =
    typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
        DType, LayoutA, cutlass::ComplexTransform::kNone, Alignment,
        DType, LayoutB, cutlass::ComplexTransform::kNone, Alignment,
        DType, LayoutC, Alignment,
        DType,
        DType,
        cutlass::arch::OpClassTensorOp,
        cutlass::arch::Sm80,
        ThreadblockShape,
        WarpShape,
        InstructionShape,
        EVTD,
        ThreadBlockSwizzle,
        NumStages,
        cutlass::arch::OpMultiplyAdd,
        EVTEpilogueStages>::GemmKernel;

using DeviceGemmEVT = cutlass::gemm::device::GemmUniversalAdapter<EVTKernel>;

template <typename Gemm>
void test(int M = 8, int N = 4, int K = 8, bool verbose = true, int batch_count = 1,
          cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm,
          DType alpha = DType(1.0), DType beta = DType(0.0))
{
    cutlass::gemm::GemmCoord problem_size = cutlass::gemm::GemmCoord(M, N, K);
    cutlass::HostTensor<typename Gemm::ElementA, typename Gemm::LayoutA> tensor_A;
    cutlass::HostTensor<typename Gemm::ElementB, typename Gemm::LayoutB> tensor_B;
    cutlass::HostTensor<DType, typename Gemm::LayoutC> tensor_Z;
    cutlass::HostTensor<DType, typename Gemm::LayoutC> tensor_Broadcast;

    tensor_A.resize({problem_size.m(), problem_size.k()});
    tensor_B.resize({problem_size.k(), problem_size.n()});
    tensor_Z.resize({problem_size.m(), problem_size.n()});
    tensor_Broadcast.resize({problem_size.m(), 1});
    cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), tensor_A.capacity());
    cutlass::reference::host::BlockFill(tensor_B.host_data(), tensor_B.capacity(), DType(1.0));

    cutlass::reference::host::BlockFill(tensor_Z.host_data(), tensor_Z.capacity(), DType(0.0));
    cutlass::reference::host::BlockFillSequential(tensor_Broadcast.host_data(), tensor_Broadcast.capacity());

    tensor_A.sync_device();
    tensor_B.sync_device();
    tensor_Z.sync_device();
    tensor_Broadcast.sync_device();
    if (verbose)
    {
        std::cout << "tensor_A:\n"
                  << tensor_A.host_view() << std::endl;
        std::cout << "tensor_B:\n"
                  << tensor_B.host_view() << std::endl;
        std::cout << "tensor_Broadcast:\n"
                  << tensor_Broadcast.host_view() << std::endl;
    }
    typename EVTD::Arguments callback_args{
        {
            {},                                                                                  // Compute0
            {tensor_Broadcast.device_data(), DType(0), {_0{}, _1{}, int32_t(problem_size.m())}}, // bias / scale
            {}                                                                                   // Accum
        },                                                                                       // EvtCompute0
        {tensor_Z.device_data(), {problem_size.n(), _1{}, problem_size.mn().product()}},         // D
    };
    typename EVTKernel::Arguments evtArgs{
        cutlass::gemm::GemmUniversalMode::kGemm, // universal mode
        problem_size,                            // problem_size
        1,                                       // batch count / splitk slices
        callback_args,                           // argument of EVT callbacks
        tensor_A.device_data(),                  // ptr_A
        tensor_B.device_data(),                  // ptr_B
        nullptr,                                 // ptr_C (unused)
        nullptr,                                 // ptr_D (unused)
        problem_size.mk().product(),             // batch_stride_A
        problem_size.nk().product(),             // batch_stride_B
        0,                                       // batch_stride_C (unused)
        0,                                       // batch_stride_D (unused)
        tensor_A.layout().stride(0),             // stride_a
        tensor_B.layout().stride(0),             // stride_b
        0,                                       // stride_c (unused)
        0                                        // stride_d (unused)
    };
}
int main()
{
    int M = 8;
    int N = 8;
    int K = 8;

    std::cout << "GemmEVT" << std::endl;
    test<EVTKernel>(M, N, K);
}
hgyhungry commented 3 months ago

Not an expert but I recently made the exactly same problem when crafting my custom epilogue visitor tree. Here is what I think : for your first problem (cute type): Change

using Scale = cutlass::epilogue::threadblock::VisitorColBroadcast<
    OutputTileThreadMap, cutlass::half_t,
    cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;

To

using Scale = cutlass::epilogue::threadblock::VisitorColBroadcast<
    OutputTileThreadMap, cutlass::half_t,
    cute::Stride<cute::Int<1>, cute::Int<0>, int32_t>>; // <-- the only change is the type of last parameter, which should be int or int64_t

Do the same with type D

hgyhungry commented 3 months ago

For your second problem, the reason why the example works well with ThreadblockSwizzleStreamK, but doesn't work with GemmIdentityThreadblockSwizzle<>, is because underlying the DefaultGemmWithVisitor class, this difference in parameter leads to dispatch to different GemmKernel types as described here https://github.com/NVIDIA/cutlass/blob/19f3cc33f1642b490ed7126ea0141f79c0045527/include/cutlass/gemm/kernel/default_gemm_universal_with_visitor.h#L140

to me more specific, the example 47 uses GemmWithEpilogueVisitorStreamk (link) as the EVTKernelStreamK, but your example will use GemmWithEpilogueVisitor (link) as your EVTKernel. With that clear, apparently the GemmWithEpilogueVisitor contains some issues while GemmWithEpilogueVisitorStreamK works just fine. I had my own fix, but I suggest you post your whole error message and let a maintainer expert fix it.

thakkarV commented 3 months ago

@hwu36

github-actions[bot] commented 1 month ago

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.