yuenshome / yuenshome.github.io

https://yuenshome.github.io
MIT License
84 stars 15 forks source link

conv2d opencl #72

Open ysh329 opened 5 years ago

ysh329 commented 5 years ago

mace conv2d-1x1

mace/conv_2d_1x1_buffer.cl at 35bbc0462a55b337e6a35fe1b93b1b46fc334e2d · XiaoMi/mace https://github.com/XiaoMi/mace/blob/35bbc0462a55b337e6a35fe1b93b1b46fc334e2d/mace/ops/opencl/cl/conv_2d_1x1_buffer.cl

#define CONVERT_STR(value, type) convert_##type((value))

#define CONVERT_TO(value, type) CONVERT_STR(value, type)
#define CONVERT(value) CONVERT_TO(value, DATA_TYPE)
#define CONVERT4(value) CONVERT_TO(value, DATA_TYPE4)

__kernel void conv2d(BUFFER_OUT_OF_RANGE_PARAMS
                     GLOBAL_WORK_GROUP_SIZE_DIM2
                     __global IN_DATA_TYPE *padded_input,
                     __global IN_DATA_TYPE *filter,
#ifdef BIAS
                     __global IN_DATA_TYPE *bias,
#endif
                     __private const int in_height,
                     __private const int in_width,
                     __private const int in_chan,
                     __private const int filter_in_chan,
                     __private const int out_height,
                     __private const int out_width,
                     __private const int out_chan,
                     __private const int stride_h,
                     __private const int stride_w,
                     __private const float relux_max_limit,
                     __private const float leakyrelu_coefficient,
                     __global OUT_DATA_TYPE *output) {
  const int out_wc_blk_idx = get_global_id(0);
  const int out_hb_idx = get_global_id(1);

#ifndef NON_UNIFORM_WORK_GROUP
  if (out_wc_blk_idx >= global_size_dim0 ||
      out_hb_idx >= global_size_dim1) {
    return;
  }
#endif
  const int out_chan_blk = (out_chan + 3) >> 2;

  const int out_width_blk_idx = out_wc_blk_idx / out_chan_blk;
  const int out_chan_blk_idx =
      out_wc_blk_idx - mul24(out_width_blk_idx, out_chan_blk);

  const int batch_idx = out_hb_idx / out_height;
  const int out_height_idx = out_hb_idx - mul24(batch_idx, out_height);
  const int out_width_idx = out_width_blk_idx << 1;
  const int out_chan_idx = out_chan_blk_idx << 2;

  const int in_height_idx = mul24(out_height_idx, stride_h);
  const int in_width_idx = mul24(out_width_idx, stride_w);
  const int strided_chan = mul24(in_chan, stride_w);

#ifdef BIAS
  DATA_TYPE4 out0 = CONVERT4(vload4(0, bias + out_chan_idx));
  DATA_TYPE4 out1 = out0;
#else
  DATA_TYPE4 out0 = 0;
  DATA_TYPE4 out1 = 0;
#endif

  int in_offset = mul24(mad24(mad24(batch_idx, in_height, in_height_idx),
      in_width, in_width_idx), in_chan);
  int filter_offset = mul24(out_chan_blk_idx, filter_in_chan) << 2;
  DATA_TYPE4 in0, in1;
  DATA_TYPE4 w0, w1, w2, w3;

  for (int in_chan_idx = 0; in_chan_idx < in_chan; in_chan_idx += 4) {
    w0 = CONVERT4(vload4(0, filter + filter_offset));
    w1 = CONVERT4(vload4(0, filter + filter_offset + 4));
    w2 = CONVERT4(vload4(0, filter + filter_offset + 8));
    w3 = CONVERT4(vload4(0, filter + filter_offset + 12));

    in0 = CONVERT4(vload4(0, padded_input + in_offset));
    in1 = CONVERT4(vload4(0, padded_input + in_offset + strided_chan));

    out0 = mad((DATA_TYPE4)(in0.x), w0, out0);
    out0 = mad((DATA_TYPE4)(in0.y), w1, out0);
    out0 = mad((DATA_TYPE4)(in0.z), w2, out0);
    out0 = mad((DATA_TYPE4)(in0.w), w3, out0);

    out1 = mad((DATA_TYPE4)(in1.x), w0, out1);
    out1 = mad((DATA_TYPE4)(in1.y), w1, out1);
    out1 = mad((DATA_TYPE4)(in1.z), w2, out1);
    out1 = mad((DATA_TYPE4)(in1.w), w3, out1);

    filter_offset += 16;
    in_offset += 4;
  }

#if  defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
  out0 = do_activation(out0, relux_max_limit, leakyrelu_coefficient);
  out1 = do_activation(out1, relux_max_limit, leakyrelu_coefficient);
#endif

  int out_offset = mad24(mad24(mad24(batch_idx, out_height, out_height_idx),
      out_width, out_width_idx), out_chan, out_chan_idx);

#define WRITE_OUTPUT(i) \
  if (out_chan_idx + 4 > out_chan) {           \
    const int diff = out_chan - out_chan_idx;  \
    switch(diff) {                             \
      case 3:                                  \
        output[out_offset + 2] = CONVERT_TO(out##i.z, OUT_DATA_TYPE);     \
      case 2:                                  \
        output[out_offset + 1] = CONVERT_TO(out##i.y, OUT_DATA_TYPE);     \
      case 1:                                  \
        output[out_offset] = CONVERT_TO(out##i.x, OUT_DATA_TYPE);         \
    }                                          \
    CHECK_OUT_OF_RANGE_FOR_BUFFER(out_offset + diff - 1); \
  } else {                                     \
    VSTORE4(CONVERT_TO(out##i, OUT_DATA_TYPE4), output, out_offset);   \
  }

  WRITE_OUTPUT(0);
  if (out_width_idx + 1 >= out_width) return;
  out_offset += out_chan;
  WRITE_OUTPUT(1);
#undef WRITE_OUTPUT

}

conv2d_1x1 op

MaceStatus Conv2d1x1(OpContext *context,
                     cl::Kernel *kernel,
                     const Tensor *padded_input,
                     const Tensor *filter,
                     const Tensor *bias,
                     const int *strides,
                     const ActivationType activation,
                     const float relux_max_limit,
                     const float leakyrelu_coefficient,
                     const bool input_changed,
                     Tensor *output,
                     StatsFuture *future) {
  const index_t batch = output->dim(0);
  const index_t height = output->dim(1);
  const index_t width = output->dim(2);
  const index_t channel = output->dim(3);

  const index_t in_height = padded_input->dim(1);
  const index_t in_width = padded_input->dim(2);

  auto runtime = context->device()->gpu_runtime()->opencl_runtime();
  MACE_OUT_OF_RANGE_DEFINITION;

  if (kernel->get() == nullptr) {
    std::set<std::string> built_options;
    MACE_OUT_OF_RANGE_CONFIG;
    MACE_NON_UNIFORM_WG_CONFIG;
    std::string kernel_name = MACE_OBFUSCATE_SYMBOL("conv2d");
    built_options.emplace("-Dconv2d=" + kernel_name);
    std::string data_dt = DtToCLDt(padded_input->dtype());
    built_options.emplace("-DIN_DATA_TYPE=" + data_dt);
    built_options.emplace("-DOUT_DATA_TYPE=" + DtToCLDt(output->dtype()));
    built_options.emplace("-DDATA_TYPE=" + DtToCLDt(DT_FLOAT));
    built_options.emplace(bias != nullptr ? "-DBIAS" : "");
    switch (activation) {
      case NOOP:
        break;
      case RELU:
        built_options.emplace("-DUSE_RELU");
        break;
      case RELUX:
        built_options.emplace("-DUSE_RELUX");
        break;
      case TANH:
        built_options.emplace("-DUSE_TANH");
        break;
      case SIGMOID:
        built_options.emplace("-DUSE_SIGMOID");
        break;
      case LEAKYRELU:
        built_options.emplace("-DUSE_LEAKYRELU");
        break;
      default:
        LOG(FATAL) << "Unknown activation type: " << activation;
    }

    MACE_RETURN_IF_ERROR(runtime->BuildKernel("conv_2d_1x1_buffer",
                                              kernel_name,
                                              built_options, kernel));
  }

  const uint32_t gws[2] = {static_cast<uint32_t>(
                               RoundUpDiv4(channel) *
                                   RoundUpDiv<index_t>(width, 2)),
                           static_cast<uint32_t>(height * batch)};

  MACE_OUT_OF_RANGE_INIT(*kernel);
  if (input_changed) {
    uint32_t idx = 0;
    MACE_BUFF_OUT_OF_RANGE_SET_ARGS(*kernel, output->size());
    MACE_SET_2D_GWS_ARGS(*kernel, gws);
    kernel->setArg(idx++, *(padded_input->opencl_buffer()));
    kernel->setArg(idx++, *(filter->opencl_buffer()));
    if (bias != nullptr) {
      kernel->setArg(idx++, *(bias->opencl_buffer()));
    }
    kernel->setArg(idx++, static_cast<int32_t>(in_height));
    kernel->setArg(idx++, static_cast<int32_t>(in_width));
    kernel->setArg(idx++, static_cast<int32_t>(padded_input->dim(3)));
    kernel->setArg(idx++,
                   static_cast<int32_t>(filter->buffer_shape()[3]));
    kernel->setArg(idx++, static_cast<int32_t>(height));
    kernel->setArg(idx++, static_cast<int32_t>(width));
    kernel->setArg(idx++, static_cast<int32_t>(channel));
    kernel->setArg(idx++, strides[0]);
    kernel->setArg(idx++, strides[1]);
    kernel->setArg(idx++, relux_max_limit);
    kernel->setArg(idx++, leakyrelu_coefficient);
    kernel->setArg(idx++, *(output->opencl_buffer()));
  }

  std::string tuning_key =
      Concat("conv2d_1x1_buffer", output->dim(0), output->dim(1),
             output->dim(2), output->dim(3));
  std::vector<uint32_t> lws = {16, 4, 0};
  MACE_RETURN_IF_ERROR(TuningOrRun2DKernel(runtime, *kernel, tuning_key, gws,
                                           lws, future));
  MACE_OUT_OF_RANGE_VALIDATION;
  return MaceStatus::MACE_SUCCESS;
}
ysh329 commented 5 years ago

#define CONVERT_STR(value, type) convert_##type((value))

#define CONVERT_TO(value, type) CONVERT_STR(value, type)
#define CONVERT(value) CONVERT_TO(value, DATA_TYPE)
#define CONVERT4(value) CONVERT_TO(value, DATA_TYPE4)

__kernel void conv2d(
                     __global CL_DTYPE *padded_input,
                     __global CL_DTYPE *filter,
                     __global CL_DTYPE *bias,
                     __private const int in_height,
                     __private const int in_width,
                     __private const int in_chan,
                     __private const int filter_in_chan,
                     __private const int out_height,
                     __private const int out_width,
                     __private const int out_chan,
                     __private const int stride_h,
                     __private const int stride_w,
                     __global CL_DTYPE *output) {
  const int out_wc_blk_idx = get_global_id(0);
  const int out_hb_idx = get_global_id(1);

#ifndef NON_UNIFORM_WORK_GROUP
  if (out_wc_blk_idx >= global_size_dim0 ||
      out_hb_idx >= global_size_dim1) {
    return;
  }
#endif
  const int out_chan_blk = (out_chan + 3) >> 2;

  const int out_width_blk_idx = out_wc_blk_idx / out_chan_blk;
  const int out_chan_blk_idx =
      out_wc_blk_idx - mul24(out_width_blk_idx, out_chan_blk);

  const int batch_idx = out_hb_idx / out_height;
  const int out_height_idx = out_hb_idx - mul24(batch_idx, out_height);
  const int out_width_idx = out_width_blk_idx << 1;
  const int out_chan_idx = out_chan_blk_idx << 2;

  const int in_height_idx = mul24(out_height_idx, stride_h);
  const int in_width_idx = mul24(out_width_idx, stride_w);
  const int strided_chan = mul24(in_chan, stride_w);

#ifdef BIAS
  DATA_TYPE4 out0 = convert_float4(vload4(0, bias + out_chan_idx));
  DATA_TYPE4 out1 = out0;
#else
  DATA_TYPE4 out0 = 0;
  DATA_TYPE4 out1 = 0;
#endif

  int in_offset = mul24(mad24(mad24(batch_idx, in_height, in_height_idx),
      in_width, in_width_idx), in_chan);
  int filter_offset = mul24(out_chan_blk_idx, filter_in_chan) << 2;
  DATA_TYPE4 in0, in1;
  DATA_TYPE4 w0, w1, w2, w3;

  for (int in_chan_idx = 0; in_chan_idx < in_chan; in_chan_idx += 4) {
    w0 = CONVERT4(vload4(0, filter + filter_offset));
    w1 = CONVERT4(vload4(0, filter + filter_offset + 4));
    w2 = CONVERT4(vload4(0, filter + filter_offset + 8));
    w3 = CONVERT4(vload4(0, filter + filter_offset + 12));

    in0 = CONVERT4(vload4(0, padded_input + in_offset));
    in1 = CONVERT4(vload4(0, padded_input + in_offset + strided_chan));

    out0 = mad((DATA_TYPE4)(in0.x), w0, out0);
    out0 = mad((DATA_TYPE4)(in0.y), w1, out0);
    out0 = mad((DATA_TYPE4)(in0.z), w2, out0);
    out0 = mad((DATA_TYPE4)(in0.w), w3, out0);

    out1 = mad((DATA_TYPE4)(in1.x), w0, out1);
    out1 = mad((DATA_TYPE4)(in1.y), w1, out1);
    out1 = mad((DATA_TYPE4)(in1.z), w2, out1);
    out1 = mad((DATA_TYPE4)(in1.w), w3, out1);

    filter_offset += 16;
    in_offset += 4;
  }

#if  defined(USE_RELU) || defined(USE_LEAKYRELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
  out0 = do_activation(out0, relux_max_limit, leakyrelu_coefficient);
  out1 = do_activation(out1, relux_max_limit, leakyrelu_coefficient);
#endif

  int out_offset = mad24(mad24(mad24(batch_idx, out_height, out_height_idx),
      out_width, out_width_idx), out_chan, out_chan_idx);

#define WRITE_OUTPUT(i) \
  if (out_chan_idx + 4 > out_chan) {           \
    const int diff = out_chan - out_chan_idx;  \
    switch(diff) {                             \
      case 3:                                  \
        output[out_offset + 2] = CONVERT_TO(out##i.z, OUT_DATA_TYPE);     \
      case 2:                                  \
        output[out_offset + 1] = CONVERT_TO(out##i.y, OUT_DATA_TYPE);     \
      case 1:                                  \
        output[out_offset] = CONVERT_TO(out##i.x, OUT_DATA_TYPE);         \
    }                                          \
    CHECK_OUT_OF_RANGE_FOR_BUFFER(out_offset + diff - 1); \
  } else {                                     \
    VSTORE4(CONVERT_TO(out##i, OUT_DATA_TYPE4), output, out_offset);   \
  }

  WRITE_OUTPUT(0);
  if (out_width_idx + 1 >= out_width) return;
  out_offset += out_chan;
  WRITE_OUTPUT(1);
#undef WRITE_OUTPUT

}
ysh329 commented 5 years ago

conv2d-1x1 caffe-opencl

OpenCL-caffe/base_conv_layer.cpp at 41805c86a29de1e7bf1c12bd1d7488a5fed3f997 · amd/OpenCL-caffe https://github.com/amd/OpenCL-caffe/blob/41805c86a29de1e7bf1c12bd1d7488a5fed3f997/src/caffe/layers/base_conv_layer.cpp#L294-L298

template <typename Dtype>
void BaseConvolutionLayer<Dtype>::forward_gpu_gemm(const Dtype* input,
    const Dtype* weights, Dtype* output, bool skip_im2col) {
  const Dtype* col_buff = input;
  if (!is_1x1_) {
    if (!skip_im2col) {
      conv_im2col_gpu(input, col_buffer_.mutable_gpu_data());
    }
    col_buff = col_buffer_.gpu_data();
  } 
  for (int g = 0; g < group_; ++g) {
     caffe_gpu_gemm < Dtype > (&(amdDevice.CommandQueue), CblasNoTrans, CblasNoTrans, conv_out_channels_
            / group_, conv_out_spatial_dim_, kernel_dim_ / group_, (Dtype) 1., weights, weight_offset_
            * g, col_buff, is_1x1_ * bottom_offset_ + col_offset_ * g, (Dtype) 0., output, top_offset_
            + output_offset_ * g);
  }

}
ysh329 commented 5 years ago

pytorch conv2d 1x1

pytorch/conv_op_impl.h at ac8d1a1f7618a01a7504818a53d4ff0eb468de5d · pytorch/pytorch https://github.com/pytorch/pytorch/blob/ac8d1a1f7618a01a7504818a53d4ff0eb468de5d/caffe2/operators/conv_op_impl.h#L363-L443

template <typename T, class Context>
bool ConvOp<T, Context>::Run1x1ConvOnDeviceWithOrderNCHW(
    const int N,
    const int C,
    const int HxW,
    const int M,
    const T* X,
    const T* filter,
    const T* bias,
    T* Y) {
  const int G = group_;
  if (G == 1) {
    math::GemmStridedBatched<T, Context>(
        CblasNoTrans,
        CblasNoTrans,
        N,
        M,
        HxW,
        C,
        1.0f,
        filter,
        0,
        X,
        C * HxW,
        0.0f,
        Y,
        M * HxW,
        &context_);
  } else {
    const int batch_size = N * G;
    const int D_X = C / G;
    const int D_Y = M / G;
    const int X_stride = D_X * HxW;
    const int W_stride = D_Y * D_X;
    const int Y_stride = D_Y * HxW;
    std::vector<const T*> X_ptr(N * G);
    std::vector<const T*> W_ptr(N * G);
    std::vector<T*> Y_ptr(N * G);
    for (int i = 0; i < N; ++i) {
      for (int j = 0; j < G; ++j) {
        const int index = i * G + j;
        X_ptr[index] = X + index * X_stride;
        W_ptr[index] = filter + j * W_stride;
        Y_ptr[index] = Y + index * Y_stride;
      }
    }
    math::GemmBatched<T, Context>(
        CblasNoTrans,
        CblasNoTrans,
        batch_size,
        D_Y,
        HxW,
        D_X,
        1.0f,
        W_ptr.data(),
        X_ptr.data(),
        0.0f,
        Y_ptr.data(),
        &context_);
  }
  if (bias != nullptr) {
    const T* bias_multiplier_data = bias_multiplier_.template data<T>();
    math::GemmStridedBatched<T, Context>(
        CblasNoTrans,
        CblasNoTrans,
        N,
        M,
        HxW,
        1,
        1.0f,
        bias,
        0,
        bias_multiplier_data,
        0,
        1.0f,
        Y,
        M * HxW,
        &context_);
  }
  return true;
}

pytorch/conv_op_impl.h

at ac8d1a1f7618a01a7504818a53d4ff0eb468de5d · pytorch/pytorch https://github.com/pytorch/pytorch/blob/ac8d1a1f7618a01a7504818a53d4ff0eb468de5d/caffe2/operators/conv_op_impl.h#L20-L188

template <typename T, class Context>
bool ConvOp<T, Context>::RunOnDeviceWithOrderNCHW() {
  const auto& X = Input(INPUT);
  const auto& filter = Input(FILTER);
  auto* Y = Output(0);
  const int N = X.dim32(0);
  const int C = X.dim32(1);
  const int G = group_;
  CAFFE_ENFORCE_EQ(X.dim(), filter.dim());
  const int M = filter.dim32(0);
  CAFFE_ENFORCE_EQ(
      C,
      filter.dim32(1) * G,
      "Convolution op: input channels does not match: # of input channels ",
      C,
      " is not equal to kernel channels * group: ",
      filter.dim32(1),
      "*",
      G);
  CAFFE_ENFORCE_EQ(
      M % G, 0, "The number of output channels is not divisible by group.");

  int kernel_size = 1;
  for (std::size_t i = 0; i < kernel_.size(); ++i) {
    CAFFE_ENFORCE_EQ(filter.dim32(i + 2), kernel_[i]);
    kernel_size *= kernel_[i];
  }
  ConvPoolOpBase<Context>::SetOutputSize(X, Y, M);

  if (N == 0) {
    Y->template mutable_data<T>();
    return true;
  }

  const vector<int> X_dims = GetDims(X);
  const vector<int> Y_dims = GetDims(*Y);
  const int X_HxW = X.numel() / (N * C);
  const int Y_HxW = Y->numel() / (N * M);
  const vector<int> img_shape(X.sizes().cbegin() + 1, X.sizes().cend());
  vector<int> buffer_shape(Y_dims.size() + 1);
  buffer_shape[0] = C * kernel_size;
  std::copy(Y_dims.cbegin(), Y_dims.cend(), buffer_shape.begin() + 1);

  const int buffer_size = C * kernel_size * Y_HxW;

  // The dimension of each kernel
  const int kernel_dim = C / G * kernel_size;
  const int X_stride = C * X_HxW;
  const int Y_stride = M * Y_HxW;
  const int filter_stride = filter.numel() / G;

  // The col buffer is stored in CHW order as well - kernel_dim, and the height
  // and width.
  const T* X_data = X.template data<T>();
  const T* filter_data = filter.template data<T>();
  const T* bias_data = nullptr;
  if (InputSize() == 3) {
    const auto& bias = Input(BIAS);
    CAFFE_ENFORCE_EQ(bias.dim(), 1);
    CAFFE_ENFORCE_EQ(bias.dim32(0), M);
    bias_data = bias.template data<T>();
    ConvPoolOpBase<Context>::template SetBiasMultiplier<T>(
        Y_HxW, &bias_multiplier_);
  }
  T* Y_data = Y->template mutable_data<T>();

  // Shortcut for 1x1 conv.
  if (kernel_size == 1 && !HasPad() && !HasStride()) {
    return Run1x1ConvOnDeviceWithOrderNCHW(
        N, C, X_HxW, M, X_data, filter_data, bias_data, Y_data);
  }

  const auto func = [&](Tensor* col_buffer) {
    col_buffer->Resize(buffer_shape);
    T* col_buffer_data = col_buffer->template mutable_data<T>();
    // Im2Col, followed by gemm.
    for (int image_id = 0; image_id < N; ++image_id) {
      if (kernel_.size() == 2) {
        math::Im2Col<T, Context, StorageOrder::NCHW>(
            C,
            X_dims[0],
            X_dims[1],
            kernel_h(),
            kernel_w(),
            dilation_h(),
            dilation_w(),
            pad_t(),
            pad_l(),
            pad_b(),
            pad_r(),
            stride_h(),
            stride_w(),
            X_data,
            col_buffer_data,
            &context_);
      } else {
        math::Im2ColNd<T, Context, StorageOrder::NCHW>(
            kernel_.size(),
            C * X_HxW,
            buffer_size,
            img_shape.data(),
            buffer_shape.data(),
            kernel_.data(),
            stride_.data(),
            dilation_.data(),
            pads_.data(),
            X_data,
            col_buffer_data,
            &context_);
      }
      // Weight term
      if (G == 1) {
        math::Gemm<T, Context>(
            CblasNoTrans,
            CblasNoTrans,
            M,
            Y_HxW,
            kernel_dim,
            1.0f,
            filter_data,
            col_buffer_data,
            0.0f,
            Y_data,
            &context_);
      } else {
        math::GemmStridedBatched<T, Context>(
            CblasNoTrans,
            CblasNoTrans,
            G,
            M / G,
            Y_HxW,
            kernel_dim,
            1.0f,
            filter_data,
            filter_stride,
            col_buffer_data,
            buffer_size / G,
            0.0f,
            Y_data,
            Y_stride / G,
            &context_);
      }
      if (bias_data != nullptr) {
        // Bias term can be carried out outside the group definition
        // to be efficient.
        math::Gemm<T, Context>(
            CblasNoTrans,
            CblasNoTrans,
            M,
            Y_HxW,
            1,
            1.0f,
            bias_data,
            bias_multiplier_.template data<T>(),
            1.0f,
            Y_data,
            &context_);
      }
      X_data += X_stride;
      Y_data += Y_stride;
    }
  };
  if (FLAGS_caffe2_force_shared_col_buffer || shared_buffer_) {
    runWithSharedBuffer<Context>(ws_, func);
  } else {
    func(&col_buffer_);
  }
  return true;
}
ysh329 commented 5 years ago

pytorch/math_gpu.cu at 161566187c278bb0e715a60016e7e74ec94444df · pytorch/pytorch https://github.com/pytorch/pytorch/blob/161566187c278bb0e715a60016e7e74ec94444df/caffe2/utils/math_gpu.cu#L766-L824

template <>
CAFFE2_CUDA_EXPORT void GemmStridedBatched<float, CUDAContext>(
    const CBLAS_TRANSPOSE trans_A,
    const CBLAS_TRANSPOSE trans_B,
    const int batch_size,
    const int M,
    const int N,
    const int K,
    const float alpha,
    const float* A,
    const int A_stride,
    const float* B,
    const int B_stride,
    const float beta,
    float* C,
    const int C_stride,
    CUDAContext* context,
    TensorProto::DataType math_type) {
#if __CUDACC_VER_MAJOR__ < 8 && !defined(__HIP_PLATFORM_HCC__)
  // loop over matrices in the batch
  for (int i = 0; i < batch_size; ++i) {
    Gemm<float, CUDAContext>(
        trans_A, trans_B, M, N, K, alpha, A, B, beta, C, context, math_type);
    A += A_stride;
    B += B_stride;
    C += C_stride;
  }
#else
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  const int lda = (trans_A == CblasNoTrans) ? K : M;
  const int ldb = (trans_B == CblasNoTrans) ? N : K;
  const int ldc = N;
  const cublasOperation_t cu_trans_A =
      (trans_A == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  const cublasOperation_t cu_trans_B =
      (trans_B == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
  CUBLAS_ENFORCE(
      cublasSetPointerMode(context->cublas_handle(), CUBLAS_POINTER_MODE_HOST));
  CUBLAS_ENFORCE(cublasSgemmStridedBatched(
      context->cublas_handle(),
      cu_trans_B,
      cu_trans_A,
      N,
      M,
      K,
      &alpha,
      B,
      ldb,
      B_stride,
      A,
      lda,
      A_stride,
      &beta,
      C,
      ldc,
      C_stride,
      batch_size));
#endif
}