NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
5.61k stars 956 forks source link

[BUG] cudaMemcpy result = misaligned address #1608

Open Tangesion opened 4 months ago

Tangesion commented 4 months ago

I'm performing a batched matrix multiply of 4bit and I'm getting the error cudaMemcpy result = misaligned address. I followed the As an example, matrix C can be seen as

(0,0,0) | (0,0,1) | (0,0,2) | (1,0,0) | (1,0,1) | (1,0,2) |

(0,1,0) | (0,1,1) | (0,1,2) | (1,1,0) | (1,1,1) | (1,1,2) |

(0,2,0) | (0,2,1) | (0,2,2) | (1,2,0) | (1,2,1) | (1,2,2) |

(0,3,0) | (0,3,1) | (0,3,2) | (1,3,0) | (1,3,1) | (1,3,2) |

(0,4,0) | (0,4,1) | (0,4,2) | (1,4,0) | (1,4,1) | (1,4,2) |

(0,5,0) | (0,5,1) | (0,5,2) | (1,5,0) | (1,5,1) | (1,5,2) |

       batch 0          |           batch 1

where we denote each element with (batch_idx, row_idx, column_idx) In this example, batch size is 2, M is 6 and N is 3 The stride (batch_stride_C) between the first element of two batches is ldc * n

matrix A can be seen as

(0,0,0) | (0,0,1) | (1,0,0) | (1,0,1) |

(0,1,0) | (0,1,1) | (1,1,0) | (1,1,1) |

(0,2,0) | (0,2,1) | (1,2,0) | (1,2,1) |

(0,3,0) | (0,3,1) | (1,3,0) | (1,3,1) |

(0,4,0) | (0,4,1) | (1,4,0) | (1,4,1) |

(0,5,0) | (0,5,1) | (1,5,0) | (1,5,1) |

 batch 0      |      batch 1

, where batch size is 2, M is 6 and K is 2 The stride (batch_stride_A) between the first element of two batches is lda * k

matrix B can be seen as

(0,0,0) | (0,0,1) | (0,0,2) | ----------------------------- batch 0 (0,1,0) | (0,1,1) | (0,1,2) |

(1,0,0) | (1,0,1) | (1,0,2) | ----------------------------- batch 1 (1,1,0) | (1,1,1) | (1,1,2) |

, where the batch size is 2, N is 3 and K is 2 The stride (batch_stride_B) between the first element of two batches is k This is the format to perform my matrix multiplication, but unlike this example the A and C matrices are changed to row_major to accommodate the 4bit matrix multiplication Here is my code

#include <iostream>
#include <vector>

#include "cutlass/cutlass.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/gemm/device/gemm_array.h"
#include "cutlass/gemm/device/gemm_batched.h"

#pragma warning(disable : 4503)

