NVIDIA / cccl

CUDA Core Compute Libraries
Other
1k stars 121 forks source link

[Possible BUG]: compute-sanitizer initchecker reports uninitialized global memory reads from thrust::reduce_by_key #1790

Open ssadasivam1 opened 1 month ago

ssadasivam1 commented 1 month ago

Is this a duplicate?

Type of Bug

Something else

Component

Thrust

Describe the bug

compute-sanitizer initchecker reports uninitialized global memory reads from thrust::reduce_by_key (reproduced on CUDA 12.5)

======== Uninitialized __global__ memory read of size 4 bytes
=========     at void thrust::THRUST_200500_520_NS::cuda_cub::core::_kernel_agent<thrust::THRUST_200500_520_NS::cuda_cub::__reduce_by_key::ReduceByKeyAgent

It is possible that these warnings from compute-sanitizer are harmless and can be safely ignored. Please feel free to close this issue if you have verified that these are just artifacts of the underlying thrust/cub algorithm and can be ignored.

How to Reproduce

  1. Compile the reproducer below as: /usr/local/cuda/bin/nvcc -O3 --expt-relaxed-constexpr -I /path/to/cccl/thrust/ -I /path/to/cccl/cub -I /path/to/cccl/libcudacxx/include/ test.cu

  2. Run under compute-sanitizer initcheck to see that it reports uninitialized global memory reads. /usr/local/cuda/compute-sanitizer/compute-sanitizer --tool initcheck ./a.out

#include <thrust/iterator/discard_iterator.h>
#include <thrust/reduce.h>
#include <thrust/device_vector.h>

#include <iostream>
#include <vector>

struct point {
    int x;
    int y;
};

struct pointMax{
    __host__ __device__
    point operator()(const point& a, const point& b) const {
        return point{std::max(a.x,b.x), std::max(a.y,b.y)};
    }
};

void printVec(const std::vector<int>& vec, const std::string& name) {
    std::cout << "Printing vector: " << name << std::endl;
    std::cout << "Size of input = " << vec.size() << std::endl;

    for(size_t ii=0; ii < vec.size(); ++ii) {
        std::cout << vec[ii] << ",";
    }
    std::cout << std::endl;
}

void printPointVec(const std::vector<point>& vec, const std::string& name) {
    std::cout << "Printing vector: " << name << std::endl;
    std::cout << "Size of input = " << vec.size() << std::endl;

    for(size_t ii=0; ii < vec.size(); ++ii) {
        std::cout << "(" << vec[ii].x << "," << vec[ii].y << "), ";
    }
    std::cout << std::endl;
}

void testReduceByKey() {
    constexpr int N = 605;
    std::vector<int> ids(4*N);
    std::vector<point> coords(4*N);
    for(int ii = 0; ii < N; ++ii) {
        for(int jj = 0; jj < 4; ++jj)
            ids[4*ii+jj] = ii;

        coords[4*ii]   = point{0,0};
        coords[4*ii+1] = point{100,0};
        coords[4*ii+2] = point{50, 50};
        coords[4*ii+3] = point{0,100};
    }

    //printVec(ids, "ids");

    thrust::device_vector<int> d_ids(ids);
    thrust::device_vector<point> d_coords(coords);
    thrust::device_vector<point> d_ur(4*N);

    auto endIt = thrust::reduce_by_key(d_ids.begin(), d_ids.end(), 
            d_coords.begin(),
            thrust::make_discard_iterator(),
            d_ur.begin(),
            thrust::equal_to<int>(), pointMax{});
    auto numEle = endIt.second-d_ur.begin();
    std::cout << "Number of output values = " << numEle  << std::endl;
    d_ur.resize(numEle);

    std::vector<point> ur(N);
    cudaMemcpy(ur.data(), thrust::raw_pointer_cast(d_ur.data()),
            N*sizeof(point), cudaMemcpyDeviceToHost);
    cudaDeviceSynchronize();

    //printPointVec(ur, "ur");
}

