lix19937 / tensorrt-insight

Deep insight tensorrt, including but not limited to qat, ptq, plugin, triton_inference, cuda
12 stars 0 forks source link

reducesum+nonzero impl #18

Open lix19937 opened 5 months ago

lix19937 commented 5 months ago

#include <sys/time.h>
#include <time.h>
#include <iostream>
#include <thrust/copy.h>
#include <thrust/device_vector.h>
#include <thrust/functional.h>
#include <thrust/iterator/counting_iterator.h>

// https://forums.developer.nvidia.com/t/how-to-use-thrust-for-each-with-cuda-streams/177797/7
// https://forums.developer.nvidia.com/t/using-thrust-copy-if-with-a-parameter/119735/7
// https://forums.developer.nvidia.com/t/all-non-zero-element-indexes/164658
//     thrust::device_ptr<int> dev_ptr = thrust::device_pointer_cast(raw_ptr);
//     thrust::device_ptr<int> dev_ptr = thrust::device_malloc<int>(N);
//     int * raw_ptr = thrust::raw_pointer_cast(dev_ptr);

template <int THRESHOLD>
struct is_nonzero {
  __host__ __device__ bool operator()(const uchar4 x) {
    auto t = x.x + x.y + x.z + x.w;
    return t != THRESHOLD;
  }
};

#define USECPSEC 1000000ULL

/// us
unsigned long long dtime_usec(unsigned long long start) {
  timeval tv;
  gettimeofday(&tv, 0);
  return ((tv.tv_sec * USECPSEC) + tv.tv_usec) - start;
}

int main() {
  // this example computes indices for all the nonzero values in a sequence

  // sequence of zero and nonzero values   40000*4
  thrust::device_vector<uchar4> stencil(40000);
  stencil[0] = {0, 0, 0, 0};
  stencil[1] = {1, 0, 0, 0};
  stencil[2] = {1, 0, 0, 0};
  stencil[3] = {0, 0, 0, 0};
  stencil[4] = {0, 0, 0, 0};
  stencil[5] = {1, 0, 0, 0};
  stencil[6] = {0, 0, 0, 0};
  stencil[7] = {1, 0, 0, 0};

  // thrust::copy(stencil.begin(), stencil.end(), std::ostream_iterator<int>(std::cout, " "));
  // std::cout << "\n";

  // storage for the nonzero indices
  thrust::device_vector<int> indices(6400);

  // compute indices of nonzero elements
  typedef thrust::device_vector<int>::iterator IndexIterator;

  IndexIterator indices_end = thrust::copy_if(
      thrust::device,
      thrust::make_counting_iterator(0),
      thrust::make_counting_iterator(6400),
      stencil.begin(),
      indices.begin(),
      is_nonzero<0>());

  unsigned long long dt0 = dtime_usec(0);
  for (int i = 0; i < 1000; ++i) {
    // use make_counting_iterator to define the sequence [0, 8)
    indices_end = thrust::copy_if(
        thrust::device,
        thrust::make_counting_iterator(0),
        thrust::make_counting_iterator(6400),
        stencil.begin(),
        indices.begin(),
        is_nonzero<0>());
    // thrust::identity<int>());

    // indices now contains [1,2,5,7]

    // thrust::copy(indices.begin(), indices.end(), std::ostream_iterator<int>(std::cout, " "));
    // std::cout << "\n";
  }
  unsigned long long dt1 = dtime_usec(dt0);
  printf("dt %llu\n", dt1);
  return 0;
}

// nvcc -std=c++14 -arch=sm_86 -O2 ./test_nonzero.cu