NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
5.65k stars 963 forks source link

[DOC]Need doc to migrate from cutlass::conv::kernel::DefaultConv2dFprop to cutlass::conv::kernel::ConvUniversal #1911

Open chacha21 opened 1 week ago

chacha21 commented 1 week ago

I was finally able thanks to recent updates to use CUTLASS to perform a basic 2d row convolution with strided input and output (see #1323) However, I have understood that 3.6 will push the cutlass::conv::kernel::ConvUniversal API to the new good practice.

I was unable to understand how to feed the epilogue parameters, the examples are just too hard to understand. Could there be a doc to help such a migration ?

Here is my code so far

#include "cudaCUTLASS3.hpp"

#include <cutlass/cutlass.h>
#include "cutlass/conv/device/conv_universal_adapter.hpp"
#include "cutlass/conv/kernel/conv_universal.hpp"
#include "cutlass/conv/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include <cutlass/util/device_memory.h>

using ElementAct     = float;
using ElementFlt     = float;
using ElementOut     = float;
using ElementAcc     = float;
using ElementCompute = float;
using TileShapeMNK = cutlass::Shape<cutlass::_64, cutlass::_64, cutlass::Shape<cutlass::_32>>;
using ClusterShapeMNK = cutlass::Shape<cutlass::_1,cutlass::_1,cutlass::_1>;

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
  cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
  TileShapeMNK, ClusterShapeMNK,
  cutlass::epilogue::collective::EpilogueTileAuto,
  ElementAcc, ElementCompute,
  float, cutlass::layout::TensorNDHWC, 4,
  float, cutlass::layout::TensorNDHWC, 4,
  cutlass::epilogue::TmaWarpSpecialized
>::CollectiveOp;

using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder<
  cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
  cutlass::conv::Operator::kFprop,
  ElementAct, cutlass::layout::TensorNDHWC, 4,
  ElementFlt, cutlass::layout::TensorNDHWC, 4,
  ElementAcc,
  TileShapeMNK, ClusterShapeMNK,
  cutlass::conv::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
  cutlass::conv::collective::KernelScheduleAuto
>::CollectiveOp;

using ProblemShape=cutlass::conv::ConvProblemShape<CollectiveMainloop::DispatchPolicy::ConvOp, CollectiveMainloop::DispatchPolicy::NumSpatialDimensions>;
using ConvKernel = cutlass::conv::kernel::ConvUniversal<
  ProblemShape,
  CollectiveMainloop,
  CollectiveEpilogue
>;

using Conv = cutlass::conv::device::ConvUniversalAdapter<ConvKernel>;

void convolutionCUTLASS3Row(const float* src, size_t srcStride, float* dst, size_t dstStride, int width, int height, const float* kernelData, int kernelRadius, cudaStream_t stream)
{
  //src is a 2D array with stride srcStride (in bytes)
  //dst is a 2D array with stride dstStride (in bytes)
  //kernelData is a 1D array with 2*kernelRadius+1 elements

  const int kernelDiameter = 2 * kernelRadius + 1;

  const int srcStrideInElements = static_cast<int>(srcStride / sizeof(float));
  const int dstStrideInElements = static_cast<int>(dstStride / sizeof(float));

  cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kFprop, 2> problem_shape(
    cutlass::conv::Mode::kConvolution,
    {1, 1, height, width, 1},
    {height*srcStrideInElements, height*srcStrideInElements, srcStrideInElements, 1, 1},
    {1, 1, 1, kernelDiameter, 1},
    {kernelDiameter, kernelDiameter, kernelDiameter, kernelDiameter, 1},
    {height*dstStrideInElements, height*dstStrideInElements, dstStrideInElements, 1, 1},
    {0, 0},
    {0, 0},
    {0, 0},
    1
  );

  auto stride_C = Conv::ConvKernel::StrideC{};
  auto stride_D = Conv::ConvKernel::StrideD{};
  cute::for_each(cute::make_seq<cute::rank<0>(Conv::ConvKernel::StrideC{})>{}, [&](auto i) {
    cute::get<0, i>(stride_C) = problem_shape.stride_C[ProblemShape::RankT-2-i];
    });
  cute::for_each(cute::make_seq<cute::rank<0>(Conv::ConvKernel::StrideD{})>{}, [&](auto i) {
    cute::get<0, i>(stride_D) = problem_shape.stride_C[ProblemShape::RankT-2-i];
    });

  cutlass::KernelHardwareInfo hw_info;
  cudaGetDevice(&hw_info.device_id);
  hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);

  auto mainLoop_args = typename Conv::ConvKernel::MainloopArguments {
    src,
    kernelData
  };
  auto epilogue_args = typename Conv::ConvKernel::EpilogueArguments {
    /*dst,
    stride_C,
    dst,
    stride_D,*/
  };
  typename Conv::ConvKernel::TileScheduler::Arguments scheduler_args{};

  auto arguments = typename Conv::Arguments {
    /*problem_shape,
    mainLoop_args,
    epilogue_args,
    hw_info,
    scheduler_args*/
  };

  cutlass::Status status;
  Conv conv_op;
  status = conv_op.can_implement(arguments);
  size_t workspace_size = conv_op.get_workspace_size(arguments);
  cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

  status = conv_op.initialize(arguments, workspace.get(), stream);
  status = conv_op.run(stream);
}
//end convolutionCUTLASS3Row()
thakkarV commented 1 week ago

