csarofeen / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
http://pytorch.org
Other
26 stars 7 forks source link

SMEM mixed bcast/iter with thread binding fails to validate/produce correct code #1418

Open naoyam opened 2 years ago

naoyam commented 2 years ago

I'm not sure what the correct behavior should be, but this fusion currently fails at the validation in the beginning of lowering.

TEST_F(NVFuserTest, FusionBroadcastConcretization4_CUDA) {
  Fusion fusion;
  FusionGuard fg(&fusion);

  auto tv0 = makeSymbolicTensor(2);
  fusion.addInput(tv0);

  auto tv1 = sum(tv0, {1});
  auto tv2 = broadcast(tv1, {false, true});
  auto tv3 = add(tv2, tv0);
  fusion.addOutput(tv3);

  tv1->axis(1)->parallelize(ParallelType::TIDx);

  tv2->merge(0, 1);
  tv2->axis(0)->parallelize(ParallelType::TIDx);

  tv3->merge(0, 1);
  tv3->axis(0)->parallelize(ParallelType::TIDx);

  fusion.printMath();
  fusion.printKernel();
}
Inputs:
  T0_g[ iS0{i1}, iS1{i2} ], float
Outputs:
  T3_g[ ithreadIdx.x9{( i1 * i2 )} ], float

%kernel_math {
T1_l[ iS2{i1}, rthreadIdx.x3{i2} ] = reduction( T0_g[ iS0{i1}, iS1{i2} ], op = add, initial value = double(0) )
T2_l[ ithreadIdx.x8{( i1 * 1 )} ] = broadcast( T1_l[ iS2{i1}, rthreadIdx.x3{i2} ] )
T3_g[ ithreadIdx.x9{( i1 * i2 )} ]
   = T2_l[ ithreadIdx.x8{( i1 * 1 )} ]
   + T0_g[ iS0{i1}, iS1{i2} ];
}

unknown file: Failure
C++ exception with description "predicated_parallel_types.none()INTERNAL ASSERT FAILED at "../torch/csrc/jit/codegen/cuda/lower_validation.cpp":492, please report a bug to PyTorch. Invalid parallelization of tensor t2. The tensor is parallelized with threadIdx.x, but it's invalid to use the types as the tensor is also predicated with them., thread prd: threadIdx.x

The issue is not (just) the validation. Even if the validation is skipped, invalid code is generated:

__global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T3) {
  alignas(4) extern __shared__ char array[];
  void* shared_mem = array;
  float T1[T0.size[0]];
  #pragma unroll 1
  for(nvfuser_index_t i16 = 0; i16 < T0.size[0]; ++i16) {
    T1[i16] = 0;
  }
  #pragma unroll 1
  for(nvfuser_index_t i13 = 0; i13 < T0.size[0]; ++i13) {
    blockReduce<true, false, false>(
      T1[i13],
      T0[(i13 * T0.stride[0]) + (((nvfuser_index_t)threadIdx.x) * T0.stride[1])],
      [](float &a, float b) { a = a + b; },
      threadIdx,
      blockDim,
      static_cast<float*>(shared_mem),
      (((nvfuser_index_t)threadIdx.x) < T0.size[1]),
      float(0));
  }
  float T2[1];
  T2[0]
     = T1[(((nvfuser_index_t)threadIdx.x) / 1)];
  if ((((((nvfuser_index_t)threadIdx.x) < (T0.size[0] * T0.size[1])) && (((nvfuser_index_t)threadIdx.x) < (T0.size[0] * T0.size[1]))) && (((nvfuser_index_t)threadIdx.x) == 0))) {
    T3[(((nvfuser_index_t)threadIdx.x) * 1)]
      = T2[(((nvfuser_index_t)threadIdx.x) / T0.size[1])]
      + T0[((((nvfuser_index_t)threadIdx.x) / T0.size[1]) * T0.stride[0]) + ((((nvfuser_index_t)threadIdx.x) % T0.size[1]) * T0.stride[1])];
  }
}

Notice that there's no blockBroadcast. This is not due to the concrete broadcast domain PR (#1412). It's because merging broadcast and non-broadcast domains results in a non-broadcast domain, so T2 has no broadcast, and no broadcast runtime call is generated.

