NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
5.44k stars 921 forks source link

[BUG] Python `EVT` `Pytorch` Emitter Broken #1462

Open jeromeku opened 5 months ago

jeromeku commented 5 months ago

Describe the bug The Python pytorch emitter does not output functioning code when compiling Gemm with an EVT.

Steps/Code to reproduce bug The script below reproduces the bug.

Switch jit to True when calling cutlass.emit.pytorch to see the generated code (see additional context, as well).

import torch
import cutlass
from cutlass import Tensor as FakeTensor

print_module = True

m = 8
n = 8
k = 8

type_A = torch.float16
type_B = torch.float16
type_C = torch.float16
type_D = torch.float16

tensor_A = torch.arange(m * k, dtype=type_A, device="cuda").reshape(m, k)
tensor_B = torch.ones(n * k, dtype=type_B, device="cuda").reshape(k, n)
tensor_C = torch.zeros(m * n, dtype=type_C, device="cuda").reshape(m, n)
tensor_D = torch.zeros_like(tensor_C)

plan = cutlass.op.Gemm(
    element=torch.float16,
    layout=cutlass.LayoutType.RowMajor,
    element_accumulator=torch.float32,
)

def epilogue_scale(accum, scale):
    D = scale * accum
    return D

# Construct inputs and outputs
scale = torch.arange(m, dtype=type_C, device="cuda").reshape(m, 1)
examples_tensors = {
    "accum": FakeTensor(
        element=torch.float32, shape=(m, n), layout_tag=cutlass.LayoutType.RowMajor
    ),
    "scale": scale,
    "D": tensor_D,
}

epilogue_visitor = cutlass.epilogue.trace(epilogue_scale, examples_tensors)
visitor_args = {"scale": scale, "D": tensor_D}

plan.epilogue_visitor = epilogue_visitor

#This works
plan.run(
    tensor_A,
    tensor_B,
    tensor_C,
    tensor_D,
    visitor_args=visitor_args,
    print_module=print_module,
)

binary_op = torch.mul
ref_D = binary_op(tensor_A @ tensor_B, scale)
print(f"ref_D =\n {ref_D}")
print(f"tensor_D =\n {tensor_D}")
print(f"(tensor_D - ref_D).abs().max() = {(tensor_D - ref_D).abs().max()}")

# Below does not work, set jit to False which shows the generated code, which is incorrect
op = plan.construct()
mod = cutlass.emit.pytorch(
    op, name="epilogue_broadcast", cc=plan.cc, sourcedir="epilogue", jit=True
)

Expected behavior Expect the jitted pytorch module to work per the non-pytorch version (using plan.run, which compiles and runs the kernel directly through pycuda / C interface).

Environment details (please complete the following information):

Additional Context Below is the generated extension module (with jit set to False).

Issues:

// This file was automatically generated by the CUTLASS 3.5.0 Python interface (https://github.com/nvidia/cutlass/python)

#include <cuda_runtime.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include "cutlass/cutlass.h"
#include "cutlass/util/device_memory.h"

#include "cutlass/gemm_coord.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/default_gemm_universal.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"

// helper function allocating the memory
void *device_memory_allocation(size_t size, int device_id = 0)
{
    if (size > 0)
    {
        torch::Device device(torch::kCUDA, device_id);
        cudaStream_t stream = at::cuda::getCurrentCUDAStream();
        torch::TensorOptions options = torch::TensorOptions().dtype(torch::kI8).device(device);
        at::Tensor device_tensor = torch::empty({
                                                    (long)size,
                                                },
                                                options);
        return reinterpret_cast<void *>(device_tensor.data_ptr());
    }
    else
    {
        return nullptr;
    }
}

#include "cutlass/gemm/device/gemm_universal.h"

using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
    cutlass::gemm::GemmShape<256, 128, 32>,
    cutlass::gemm::GemmShape<64, 64, 32>,
    cutlass::half_t,
    8,
    1 /* epilogue stages */
    >;

using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;

using Scale = cutlass::epilogue::threadblock::VisitorColBroadcast<
    OutputTileThreadMap, cutlass::half_t,
    cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;