I was unable to understand how to feed the epilogue parameters, the examples are just too hard to understand.

The epilogue builder API is identical and has not changed since original 3.0 release. The input parameters are arch agnostic and do not require the user to do anything fancy other than specify the parameters of their input problem traits. The template parameters are here: https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/epilogue/collective/collective_builder.hpp#L53

and their names are for the most part self documenting. Which of these are you having a difficult time with? happy to explain

chacha21 commented 1 week ago

and their names are for the most part self documenting. Which of these are you having a difficult time with? happy to explain

I can't tell how to fill

auto epilogue_args = typename Conv::ConvKernel::EpilogueArguments {
    /*dst,
    stride_C,
    dst,
    stride_D,*/
  };
  typename Conv::ConvKernel::TileScheduler::Arguments scheduler_args{};

  auto arguments = typename Conv::Arguments {
    /*problem_shape,
    mainLoop_args,
    epilogue_args,
    hw_info,
    scheduler_args*/
  };

When you say

the input parameters are arch agnostic and do not require the user to do anything fancy other than specify the parameters of their input problem traits you have certainly acquired a knowledge about what CUTLASS expects, but I just get lost in the templates.

I do not have any "problem traits", I just have a basic float* pointer to 2D strided data. I understand that I have to build a tensor wrapper, but from cute::, from cutlass::device:: ? Look at the code that I have already written : it's just cherry picking from existing examples, trying to guess how it must be mapped to my very basic 2D convolution goal.

thakkarV commented 1 week ago

https://github.com/NVIDIA/cutlass/blob/main/test/unit/conv/device_3x/testbed_conv.hpp#L306

This is the test bed code that fills up the arguments before calling the kernel. Hope it helps.

chacha21 commented 1 week ago

As you can see with my sample code, https://github.com/NVIDIA/cutlass/blob/main/test/unit/conv/device_3x/testbed_conv.hpp#L306 is indeed the code I used to take some inspiration.

But it explains absolutely nothing about where to find the implementable classes matching the expected trait templates. How am I supposed to wrap my float* ? Are there good practices, modern and deprecated wrappers ?

A navigable doc giving explanations for each template argument would be invaluable for CUTLASS. Being a CUTLASS newbie is just being lost in an ocean of templates, without any roadmap.

chacha21 commented 1 week ago

Where can I ask for help to understand how to fill correct EpilogueArguments and Conv::Arguments ? I am now able to compile, but there are assert errors at runtime, certainly because of invalid strides or shapes.

Here is an update to the code I try to get working :

#include "cudaCUTLASS3.hpp"

#include <cutlass/cutlass.h>
#include "cutlass/conv/device/conv_universal_adapter.hpp"
#include "cutlass/conv/kernel/conv_universal.hpp"
#include "cutlass/conv/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include <cutlass/util/device_memory.h>

