yuenshome / yuenshome.github.io

https://yuenshome.github.io
MIT License
81 stars 15 forks source link

OpenCL elementwise_add #68

Open ysh329 opened 4 years ago

ysh329 commented 4 years ago
template <typename dtype>
void elementwise_compute_ref(const operators::ElementwiseParam& param,
                             const std::string elt_type,
                             const std::string act_type) {
  const dtype* x_data = param.X->data<const dtype>();
  const dtype* y_data = param.Y->data<const dtype>();
  dtype* out_data = param.Out->mutable_data<dtype>();
  auto x_dims = param.X->dims();
  auto y_dims = param.Y->dims();
  int axis = param.axis;
  if (axis < 0) {
    axis = x_dims.size() - y_dims.size();
  }
  int batch = 1;
  int channels = 1;
  int num = 1;
  for (int i = 0; i < axis; ++i) {
    batch *= x_dims[i];
  }
  for (int i = 0; i < y_dims.size(); ++i) {
    channels *= y_dims[i];
  }
  for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) {
    num *= x_dims[i];
  }
  // do elementwise add/sub/max...
  if (elt_type == "add") {
    for (int i = 0; i < batch; ++i) {
      for (int j = 0; j < channels; ++j) {
        int offset = (i * channels + j) * num;
        const dtype* din_ptr = x_data + offset;
        const dtype diny_data = y_data[j];
        dtype* dout_ptr = out_data + offset;
        for (int k = 0; k < num; ++k) {
          *dout_ptr = *din_ptr + diny_data;
          dout_ptr++;
          din_ptr++;
        }
      }
    }
  } else if (elt_type == "sub") {
    for (int i = 0; i < batch; ++i) {
      for (int j = 0; j < channels; ++j) {
        int offset = (i * channels + j) * num;
        const dtype* din_ptr = x_data + offset;
        const dtype diny_data = y_data[j];
        dtype* dout_ptr = out_data + offset;
        for (int k = 0; k < num; ++k) {
          *dout_ptr = *din_ptr - diny_data;
          dout_ptr++;
          din_ptr++;
        }
      }
    }
  } else {
    LOG(FATAL) << "unsupported Elementwise type: " << elt_type;
  }
  // do activation relu/sigmod...
  if (act_type.size() > 0) {
    if (act_type == "relu") {
      for (int i = 0; i < batch; ++i) {
        for (int j = 0; j < channels; ++j) {
          dtype* dout_ptr = out_data + (i * channels + j) * num;
          for (int k = 0; k < num; ++k) {
            *dout_ptr = *dout_ptr > 0.0f ? *dout_ptr : 0.0f;
            dout_ptr++;
          }
        }
      }
    } else {
      LOG(FATAL) << "unsupported Activation type: " << elt_type;
    }
  }
}
ysh329 commented 4 years ago
__kernel void elementwise_add_buffer(__global const float* x_data, __global const float* y_data, __global float* out_data,
                  const int batch, const int channels, const int num) {

  const int c = get_global_id(0); // c: [0, channels)
  const int b = get_global_id(1); // b: [0, batch)

  if ((c >= channels) || (b >= batch)) {
    return;
  }

  const int offset = (b * channels + c) * num;

  const float* din_ptr = x_data + offset;
  const float diny_data = y_data[c];
  float* dout_ptr = out_data + offset;

  for (int n = 0; n < num; ++n) { // n: [0, h*w)
    *dout_ptr = *din_ptr + diny_data;
    ++dout_ptr;
    ++din_ptr;
  }
}
ysh329 commented 4 years ago
[==========] Running 8 tests from 1 test case.
[----------] Global test environment set-up.
[----------] 8 tests from cl_test
[ RUN      ] cl_test.runtime_test
[F  7/18  9:42:47.716 ...opencl/cl_wrapper.cc InitFunctions:164] Check failed: clCreateCommandQueueWithProperties_ != nullptr: Cannot load clCreateCommandQueueWithProperties!