int main() {
    testReduceByKey();

    return 0;
}

Expected behavior

initchecker should report 0 errors.

Reproduction link

No response

Operating System

No response

nvidia-smi output

+-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 555.44 Driver Version: 555.44 CUDA Version: 12.5 | |-----------------------------------------+------------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+========================+======================| | 0 NVIDIA A100 80GB PCIe On | 00000000:01:00.0 Off | 0 | | N/A 36C P0 67W / 300W | 1MiB / 81920MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+

NVCC version

nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2024 NVIDIA Corporation Built on Tue_May_14_02:18:54_PDT_2024 Cuda compilation tools, release 12.5, V12.5.54 Build cuda_12.5.r12.5/compiler.34287351_0

lilohuang commented 1 month ago

I can reproduce the bug through the code snippet provided by @ssadasivam1, I suspect there might be a bug in the thrust ReduceByKeyAgent kernel while loading the values array, although I haven't had a chance to investigate the details with cuda-gdb.

Interestingly, if we change the point structure to use a 64-bit point type:

struct point {
    int64_t x;
    int64_t y;
};

The NVIDIA compute-sanitizer reports an error like below,

Uninitialized __global__ memory read of size 8 bytes.

It's 8 bytes rather than 4 bytes.

However, if we change the values array type of reduce_by_key() to a primitive data type like int32_t or int64_t instead of struct point, the NVIDIA compute-sanitizer does not produce any error.

If it is indeed an uninitialized global memory read, this may lead to a process crash or non-deterministic results. It would be great to address these concerns.

CC @miscco, @gevtushenko, @alliepiper, @jrhemstad to provide their insights if there are any.

elstehle commented 1 month ago

Thank you for reporting the issue and providing a reproducer, @ssadasivam1. Thank you for reproducing the issue, @lilohuang.

After some initial investigation, it looks like the uninitialized reads reported by compute-sanitizer are on the padded tile states from the decoupled look-back. My current understanding is that the uninitialized reads of these values are not further utilized and, hence, this is nothing we would have to worry about.

I will convince myself that this is indeed the case and that all the reads reported by compute-sanitizer indeed fall within the padding area.

lilohuang commented 1 month ago

Thanks @elstehle for your quick response. If that's the case due to the decoupled look-back algorithm, may I know why there is a different result from the NVIDIA compute-sanitizer between the struct point { int x; int y; } and int64_t of the values array type, especially if they have the same size for each element?

It would be great to find a way to appease the compute-sanitizer so that CUDA programmers can trust the NVIDIA compute-sanitizer to identify any potential bugs while using the thrust/cub library. Thank you so much. 👍

elstehle commented 1 month ago

Thanks @elstehle for your quick response. If that's the case due to the decoupled look-back algorithm, may I know why there is a different result from the NVIDIA compute-sanitizer between the struct point { int x; int y; } and int64_t of the values array type, especially if they have the same size for each element?

That is, because we currently have specialized code paths that depend on the key and value types. As referenced below, we specialize for Traits<ValueT>::PRIMITIVE, which is basically true for the builtin types. In that case (e.g., int64_t), we have a single array that stores the tile status, as well as partial and inclusive prefix results. For that code path, we do initialize the padded area. For the "non-primitive" case, we have three separate arrays. Only the status array must be initialized for the padded area. Initializing the other two is not required and would incur some overhead. https://github.com/NVIDIA/cccl/blob/1b12fcab61a36afdb50930e0f67dbadb86a496bf/cub/cub/agent/single_pass_scan_operators.cuh#L852-L855

It would be great to find a way to appease the compute-sanitizer so that CUDA programmers can trust the NVIDIA compute-sanitizer to identify any potential bugs while using the thrust/cub library. Thank you so much. 👍

I will get a feeling for the overhead and think about it and will get back on this.