using ElementAct     = float;
using ElementFlt     = float;
using ElementOut     = float;
using ElementAcc     = float;
using ElementCompute = float;
using TileShapeMNK = cutlass::Shape<cutlass::_64, cutlass::_64, cutlass::Shape<cutlass::_32>>;
using ClusterShapeMNK = cutlass::Shape<cutlass::_1,cutlass::_1,cutlass::_1>;

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
  cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
  TileShapeMNK, ClusterShapeMNK,
  cutlass::epilogue::collective::EpilogueTileAuto,
  ElementAcc, ElementCompute,
  float, cutlass::layout::TensorNHWC, 4,
  float, cutlass::layout::TensorNHWC, 4,
  cutlass::epilogue::TmaWarpSpecialized
>::CollectiveOp;

using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder<
  cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
  cutlass::conv::Operator::kFprop,
  ElementAct, cutlass::layout::TensorNHWC, 4,
  ElementFlt, cutlass::layout::TensorNHWC, 4,
  ElementAcc,
  TileShapeMNK, ClusterShapeMNK,
  cutlass::conv::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
  cutlass::conv::collective::KernelScheduleAuto
>::CollectiveOp;

using ProblemShape=cutlass::conv::ConvProblemShape<CollectiveMainloop::DispatchPolicy::ConvOp, CollectiveMainloop::DispatchPolicy::NumSpatialDimensions>;
using ConvKernel = cutlass::conv::kernel::ConvUniversal<
  ProblemShape,
  CollectiveMainloop,
  CollectiveEpilogue
>;

using Conv = cutlass::conv::device::ConvUniversalAdapter<ConvKernel>;

void convolutionCUTLASS3Row(const float* src, size_t srcStride, float* dst, size_t dstStride, int width, int height, const float* kernelData, int kernelRadius, cudaStream_t stream)
{
  //src is a 2D array with stride srcStride (in bytes)
  //dst is a 2D array with stride dstStride (in bytes)
  //kernelData is a 1D array with 2*kernelRadius+1 elements

  const int kernelDiameter = 2 * kernelRadius + 1;

  const int srcStrideInElements = static_cast<int>(srcStride / sizeof(float));
  const int dstStrideInElements = static_cast<int>(dstStride / sizeof(float));

  Conv::ConvKernel::ProblemShape problem_shape(
    cutlass::conv::Mode::kConvolution,
    {1, height, width, 1},
    {height*srcStrideInElements, height*srcStrideInElements, srcStrideInElements, 1},
    {1, 1, kernelDiameter, 1},
    {kernelDiameter, kernelDiameter, kernelDiameter, 1},
    {0, 0},
    {0, 0},
    {1, 1},
    {0, 0},
    1
  );

  auto stride_C = Conv::ConvKernel::StrideC {{height*dstStrideInElements, dstStrideInElements, 1}, cute::_1(), cute::_0()};
  auto stride_D = Conv::ConvKernel::StrideD {{height*dstStrideInElements, dstStrideInElements, 1}, cute::_1(), cute::_0()};

  cutlass::KernelHardwareInfo hw_info;
  cudaGetDevice(&hw_info.device_id);
  hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);

  auto mainLoop_args = typename Conv::ConvKernel::MainloopArguments {
    src,
    kernelData
  };
  auto epilogue_args = typename Conv::ConvKernel::EpilogueArguments {
    {},
    dst,
    stride_C,
    dst,
    stride_D
  };
  typename Conv::ConvKernel::TileScheduler::Arguments scheduler_args{};

  auto arguments = typename Conv::Arguments(
    problem_shape,
    mainLoop_args,
    epilogue_args,
    hw_info,
    scheduler_args
  );

  cutlass::Status status;
  Conv conv_op;
  status = conv_op.can_implement(arguments);
  size_t workspace_size = conv_op.get_workspace_size(arguments);
  cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

  status = conv_op.initialize(arguments, workspace.get(), stream);
  status = conv_op.run(stream);
}
//end convolutionCUTLASS3Row()
thakkarV commented 1 week ago

this blog post is written in collaboration with us and explains how the epilogue arguments are constructed for arbitrarily complex EVT fusions: https://research.colfax-intl.com/epilogue_visitor_tree/

thakkarV commented 1 week ago

I am now able to compile, but there are assert errors at runtime, certainly because of invalid strides or shapes.

What does this manifest as. You can always compile with CUTLASS_DEBUG_TRACE_LEVEL=1 to get more host tracing debug info

chacha21 commented 1 week ago

CUTLASS_DEBUG_TRACE_LEVEL