This problem only happens when a broadcast domain requires an actual parallel broadcast, which means the tensor must be predicated with the same parallel type as used on the broadcast domain.

csarofeen commented 2 years ago

blockBroadcast cannot support this type of use case. I would expect parallel validation should fail in this case instead of predicate check complaining here. This is a change in the parallelization strategy where T3 has a completely different parallel scheme than T2. The only way this can be supported is when T2 is in shared memory, then we should do the right thing, which we almost do when that check is removed:

__global__ void CUDAGeneratedKernel(Tensor<float, 2> T0, Tensor<float, 2> T3) {
  alignas(4) extern __shared__ char array[];
  unsigned offset = 0;
  void* shared_mem = array;
  offset += ((blockDim.x * blockDim.y * blockDim.z) * sizeof(float));
  float T1[T0.size[0]];
  #pragma unroll 1
  for(nvfuser_index_t i16 = 0; i16 < T0.size[0]; ++i16) {
    T1[i16] = 0;
  }
  offset = alignBufferSize(offset,4);
  float* T2 = reinterpret_cast<float*>(array + offset);
  offset += ((T0.size[0] * 1) * sizeof(float));
  #pragma unroll 1
  for(nvfuser_index_t i13 = 0; i13 < T0.size[0]; ++i13) {
    blockReduce<true, false, false>(
      T1[i13],
      T0[(i13 * T0.stride[0]) + (((nvfuser_index_t)threadIdx.x) * T0.stride[1])],
      [](float &a, float b) { a = a + b; },
      threadIdx,
      blockDim,
      static_cast<float*>(shared_mem),
      (((nvfuser_index_t)threadIdx.x) < T0.size[1]),
      float(0));
  }
  if (((((nvfuser_index_t)threadIdx.x) < (T0.size[0] * 1)) && (((nvfuser_index_t)threadIdx.x) == 0))) {
    T2[((nvfuser_index_t)threadIdx.x)]
       = T1[(((nvfuser_index_t)threadIdx.x) / 1)];
  }
  if ((((((nvfuser_index_t)threadIdx.x) < (T0.size[0] * T0.size[1])) && (((nvfuser_index_t)threadIdx.x) < (T0.size[0] * T0.size[1]))) && (((nvfuser_index_t)threadIdx.x) == 0))) {
    T3[(((nvfuser_index_t)threadIdx.x) * 1)]
      = T2[(((nvfuser_index_t)threadIdx.x) / T0.size[1])]
      + T0[((((nvfuser_index_t)threadIdx.x) / T0.size[1]) * T0.stride[0]) + ((((nvfuser_index_t)threadIdx.x) % T0.size[1]) * T0.stride[1])];
  }
  __barrier_sync(0);
}

The issues here are:

  1. There's a RAW race between writing T2 and its use in T3
    • I've been thinking that we should be detecting which tensors actually need a RAW protection in a build pass before the RAW pass. This type of pass would be really similar to the parallelization validation pass, and would just mark what type of communication is required to satisfy the parallelization scheme (smem or gmem communication).
  2. The predicate ((nvfuser_index_t)threadIdx.x) == 0 shouldn't exist on the T2 and T3 expressions.
csarofeen commented 2 years ago

This bug isn't high priority as we don't do parallelization schemes like this in practice, but it's not great that this produces such incorrect code.

csarofeen commented 2 years ago

Disabled test case is added for this issue in https://github.com/csarofeen/pytorch/pull/1412

naoyam commented 2 years ago

I think another thing we may need to think about is that once a broadcast domain is merged with a non-broadcast domain, it becomes a non-broadcast domain, i.e., IterDomain::isBroadcast() returns false. We often just look at leaf IDs and see if there's any broadcast (e.g., TensorDomain::hasBroadcast()), which would just return false even when the root domain has a broadcast.

What's not really clear to me is what the merged domain of a broadcast and a non-broadcast domain would really mean. It doesn't matter when the broadcast is trivial, e.g., it's on shared memory, not concretized or not predicated. Otherwise, I'm not sure what the right semantics would be. If we don't care as that won't be necessary, we might just want to disable merging of broadcast and non-broadcast domains when the broadcast is not trivial.