NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
5.63k stars 962 forks source link

[BUG] Non-ZFILL CP_ASYNC copy trait cause buffer overwriting with predicated copy #1716

Open cloudhan opened 2 months ago

cloudhan commented 2 months ago

Describe the bug

#include "cute/tensor.hpp"

using namespace cute;

__global__ void kernel(int *gmem) {
  int tid = threadIdx.x;
  gmem[tid * 4 + 0] = tid * 4 + 0;
  gmem[tid * 4 + 1] = tid * 4 + 1;
  gmem[tid * 4 + 2] = tid * 4 + 2;
  gmem[tid * 4 + 3] = tid * 4 + 3;
  __syncthreads();

  __shared__ int smem[128];
  smem[tid * 4 + 0] = -1;
  smem[tid * 4 + 1] = -1;
  smem[tid * 4 + 2] = -1;
  smem[tid * 4 + 3] = -1;
  __syncthreads();

  auto g = make_tensor(make_gmem_ptr(gmem), Layout<Shape<_1, _128>>{});
  auto s = make_tensor(make_smem_ptr(smem), Layout<Shape<_1, _128>>{});

  __syncthreads();
  if (thread0()) {
    print_tensor(s);
  }

  using CopyAtom = Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<uint128_t>, int>;
  // using CopyAtom = Copy_Atom<UniversalCopy<uint128_t>, int>;

  auto tiled_copy = make_tiled_copy(CopyAtom{}, Layout<_32>{}, Layout<_4>{});
  auto thr_copy = tiled_copy.get_thread_slice(tid);
  auto tG = thr_copy.partition_S(coalesce(g));
  auto tS = thr_copy.partition_D(coalesce(s));
  auto p = make_tensor<bool>(size<1>(tG));
  for (int i = 0; i < size(p); i++) {
    p(i) = tid < 2;  // only activate two threads
  }

  copy_if(tiled_copy, p, tG, tS);
  cp_async_fence();
  cp_async_wait<0>();
  __syncthreads();

  if (thread0()) {
    print_tensor(s);
  }
}

int main() {
  int *buffer;
  cudaMalloc(&buffer, 32 * sizeof(int));
  kernel<<<1, 32>>>(buffer);
  cudaDeviceSynchronize();

  return cudaGetLastError() == cudaSuccess;
}

nvcc cp_async.cu -Iinclude -std=c++20 --expt-relaxed-constexpr -gencode=arch=compute_80,code=sm_80 && ./a.out

produces

smem_ptr[32b](0x7f6985000000) o (_1,_128):(_0,_1):
   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1   -1
smem_ptr[32b](0x7f6985000000) o (_1,_128):(_0,_1):
    0    1    2    3    4    5    6    7    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0

Expected behavior

When _ZFILL is not requested, those predicated out values should not be touched. This might cause

  1. buffer overwritten when dst buffer is not large enough, but have another __shared__ buffer declare later and is large enough.
  2. memory fault when oob write.

This is because

https://github.com/NVIDIA/cutlass/blob/f93a69134ec8259fd235f220209d6f8734a5cb06/include/cute/atom/copy_traits_sm80.hpp#L77-L82

re-dispatch to _ZFILL trait silently and this will generally cause very very very subtle bug when the user is expecting an async version of Copy_Atom<UniversalCopy<uint128_t>, T> as a simple substitute, but not the ignore-src behavior!

Since the _ZFILL variants exists, this implicit behavior should be removed.

The only workaround is to replace copy_if as follows

// assuming only one iter mode
if (p(i)) {
  copy(tiled_copy, tG(_, i), tS(_, i));
}

Environment details (please complete the following information): f93a69134ec8259fd235f220209d6f8734a5cb06

ccecka commented 2 months ago

I agree this is a bug, those .with functions should only be on the ZFILL variants, it appears.

cloudhan commented 2 months ago

@thakkarV Do NV has any plan to address this issue? This one is quite subtle and can be easily made wrong. And it will be a breaking change for the users that relying on the exotic behavior. The late it is addressed, the wide it will impact the user.

thakkarV commented 2 months ago

@yzhaiustc can we please add this to the docket for 3.6 fixes?

thakkarV commented 2 months ago

@cloudhan will fix in 3.6 which will land in a month or so. is that ok?

yzhaiustc commented 2 months ago

@yzhaiustc can we please add this to the docket for 3.6 fixes?

sure. thanks :-)

github-actions[bot] commented 1 month ago

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.