I get a CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA. I had no such a problem with same data and CUTLASS 2.x kernels (see below). It looks like my "cutlass parameters" are not what they should be.

this blog post is written in collaboration with us and explains how the epilogue arguments are constructed for arbitrarily complex EVT fusions: https://research.colfax-intl.com/epilogue_visitor_tree/

The doc of CUTLASS always has very high-level abstracted examples with N-dimensional scenarios and auto-generated template test codes. It really lacks a good'old pragmatic example like : "for instance, here is a minimal didactic code showing how to build and feed a 2D convolution with a 16x10 2D array by a 8x1 kernel."

In my code "convolutionCUTLASS3Row()", which is a very basic usage of what CUTLASS could shine at, I think only building the correct parameters is faulty. And I can't find anybody able to understand what CUTLASS expects as input data.

Here is my cutlass 2.x code that works with the same input (for which I mostly guessed that I could use cutlass::make_TensorRef(), which is not really documented; I just found it by chance by browsing.)

void convolutionCUTLASSRow(const float* src, size_t srcStride, float* dst, size_t dstStride, int width, int height, const float* kernelData, int kernelRadius, cudaStream_t stream)
{
  Conv2dFprop implicit_gemm_op;

  const int kernelDiameter = 2 * kernelRadius + 1;
  cutlass::Tensor4DCoord input_size(1, height, width, 1);
  cutlass::Tensor4DCoord filter_size(1, 1, kernelDiameter, 1);
  cutlass::Tensor4DCoord output_size(1, height, width, 1);

  cutlass::conv::Conv2dProblemSize problem_size(
    input_size,
    filter_size,
    cutlass::Tensor4DCoord(0, 0, kernelRadius, 0),
    cutlass::MatrixCoord(1, 1),
    cutlass::MatrixCoord(1, 1),
    output_size,
    cutlass::conv::Mode::kConvolution,
    1
  );

  const int srcStrideInElements = static_cast<int>(srcStride / sizeof(float));
  cutlass::layout::TensorNHWC src_layout(1, srcStrideInElements, height * srcStrideInElements);
  auto tensor_src = cutlass::make_TensorRef(const_cast<float*>(src), src_layout);

  cutlass::layout::TensorNHWC ker_layout(1, kernelDiameter, kernelDiameter);
  auto tensor_ker = cutlass::make_TensorRef(const_cast<float*>(kernelData), ker_layout);

  const int dstStrideInElements = static_cast<int>(dstStride / sizeof(float));
  cutlass::layout::TensorNHWC dst_layout(1, dstStrideInElements, height * dstStrideInElements);
  auto tensor_dst = cutlass::make_TensorRef(dst, dst_layout);

  using Arguments = typename Conv2dFprop::Arguments;
  Arguments arguments = Arguments(
    problem_size,
    tensor_src,
    tensor_ker,
    tensor_dst,
    tensor_dst,
    { 1.f, 0.f },
    cutlass::conv::SplitKMode::kSerial
  );

  cutlass::Status status;
  status = implicit_gemm_op.can_implement(arguments);

  size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
  cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

  status = implicit_gemm_op.initialize(arguments, workspace.get(), stream);
  status = implicit_gemm_op.run(stream);
}
//end convolutionCUTLASSRow()
thakkarV commented 1 week ago

Well, the error tell you what is going wrong there. In general, please always call can_implement() yourself on any problem shape to make sure the kernel can actually accept the runtime arguments you have given it. In this case, one of your input tensors does not satisfy the alignment requirements of TMA. This would have been the case with 2.x or 3.x

To get the arguments right, I think you can greatly simplify your templates first to match the problem you are running. You are stamping out a 3D fprop conv template but only running a 1D conv within it. Why not just stamp out a 1D conv which will get you some performance benefit while also simplifying the arguments. Examples here: https://github.com/NVIDIA/cutlass/blob/main/test/unit/conv/device_3x/fprop/sm90_conv1d_fprop_implicit_gemm_tf32_tf32_f32_tensorop_f32.cu#L57

Using a 1D/2D conv here will simplify the shapes and strides so there is less of a chance of your parameters being incorrect.

It really lacks a good'old pragmatic example like : "for instance, here is a minimal didactic code showing how to build and feed a 2D convolution with a 16x10 2D array by a 8x1 kernel."

