csarofeen / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
http://pytorch.org
Other
26 stars 7 forks source link

Vector load from SMEM #1391

Closed naoyam closed 2 years ago

naoyam commented 2 years ago

Vector load from SMEM seems to have a problem.

TEST_F(NVFuserTest, FusionSmemVectorize_CUDA) {
  Fusion fusion;
  FusionGuard fg(&fusion);

  auto tv0 = makeContigTensor(1);
  fusion.addInput(tv0);

  auto tv1 = set(tv0);
  auto tv2 = set(tv1);
  auto tv3 = set(tv2);
  fusion.addOutput(tv3);

  tv1->setMemoryType(MemoryType::Shared);

  tv3->split(-1, 4);
  TransformPropagator::from(tv3);

  tv1->computeAt(tv3, -2);

  tv2->axis(-1)->parallelize(ParallelType::Vectorize);

  FusionExecutor fe;
  fe.compileFusion(&fusion);

  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
  at::manual_seed(0);
  // Fail when the size is not divisible by the vector length
  auto t0 = at::randn({199}, options);
  auto cg_outputs = fe.runFusion({t0});

  auto ref = t0;

  testValidate(&fusion, cg_outputs, {t0}, {ref}, __LINE__, __FILE__);
}

The generated kernel:

__global__ void kernel1(Tensor<float, 1> T0, Tensor<float, 1> T3) {
  NVFUSER_DEFINE_MAGIC_ZERO
  __shared__ float T1[4];
  #pragma unroll 1
  for(nvfuser_index_t ki49 = 0; ki49 < (ceilDiv(T0.size[0], 4)); ++ki49) {
    #pragma unroll
    for(nvfuser_index_t ki59 = 0; ki59 < 4; ++ki59) {
      if ((((ki49 * 4) + (ki59 + nvfuser_zero)) < T0.size[0])) {
        T1[ki59]
           = T0[(((ki49 * 4) + (ki59 + nvfuser_zero)) * 1)];
      }
    }
    NVFUSER_UPDATE_MAGIC_ZERO
    __barrier_sync(0);
    float T2[4];
    if ((((ki49 * 4) + 3) < T0.size[0])) {
      *reinterpret_cast<Array<float, 4>*>(&T2[0]) = *reinterpret_cast<Array<float, 4>*>(&T1[0]);
    }
    #pragma unroll
    for(nvfuser_index_t ki51 = 0; ki51 < 4; ++ki51) {
      if ((((ki49 * 4) + (ki51 + nvfuser_zero)) < T0.size[0])) {
        T3[(((ki49 * 4) + (ki51 + nvfuser_zero)) * 1)]
           = T2[ki51];
      }
    }
    NVFUSER_UPDATE_MAGIC_ZERO
    __barrier_sync(0);
  }
}

Test output:

C++ exception with description "aten_output_tensor.allclose( fusion_output_tensor.to(aten_output_tensor.dtype()), tolerance_values.second, tolerance_values.first)INTERNAL ASSERT FAILED at "../test/cpp/jit/test_gpu_validator.h":420, please report a bug to PyTorch.

Validation error in output 0 on line 19724 in file ../test/cpp/jit/test_gpu.cpp.
  Detected abs error of: 1.54733
    absolute tolerance was set to 1.51992e-06
    and relative tolerance set to 2.23704e-06

The vectorized load from T1 to T2 works fine when T0.size[0] is divisible by 4, but otherwise, the load is not done for the last remaining elements.

This problem doesn't seem like a new problem, but I don't exactly remember it's supposed to have been fixed or it's still a known problem.

rdspring1 commented 2 years ago

We assumed for ParallelType::Vectorization that the inner-most dimension is evenly divisible by the vector size. It should fail at this runtime check.

ParallelType::MisalignedVectorization should handle the remaining elements.

naoyam commented 2 years ago

Updated the test kernel with its output. As far as I see, it is not detected by the validation. I'll look into it.

rdspring1 commented 2 years ago

The runtime check only looks at the input and output tensors. If the tensor is an intermediate / shared memory, it isn't checked.