microsoft / tutel

Tutel MoE: An Optimized Mixture-of-Experts Implementation
MIT License
694 stars 84 forks source link

[Bug]The function func_fwd is calculated inconsistent on the cpu and gpu #197

Closed starkhu closed 1 year ago

starkhu commented 1 year ago

The tutel/fasst_dispatch.py file calls the function func_fwd, but the cpu and cuda implementations of this function are inconsistent. Here is the code implementation of this function on cpu and cuda:

// cpu code
// code on tutel/custom/custom_kernel.cpp
    for (int i = 0; i < samples; ++i) {
      if (locations1_s[i] < capacity && indices1_s[i] >= 0) {
        for (int j = 0; j < hidden; ++j) {
          dispatched_input[(indices1_s[i] * capacity + locations1_s[i]) * (hidden) + j] += gates1_s[i] * reshaped_input[i * (hidden) + j];
        }
      }
    }

//cuda code
// code on tutel/jit_kernels/sparse.py
      for (int i = blockIdx.x; i < samples; i += gridDim.x)
          if (locations1_s[i] < capacity && indices1_s[i] >= 0) {
              #pragma unroll
              for (int j = threadIdx.x; j < hidden; j += 1024)
                  dispatched_input[(indices1_s[i] * capacity + locations1_s[i]) * (hidden) + j] = gates1_s[i] * reshaped_input[i * (hidden) + j];
          }

The dispatched_input is computed differently on the two implementations. On the cpu, dispatched_input += gates1_sreshaped_input, and on cuda, dispatched_input=gates1_sreshaped_input.

ghostplant commented 1 year ago

Thanks! But this is not a bug, because indices1_s[i] * capacity + locations1_s[i] is always unique, so that = and += actually do the same, you can change CPU's += to =, although it doesn't make any differences in producing results.