Noted, we can work on adding a conv API example. @hwu36 @yzhaiustc CC

chacha21 commented 6 days ago

You are right, my problem can be expressed as NWC (N being H for my strided 2D data) rather than NHWC (with N=1) But I can't get rid of the "CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.",, which makes no sense to me. As I said, the same input data does work with the Cutlass 2.x code I have posted. So I just think that I get something wrong in feeding the parameters. Any help ?

I don't see how I could simplify more.

using ElementAct     = float;
using ElementFlt     = float;
using ElementOut     = float;
using ElementAcc     = float;
using ElementCompute = float;
using TileShapeMNK = cutlass::Shape<cutlass::_64, cutlass::_64, cutlass::Shape<cutlass::_32>>;
using ClusterShapeMNK = cutlass::Shape<cutlass::_1,cutlass::_1,cutlass::_1>;

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
  cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
  TileShapeMNK, ClusterShapeMNK,
  cutlass::epilogue::collective::EpilogueTileAuto,
  ElementAcc, ElementCompute,
  float, cutlass::layout::TensorNWC, 4,
  float, cutlass::layout::TensorNWC, 4,
  cutlass::epilogue::TmaWarpSpecialized
>::CollectiveOp;

using CollectiveMainloop = typename cutlass::conv::collective::CollectiveBuilder<
  cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
  cutlass::conv::Operator::kFprop,
  ElementAct, cutlass::layout::TensorNWC, 4,
  ElementFlt, cutlass::layout::TensorNWC, 4,
  ElementAcc,
  TileShapeMNK, ClusterShapeMNK,
  cutlass::conv::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
  cutlass::conv::collective::KernelScheduleAuto
>::CollectiveOp;

using ProblemShape=cutlass::conv::ConvProblemShape<CollectiveMainloop::DispatchPolicy::ConvOp, CollectiveMainloop::DispatchPolicy::NumSpatialDimensions>;
using ConvKernel = cutlass::conv::kernel::ConvUniversal<
  ProblemShape,
  CollectiveMainloop,
  CollectiveEpilogue
>;

using Conv = cutlass::conv::device::ConvUniversalAdapter<ConvKernel>;

void convolutionCUTLASS3Row(const float* src, size_t srcStride, float* dst, size_t dstStride, int width, int height, const float* kernelData, int kernelRadius, cudaStream_t stream)
{
  //src is a 2D array with stride srcStride (in bytes)
  //dst is a 2D array with stride dstStride (in bytes)
  //kernelData is a 1D array with 2*kernelRadius+1 elements

  const int kernelDiameter = 2 * kernelRadius + 1;

  const int srcStrideInElements = static_cast<int>(srcStride / sizeof(float));
  const int dstStrideInElements = static_cast<int>(dstStride / sizeof(float));

  Conv::ConvKernel::ProblemShape problem_shape(
    cutlass::conv::Mode::kConvolution,
    {height, width, 1},
    {height*srcStrideInElements, srcStrideInElements, 1},
    {1, kernelDiameter, 1},
    {kernelDiameter, kernelDiameter, 1},
    {0},
    {0},
    {1},
    {0},
    1
  );

  auto stride_C = Conv::ConvKernel::StrideC {{dstStrideInElements, 1}, cute::_1(), cute::_0()};
  auto stride_D = Conv::ConvKernel::StrideD {{dstStrideInElements, 1}, cute::_1(), cute::_0()};

  cutlass::KernelHardwareInfo hw_info;
  cudaGetDevice(&hw_info.device_id);
  hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);

  auto mainLoop_args = typename Conv::ConvKernel::MainloopArguments {
    src,
    kernelData
  };
  auto epilogue_args = typename Conv::ConvKernel::EpilogueArguments {
    {},
    dst,
    stride_C,
    dst,
    stride_D
  };
  typename Conv::ConvKernel::TileScheduler::Arguments scheduler_args{};

  auto arguments = typename Conv::Arguments(
    problem_shape,
    mainLoop_args,
    epilogue_args,
    hw_info,
    scheduler_args
  );

  cutlass::Status status;
  Conv conv_op;
  status = conv_op.can_implement(arguments);
  size_t workspace_size = conv_op.get_workspace_size(arguments);
  cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

  status = conv_op.can_implement(arguments);
  if (status == cutlass::Status::kSuccess)
  {
    status = conv_op.initialize(arguments, workspace.get(), stream);
    status = conv_op.run(stream);
  }//end if (status == cutlass::Status::kSuccess)
}
//end convolutionCUTLASS3Row()
thakkarV commented 6 days ago