/*
This example demonstrates how to use cutlass to compute a batched strided gemm in two different ways:
  1. By specifying pointers to the first matrices of the batch and the stride between the consecutive
     matrices of the batch (this is called a strided batched gemm).
  2. By copying pointers to all matrices of the batch to the device memory (this is called an array gemm).
In this example, both A and B matrix are non-transpose and column major matrix
batched_C = batched_A x batched_B
As an example, matrix C can be seen as
-----------------------------------------------------------
(0,0,0) | (0,0,1) | (0,0,2) | (1,0,0) | (1,0,1) | (1,0,2) |
-----------------------------------------------------------
(0,1,0) | (0,1,1) | (0,1,2) | (1,1,0) | (1,1,1) | (1,1,2) |
-----------------------------------------------------------
(0,2,0) | (0,2,1) | (0,2,2) | (1,2,0) | (1,2,1) | (1,2,2) |
-----------------------------------------------------------
(0,3,0) | (0,3,1) | (0,3,2) | (1,3,0) | (1,3,1) | (1,3,2) |
-----------------------------------------------------------
(0,4,0) | (0,4,1) | (0,4,2) | (1,4,0) | (1,4,1) | (1,4,2) |
-----------------------------------------------------------
(0,5,0) | (0,5,1) | (0,5,2) | (1,5,0) | (1,5,1) | (1,5,2) |
-----------------------------------------------------------
           batch 0          |           batch 1
where we denote each element with (batch_idx, row_idx, column_idx)
In this example, batch size is 2, M is 6 and N is 3
The stride (batch_stride_C) between the first element of two batches is ldc * n

matrix A can be seen as
---------------------------------------
(0,0,0) | (0,0,1) | (1,0,0) | (1,0,1) |
---------------------------------------
(0,1,0) | (0,1,1) | (1,1,0) | (1,1,1) |
---------------------------------------
(0,2,0) | (0,2,1) | (1,2,0) | (1,2,1) |
---------------------------------------
(0,3,0) | (0,3,1) | (1,3,0) | (1,3,1) |
---------------------------------------
(0,4,0) | (0,4,1) | (1,4,0) | (1,4,1) |
---------------------------------------
(0,5,0) | (0,5,1) | (1,5,0) | (1,5,1) |
---------------------------------------
     batch 0      |      batch 1
, where batch size is 2, M is 6 and K is 2
The stride (batch_stride_A) between the first element of two batches is lda * k

matrix B can be seen as
-----------------------------
(0,0,0) | (0,0,1) | (0,0,2) |
----------------------------- batch 0
(0,1,0) | (0,1,1) | (0,1,2) |
-------------------------------------
(1,0,0) | (1,0,1) | (1,0,2) |
----------------------------- batch 1
(1,1,0) | (1,1,1) | (1,1,2) |
-----------------------------
, where the batch size is 2, N is 3 and K is 2
The stride (batch_stride_B) between the first element of two batches is k

*/

cudaError_t cutlass_strided_batched_sgemm(
    int m,
    int n,
    int k,
    int32_t alpha,
    const cutlass::int4b_t *A,
    int lda,
    long long int batch_stride_A,
    const cutlass::int4b_t *B,
    int ldb,
    long long int batch_stride_B,
    int32_t *C,
    int ldc,
    long long int batch_stride_C,
    int32_t beta,
    int batch_count)
{

    typedef cutlass::int4b_t input_t;
    typedef int32_t output_t;

    using ElementAccumulator = output_t;
    using ElementCompute = output_t;
    using ElementInputA = input_t;                      
    using ElementInputB = input_t;                      
    using ElementOutput = output_t;                   
    const int pipe_stages = 2;

    using Gemm = cutlass::gemm::device::GemmBatched<
    cutlass::int4b_t, cutlass::layout::RowMajor, 
    cutlass::int4b_t, cutlass::layout::ColumnMajor, 
    ElementOutput, cutlass::layout::RowMajor,
    ElementAccumulator, 
    cutlass::arch::OpClassTensorOp, 
    cutlass::arch::Sm75,
    cutlass::gemm::GemmShape<128, 256, 128>,
    cutlass::gemm::GemmShape<64, 64, 128>, 
    cutlass::gemm::GemmShape<8, 8, 32>,
    cutlass::epilogue::thread::LinearCombination<ElementOutput, 
                                                4,
                                                ElementAccumulator, ElementCompute>,
    cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, 
    pipe_stages>;

    Gemm gemm_op;

    cutlass::Status status = gemm_op({{m, n, k},
                                      {A, lda},
                                      batch_stride_A,
                                      {B, ldb},
                                      batch_stride_B,
                                      {C, ldc},
                                      batch_stride_C,
                                      {C, ldc},
                                      batch_stride_C,
                                      {alpha, beta},
                                      batch_count}); 

    if (status != cutlass::Status::kSuccess)
    {
        return cudaErrorUnknown;
    }

    return cudaSuccess;
}

