NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
268 stars 52 forks source link

Compilation failure in `StagedReduction` test #2846

Closed jacobhinkle closed 2 months ago

jacobhinkle commented 2 months ago

Currently the test PipelineTestStagedReduction.StagedReduction is failing. See

mpirun -np 1 build/test_multidevice --gtest_filter='PipelineTestStagedReduction.StagedReduction/Manual'

which fails to compile the generated kernel:

__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<float, 3, 3> T0, Tensor<float, 1, 1> T2, Tensor<float, 2, 2> T1, Tensor<int64_t, 1, 1> T5) {
  alignas(16) extern __shared__ char array[];
  void* shared_mem = array;
  NVFUSER_DEFINE_MAGIC_ZERO;
  nvfuser_index_t i0;
  i0 = (T0.alloc_stride[2LL] * ((nvfuser_index_t)threadIdx.x)) + (T0.alloc_stride[1LL] * ((nvfuser_index_t)blockIdx.x));
  nvfuser_index_t i1;
  i1 = 32LL * T0.alloc_stride[2LL];
  nvfuser_index_t i2;
  i2 = -64LL + ((nvfuser_index_t)threadIdx.x);
  bool b3;
  b3 = (((nvfuser_index_t)blockIdx.x) == 0LL) && (((nvfuser_index_t)threadIdx.x) == 0LL);
  // Allocate global tensor T1
  *(volatile float*)&T1[((nvfuser_index_t)blockIdx.x)] = 0.000000000e+00f;
  float T3[4LL];
  #pragma unroll
  for(nvfuser_index_t i4 = 0LL; i4 < 4LL; ++i4) {
    T3[i4] = 0.000000000e+00f;
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
  if (((((nvfuser_index_t)threadIdx.x) + 96LL) < 64LL)) {
    #pragma unroll
    for(nvfuser_index_t i4 = 0LL; i4 < 4LL; ++i4) {
      T3[i4]
        = T3[i4]
        + T0[(i0 + (i1 * (i4 + nvfuser_zero)))];
    }
  } else {
    #pragma unroll
    for(nvfuser_index_t i4 = 0LL; i4 < 4LL; ++i4) {
      nvfuser_index_t i5;
      i5 = i4 + nvfuser_zero;
      if ((i2 < (-(32LL * i5)))) {
        T3[i4]
          = T3[i4]
          + T0[(i0 + (i1 * i5))];
      }
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
  float T4[1LL];
  T4[0LL] = 0.000000000e+00f;
  #pragma unroll
  for(nvfuser_index_t i6 = 0LL; i6 < 4LL; ++i6) {
    T4[0LL]
      = T4[0LL]
      + T3[i6];
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
  blockReduce<true, false, false, true>(*(volatile float*)&T1[((nvfuser_index_t)blockIdx.x)], T4[0LL], [](float &a, float b) { a = a + b; }, static_cast<float*>(shared_mem), true, float(0.000000000e+00f));
  // Allocate global tensor T5
  grid_sync::sync<true, false, false, true, true>(T5[index_utils::maskedOffset<false, true, true>(blockIdx, gridDim)], index_utils::maskedSize<true, false, false>(gridDim));
  #pragma unroll
  for(nvfuser_index_t i7 = 0LL; i7 < 8LL; ++i7) {
    nvfuser_index_t i8;
    i8 = i7 + nvfuser_zero;
    if (b3) {
      T2[i8]
         = *(volatile float*)&T1[i8];
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
}

This fails with

CUDA NVRTC compile error: __tmp_kernel_none_f0_c0_r0_g0.cu(10749): error: no instance of overloaded function "<unnamed>::blockReduce" matches the argument list
            argument types are: (volatile float, float, lambda [](float &, float)->void, float *, __nv_bool, float)
    blockReduce<true, false, false, true>(*(volatile float*)&T1[((nvfuser_index_t)blockIdx.x)], T4[0LL], [](float &a, float b) { a = a + b; }, static_cast<float*>(shared_mem), true, float(0.000000000e+00f));
    ^
__tmp_kernel_none_f0_c0_r0_g0.cu(5373): note #3327-D: candidate function template "<unnamed>::blockReduce<X_REDUCE,Y_REDUCE,Z_REDUCE,Aligned,T,Func>(T &, const T &, Func, T *, __nv_bool, T)" failed deduction
  __device__ void blockReduce(
                  ^
__tmp_kernel_none_f0_c0_r0_g0.cu(5227): note #3322-D: number of parameters of function template "<unnamed>::blockReduce<X_REDUCE,Y_REDUCE,Z_REDUCE,Aligned,T,Func>(T &, const T &, Func, T *, __nv_bool, __nv_bool, T)" does not match the call
  __device__ void blockReduce(
                  ^

1 error detected in the compilation of "__tmp_kernel_none_f0_c0_r0_g0.cu".

I think this is the signature we're targeting in this case, but we have a volatile float for the first argument: https://github.com/NVIDIA/Fuser/blob/6dba9a837deb14b82bb87db8c8e2a07fb02cad60/runtime/block_reduction.cu#L163-L177

The other case, PipelineTestStagedReduction.StagedReduction/Automatic fails to segment this fusion:

Fusion IR after pre-segmenter optimization passes:
Inputs:
  T0_g[ bdeviceIdx.x0{1}, iS1{8}, iS2{64} ] (DeviceMesh{0}), float
Outputs:
  T2_g[ iS6{8} ] (DeviceMesh{0}), float

%kernel_math {
T1_l[ bdeviceIdx.x3{1}, iS4{8}, rS5{64} ] (DeviceMesh{0})
   = reduction( T0_g[ bdeviceIdx.x0{1}, iS1{8}, iS2{64} ] (DeviceMesh{0}), op = add, initial value = float(0), allreduce = false )
T3_l[ bS10{1}, iS11{8} ] (DeviceMesh{0})
   = SegmenterSet( T1_l[ bdeviceIdx.x3{1}, iS4{8}, rS5{64} ] (DeviceMesh{0}) )
T2_g[ iS6{8} ] (DeviceMesh{0})
   = squeeze( T3_l[ bS10{1}, iS11{8} ] (DeviceMesh{0}) )
} // %kernel_math

That seems like a separate issue but might indicate a common issue with this test.

wujingyue commented 2 months ago

The failure in PipelineTestStagedReduction.StagedReduction is a duplicate of #2257. It only happens with -np 1 not multiple GPUs, so CI hasn't complaint about it.

What's the problem with PipelineTestStagedReduction.StagedReduction/Automatic?

jacobhinkle commented 2 months ago

duplicate of #2257

Ah thanks! I somehow missed that before filing this issue. Yes the Manual failure is definitely the same as #2257. The Automatic failure is a failure to segment: a segment with a single expression that's a SegmentSet is proposed and none of the schedulers accepts it.

**Segmenter** Considering fusion:
T3_g[ bS10{1}, iS11{8} ] (DeviceMesh{0})
   = SegmenterSet( T1_g[ bdeviceIdx.x3{1}, iS4{8}, rS5{64} ] (DeviceMesh{0}) )

Scheduler _expr_eval_ ***rejected*** because : Fusion is resharding.
Scheduler _no_op_ ***rejected*** because : output has a concrete dimension
Scheduler _matmul_ ***rejected*** because : No matmul patterns were found
Scheduler _reduction_ ***rejected*** because : Fusion is resharding.
Scheduler _transpose_ ***rejected*** because : Fusion is resharding.
Scheduler _pointwise_ ***rejected*** because : Fusion is resharding.
Scheduler _inner_persistent_ ***rejected*** because : Fusion is resharding.
Scheduler _outer_persistent_ ***rejected*** because : Fusion is resharding.
Scheduler _inner_outer_persistent_ ***rejected*** because : Fusion is resharding.
unknown file: Failure