well, the problem is -- C is the major mode here and you are setting C shape to 1, but the templates you are using promise a vectorization of 4. For the kernel you are stamping out, channels have to be a multiple of 4. The runtime tracing is telling you exactly what is wrong.

which makes no sense to me.

What would be a clearer way we could rephrase this error message?

chacha21 commented 6 days ago

well, the problem is -- C is the major mode here and you are setting C shape to 1, but the templates you are using promise a vectorization of 4. For the kernel you are stamping out, channels have to be a multiple of 4. The runtime tracing is telling you exactly what is wrong.

which makes no sense to me.

What would be a clearer way we could rephrase this error message?

Ok, so here are a few remarks (I have not succeeded yet)

[edit]

thakkarV commented 6 days ago

I can't remember where I read a doc for the "Alignment" template parameter : I thought is was in bytes.

Totally fair. We should document this. Both 2.x and 3.x API treat alignment in units of elements.

I tried to set an Alignment (template parameter) to 1, but it leads to a compile error

non-TMA based (and therefore non-16B aligned) convolutions are not supported for Hopper yet. This could be a better error message that we emit instead. The static assert should clearly say that lower alignment is unsupported. @yzhaiustc CC

You talk about "C shape", but the ConvProblemShape() constructor arguments use names like "_act" and "_flt", which is very confusing

My bad. By C shape I mean the shape of the output tensor for fprop.

and I still don't understand why the problem_shape holds stride_C and stride_D

It does not contain a stride_D? just a stride_C. That is computed for you assuming a compact output tensor but you are not under any obligation to use it for the epilogue. You are welcome to provide any non-compact strides for the C and D tensors yourself. We cannot know the strides of the output tensor based in input shapes, only the shape of the output tensor can be inferred.

chacha21 commented 6 days ago

Moreover, I don't understand the impact of alignment on shape. You tell me : "channels have to be a multiple of 4". Here what is wrong with this assumption :

thakkarV commented 6 days ago
chacha21 commented 6 days ago
  • that's a property of your workload. If you must have C = 1, then you can continue using the 2.x kernel by compiling them for Hopper until we start supporting lower alignment kernels in 3.x API for Hopper

Currently, a 2D-strided convolution with CUTLASS 2.x is orders of magnitude slower that Npp. I expected improvements from 3.x

* Which tensor is that coming from? Your epilogue is also set to assume alignment of 4. Is the output channel count also a multiple of 4 with these shapes you are trying?

This code snippet is just the problem_shape (see my code). So it defines the shapes of input and output. I tried to make everything compatible with 4-alignement, but once again I failed. Even if I can't expect the numerical result I want with C=4 rather than C=1, it is at least supposed to run.

thakkarV commented 6 days ago

C is the input channel count. K would be the output channel count. It is not obvious to me what the output channel count is from looking at your problem shape but if I had to guess it is not a multiple of 4 which is why you are seeing this debug log message.

chacha21 commented 6 days ago

But where would K be specified if it is not in the ConvProblemShape() ?

thakkarV commented 6 days ago

output channel count K is the first shape of the filter tensor:

      TensorExtent shape_act,                                              // [n,d,h,w,c]
      TensorStride stride_act,                                             // [n,d,h,w,c]
      TensorExtent shape_flt,                                              // [k,t,r,s,c]
      TensorStride stride_flt,                                             // [k,t,r,s,c]

I see in your shape that it is set to 1, which violates the alignment assumption for the contiguous dim of the kernel you are stamping out.

chacha21 commented 6 days ago

Ok, that's on me. I was confused because with CUTLASS 2.x cutlass::conv::Conv2dProblemSize problem_size() has an explicit output_size. I missed that in CUTLASS 3.x the shape_flt (meaning filter, not so obvious at first sight) was not describing only the input filter, but also a part of the output. I wrongly filled shape_flt as [n,d,h,w,c] instead of [k,t,r,s,c] This is certainly the main mistake, I am looking forward to testing.