cudaError_t run_batched_gemm(bool use_array)
{

    const char *gemm_desc = use_array ? "array" : "strided batched";
    std::cout << "Running " << gemm_desc << " gemm" << std::endl;

    // Arbitrary problem size
    int const m = 6;
    int const n = 3;
    int const k = 2;
    int const batch_count = 2;

    // A, B, C are non-transpose, B column major, A,C row major
    int const lda = k * batch_count;
    int const ldb = k * batch_count;
    int const ldc = n * batch_count;

    int const count_A = lda * m;
    int const count_B = ldb * n;
    int const count_C = ldc * m;

    // the memory is batched along K dimension
    long long int batch_stride_A = static_cast<long long int>(k);
    long long int batch_stride_B = static_cast<long long int>(k);
    long long int batch_stride_C = static_cast<long long int>(n);

    // alpha and beta
    int32_t alpha = int32_t(1);
    int32_t beta = int32_t(0);

    cudaError_t result = cudaSuccess;

    // allocate the host memory
/*     std::vector<cutlass::int4b_t> host_A(count_A);
    std::vector<cutlass::int4b_t> host_B(count_B);
    std::vector<int32_t> host_C(count_C);
    std::vector<int32_t> result_C(count_C); */

    cutlass::int4b_t *host_A;
    cutlass::int4b_t *host_B;
    int32_t *host_C;
    int32_t *result_C;

    cudaError_t cuda_result = cudaMallocHost(&host_A, count_A * sizeof(cutlass::int4b_t));
    if (cuda_result != cudaSuccess)
    {
        std::cerr << "cudaMallocHost result = " << cuda_result << std::endl;
        return cuda_result;
    }
    cuda_result = cudaMallocHost(&host_B, count_B * sizeof(cutlass::int4b_t));
    if (cuda_result != cudaSuccess)
    {
        std::cerr << "cudaMallocHost result = " << cuda_result << std::endl;
        return cuda_result;
    }
    cuda_result = cudaMallocHost(&host_C, count_C * sizeof(int32_t));
    if (cuda_result != cudaSuccess)
    {
        std::cerr << "cudaMallocHost result = " << cuda_result << std::endl;
        return cuda_result;
    }
    cuda_result = cudaMallocHost(&result_C, count_C * sizeof(int32_t));
    if (cuda_result != cudaSuccess)
    {
        std::cerr << "cudaMallocHost result = " << cuda_result << std::endl;
        return cuda_result;
    }

    // allocate the device memory
    cutlass::int4b_t *A;
    cutlass::int4b_t *B;
    int32_t *C;

    result = cudaMalloc(&A, count_A * sizeof(cutlass::int4b_t));
    if (result != cudaSuccess)
    {
        std::cerr << "cudaMalloc result = " << result << std::endl;
        return result;
    }
    result = cudaMalloc(&B, count_B * sizeof(cutlass::int4b_t));
    if (result != cudaSuccess)
    {
        std::cerr << "cudaMalloc result = " << result << std::endl;
        return result;
    }
    result = cudaMalloc(&C, count_C * sizeof(int32_t));
    if (result != cudaSuccess)
    {
        std::cerr << "cudaMalloc result = " << result << std::endl;
        return result;
    }

    // Limit range to avoid floating-point errors
    int const kRange = 8;

    // fill A
    for (int b_idx = 0; b_idx < batch_count; b_idx++)
    {
        for (int row_idx = 0; row_idx < m; row_idx++)
        {
            for (int col_idx = 0; col_idx < k; col_idx++)
            {
                host_A[col_idx + lda * row_idx + b_idx * k] = static_cast<cutlass::int4b_t>(1);
            }
        }   
    }
    // fill B
    int i = 0;
    for (int b_idx = 0; b_idx < batch_count; b_idx++)
    {
        for (int col_idx = 0; col_idx < n; col_idx++)
        {
            for (int row_idx = 0; row_idx < k; row_idx++)
            {
                host_B[row_idx + col_idx * ldb + b_idx * k] = static_cast<cutlass::int4b_t>(i ++ );
            }
        }
    }
    // fill C
    for (int b_idx = 0; b_idx < batch_count; b_idx++)
    {
        for (int row_idx = 0; row_idx < m; row_idx++)
        {
            for (int col_idx = 0; col_idx < n; col_idx++)
            {
                host_C[col_idx + row_idx * ldc + b_idx * n] = static_cast<int32_t>(1);
            }
        }
    }

    // copy host memory to device
    result = cudaMemcpy(A, host_A, count_A * sizeof(cutlass::int4b_t), cudaMemcpyHostToDevice);
    if (result != cudaSuccess)
    {
        std::cerr << "cudaMemcpy result = " << result << std::endl;
        return result;
    }
    result = cudaMemcpy(B, host_B, count_B * sizeof(cutlass::int4b_t), cudaMemcpyHostToDevice);
    if (result != cudaSuccess)
    {
        std::cerr << "cudaMemcpy result = " << result << std::endl;
        return result;
    }
    result = cudaMemcpy(C, host_C, count_C * sizeof(int32_t), cudaMemcpyHostToDevice);
    if (result != cudaSuccess)
    {
        std::cerr << "cudaMemcpy result = " << result << std::endl;
        return result;
    }

    // run cutlass
    //print A
    std::cout << "A:" << std::endl;
    for (int b_idx = 0; b_idx < batch_count; b_idx++)
    {
        for (int row_idx = 0; row_idx < m; row_idx++)
        {
            for (int col_idx = 0; col_idx < k; col_idx++)
            {
                std::cout << host_A[col_idx + row_idx * lda + b_idx * k] << " ";
            }
            std::cout << std::endl;
        }
        std::cout << std::endl;
    }
    //print B
    std::cout << "B:" << std::endl;
    for (int b_idx = 0; b_idx < batch_count; b_idx++)
    {
        for (int col_idx = 0; col_idx < n; col_idx++)
        {
            for (int row_idx = 0; row_idx < k; row_idx++)
            {
                std::cout << host_B[row_idx + col_idx * ldb + b_idx * k] << " ";
            }
            std::cout << std::endl;
        }
        std::cout << std::endl;
    }

    //print C
    std::cout << "C:" << std::endl;
    for (int b_idx = 0; b_idx < batch_count; b_idx++)
    {
        for (int row_idx = 0; row_idx < m; row_idx++)
        {
            for (int col_idx = 0; col_idx < n; col_idx++)
            {
                std::cout << host_C[col_idx + row_idx * ldc + b_idx * n] << " ";
            }
            std::cout << std::endl;
        }
        std::cout << std::endl;
    }

    result = cutlass_strided_batched_sgemm(
        m, n, k, alpha, A, lda, batch_stride_A, B, ldb, batch_stride_B, C, ldc, batch_stride_C,
        beta, batch_count);
    if (result != cudaSuccess)
        return result;

    // copy device memory to host
    result = cudaMemcpy(result_C, C, count_C * sizeof(int32_t), cudaMemcpyDeviceToHost);
    if (result != cudaSuccess)
    {
        std::cerr << "cudaMemcpy result = " << result << std::endl;
        return result;
    }

    // free memory
    result = cudaFree(A);
    if (result != cudaSuccess)
    {
        std::cerr << "cudaFree result = " << result << std::endl;
        return result;
    }
    result = cudaFree(B);
    if (result != cudaSuccess)
    {
        std::cerr << "cudaFree result = " << result << std::endl;
        return result;
    }
    result = cudaFree(C);
    if (result != cudaSuccess)
    {
        std::cerr << "cudaFree result = " << result << std::endl;
        return result;
    }

    return result;
}

int main()
{

    cudaError_t result = cudaSuccess;
    for (bool use_array : {false, true})
    {
        result = run_batched_gemm(use_array);
        if (result == cudaSuccess)
        {
            std::cout << "Passed." << std::endl;
        }
        else
        {
            break;
        }
    }

    // Exit.
    return result == cudaSuccess ? 0 : -1;
}
HaoMyWorld commented 3 months ago

I have encountered a same issue. so, how to solve it?

github-actions[bot] commented 2 months ago

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.