using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
    cutlass::multiplies, cutlass::half_t, float,
    cutlass::FloatRoundStyle::round_to_nearest>;

using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT<
    Compute0,
    Scale,
    Accum>;

using D = cutlass::epilogue::threadblock::VisitorAuxStore<
    OutputTileThreadMap, cutlass::half_t, cutlass::FloatRoundStyle::round_to_nearest,
    cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>>;

using EVTD = cutlass::epilogue::threadblock::Sm80EVT<
    D,
    EVTCompute0>;

// Gemm operator cutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8
using cutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8_base =
    typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
        cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
        cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
        cutlass::half_t, cutlass::layout::RowMajor, 8,
        float,
        float,
        cutlass::arch::OpClassTensorOp,
        cutlass::arch::Sm80,
        cutlass::gemm::GemmShape<256, 128, 32>,
        cutlass::gemm::GemmShape<64, 64, 32>,
        cutlass::gemm::GemmShape<16, 8, 16>,
        EVTD,
        cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
        3,
        cutlass::arch::OpMultiplyAdd,
        1 /* epilogue stages */
        >::GemmKernel;

// Define named type
struct cutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8_type : public cutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8_base
{
};

using DeviceKernel = cutlass_tensorop_f16_s16816gemm_f16_256x128_32x3_tt_align8_type;
using ElementCompute = typename DeviceKernel::ElementC;

cutlass::Status epilogue_broadcast_kernel_run(int M, int N, int K,
                                              const DeviceKernel::ElementA *A, const DeviceKernel::ElementB *B, const DeviceKernel::ElementC *C, DeviceKernel::ElementC *D,
                                              ElementCompute alpha, ElementCompute beta)
{

    typename DeviceKernel::Arguments arguments{
        cutlass::gemm::GemmUniversalMode::kGemm,
        {M, N, K}, // problem size
        1,
        {alpha, beta},
        A,
        B,
        C,
        D,
        0,
        0,
        0,
        0,                                               // batch strides
        DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda
        DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb
        DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc
        DeviceKernel::LayoutC::packed({M, N}).stride(0)  // ldd
    };

    size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
    cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

    DeviceKernel gemm_op;
    cutlass::Status status = gemm_op.initialize(arguments,
                                                workspace.get(),
                                                nullptr); // CUDA stream

    if (status != cutlass::Status::kSuccess)
    {
        return status;
    }

    status = gemm_op();
    return status;
}

at::Tensor epilogue_broadcast_kernel(const at::Tensor &A, const at::Tensor &B, at::optional<const at::Tensor> C, float alpha, float beta)
{
    int M = A.size(0);
    int N = B.size(1);
    int K = A.size(1);

    typename DeviceKernel::ElementC *ptrC = (C == at::nullopt) ? nullptr : reinterpret_cast<typename DeviceKernel::ElementC *>(C->contiguous().data_ptr());
    at::Tensor D = B.new_empty({M, N}, torch::kF16);

    cutlass::Status status = epilogue_broadcast_kernel_run(M, N, K,
                                                           reinterpret_cast<typename DeviceKernel::ElementA *>(A.contiguous().data_ptr()),
                                                           reinterpret_cast<typename DeviceKernel::ElementB *>(B.contiguous().data_ptr()),
                                                           ptrC,
                                                           reinterpret_cast<typename DeviceKernel::ElementC *>(D.contiguous().data_ptr()),
                                                           ElementCompute(alpha), ElementCompute(beta));

    TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
    return D;
}
jackkosaian commented 5 months ago

We haven't yet done the plumbing to emit the correct EVT arguments structures for creating a PyTorch extension for a kernel that uses EVT. Apologies that this hasn't been better documented and lacks a clear error indicating the lack of support.

jeromeku commented 5 months ago

@jackkosaian Thanks for the response.

Are there any examples or documentation on how to properly construct arguments for an EVT, other than the streamk example?

Moreover, I'm having trouble with the different epilogue interfaces, #1459, for a relatively simple example. Would appreciate any help!

github-actions[bot] commented 4 months ago

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

github-actions[bot] commented 1 month ago

This issue has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.