loolzaaa / faster-rcnn-pytorch

A PyTorch implementation of Faster R-CNN
MIT License
17 stars 5 forks source link

Is it possible to change the roi pooling to 3D? #10

Open lihaolin88 opened 9 months ago

lihaolin88 commented 9 months ago

Hello, I modified the roi_pool_kernel.cu and made it accept 3D input, but I'm not very familiar with the cuda code, can anyone helps me to check if I made some mistakes? Very appreciate!

My input shape is (B, C, H, W, D), roi shape is (num_of_roi, 7) (the order of roi shape is (label, min_width, min_depth, min_height, max_width, max_depth, max_height)) And the output I expect is: (num_of_roi, C, pool_size, pool_size, pool_size) (I don't know why GitHub break my code to multiple parts, sorry for the inconvenient)

`

include <torch/extension.h>

include <THC/THCAtomics.cuh>

include "cuda_helpers.h"

template global void RoIPoolForward( const int nthreads, const T input, const T spatial_scale, const int channels, const int height, const int width, const int depth, const int pooled_height, const int pooled_width, const int pooled_depth, const T rois, T output, int argmax_data) { CUDA_1D_KERNEL_LOOP(index, nthreads) { // (n, c, ph, pw) is an element in the pooled output int pw = index % pooled_width; int ph = (index / pooled_width) % pooled_height; int pd = (index / (pooled_widthpooled_height)) % pooled_depth; int c = (index / (pooled_width pooled_height pooled_depth)) % channels; int n = index / (pooled_width pooled_height pooled_depth channels);

const T* offset_rois = rois + n * 7;
int roi_batch_ind = offset_rois[0];
int roi_start_w = round(offset_rois[1] * (284/62));//spatial_scale);  //for spatial need to change
int roi_start_h = round(offset_rois[3] * (266/60));//spatial_scale);  //different side need different number
int roi_start_d = round(offset_rois[2] * (316/124));//spatial_scale);
int roi_end_w = round(offset_rois[4] * (284/62));//spatial_scale);
int roi_end_h = round(offset_rois[6] * (266/60));//spatial_scale);
int roi_end_d = round(offset_rois[5] * (316/124));//spatial_scale);

// Force malformed ROIs to be 1x1
int roi_width = max(roi_end_w - roi_start_w + 1, 1);
int roi_height = max(roi_end_h - roi_start_h + 1, 1);
int roi_depth = max(roi_end_d - roi_start_d + 1, 1);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
T bin_size_d = static_cast<T>(roi_depth) / static_cast<T>(pooled_depth);

int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
int dstart = static_cast<int>(floor(static_cast<T>(pd) * bin_size_d));
int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));
int dend = static_cast<int>(ceil(static_cast<T>(pd + 1) * bin_size_d));

// Add roi offsets and clip to input boundaries
hstart = min(max(hstart + roi_start_h, 0), height);
hend = min(max(hend + roi_start_h, 0), height);
wstart = min(max(wstart + roi_start_w, 0), width);
wend = min(max(wend + roi_start_w, 0), width);
dstart = min(max(dstart + roi_start_d, 0), depth);
dend = min(max(dend + roi_start_d, 0), depth);
bool is_empty = (hend <= hstart) || (wend <= wstart) || (dend <= dstart);

// Define an empty pooling region to be zero
T maxval = is_empty ? 0 : -FLT_MAX;
// If nothing is pooled, argmax = -1 causes nothing to be backprop'd
int maxidx = -1;
const T* offset_input =
    input + (roi_batch_ind * channels + c) * height * width * depth;
for (int h = hstart; h < hend; ++h) {
    for (int w = wstart; w < wend; ++w) {
      for (int d = dstart; d < dend; ++d) {

        int input_index = d*width*height + h * width + w; //h*depth*width + w*depth + d; //
        if (offset_input[input_index] > maxval) {
          maxval = offset_input[input_index];
          maxidx = input_index;
        }
    }
  }
}
output[index] = maxval;
argmax_data[index] = maxidx;

} }

template global void RoIPoolBackward( const int nthreads, const T grad_output, const int argmax_data, const int channels, const int height, const int width, const int depth, const int pooled_height, const int pooled_width, const int pooled_depth, T grad_input, const T rois, const int n_stride, const int c_stride, const int h_stride, const int w_stride, const int d_stride) { CUDA_1D_KERNEL_LOOP(index, nthreads) { // (n, c, ph, pw) is an element in the pooled output int pw = index % pooled_width; int ph = (index / pooled_width) % pooled_height; int pd = (index / (pooled_widthpooled_height)) % pooled_depth; int c = (index / (pooled_width pooled_height pooled_depth)) % channels; int n = index / (pooled_width pooled_height pooled_depth channels); //int c = (index / pooled_width / pooled_height) % channels; //int n = index / pooled_width / pooled_height / channels;

const T* offset_rois = rois + n * 7;
int roi_batch_ind = offset_rois[0];
T* grad_input_offset =
    grad_input + ((roi_batch_ind * channels + c) * height * width * depth);

int output_offset = n * n_stride + c * c_stride;
const int* argmax_data_offset =
    argmax_data + (n * channels + c) * pooled_height * pooled_width* pooled_depth;
int argmax = argmax_data_offset[pd * pooled_height * pooled_width + ph * pooled_width + pw];

if (argmax != -1) {
  atomicAdd(
      grad_input_offset + argmax,
      static_cast<T>(
          grad_output[output_offset + ph * h_stride + pw * w_stride + pd * d_stride]));
}

} }

std::tuple<torch::Tensor, torch::Tensor> roi_pool_forward3d_cuda(const torch::Tensor& input, const torch::Tensor& rois, const float spatial_scale, const int output_size) { AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor"); AT_ASSERTM(rois.is_cuda(), "rois must be a CUDA tensor");

const int num_rois = rois.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int depth = input.size(4);

const int pooling_width = output_size;
const int pooling_height = output_size;
const int pooling_depth = output_size;

const auto total_size = num_rois * pooling_height * pooling_width * pooling_depth * channels;

auto output = torch::empty(
    {num_rois, channels, pooling_height, pooling_width, pooling_depth}, input.options());
auto argmax = torch::zeros(
    {num_rois, channels, pooling_height, pooling_width, pooling_depth},
    input.options().dtype(torch::kInt));

const dim3 grid(std::min((total_size + 512 - 1) / 512, 4*4096));
const dim3 block(512);

if (output.numel() == 0) {
    return std::make_tuple(output, argmax);
}

AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "RoIPool_forward", [&] {
    RoIPoolForward<scalar_t><<<grid, block>>>(
        total_size,
        input.contiguous().data_ptr<scalar_t>(),
        spatial_scale,
        channels,
        height,
        width,
        depth,
        pooling_width,
        pooling_height,
        pooling_depth,
        rois.contiguous().data_ptr<scalar_t>(),
        output.data_ptr<scalar_t>(),
        argmax.data_ptr<int>());
});

return std::make_tuple(output, argmax);

}

torch::Tensor roi_pool_backward3d_cuda(const torch::Tensor& grad, const torch::Tensor& argmax, const torch::Tensor& input_size, const torch::Tensor& rois) { // Check if input tensors are CUDA tensors AT_ASSERTM(grad.is_cuda(), "grad must be a CUDA tensor"); AT_ASSERTM(rois.is_cuda(), "rois must be a CUDA tensor"); AT_ASSERTM(argmax.is_cuda(), "argmax must be a CUDA tensor");

auto input_size_a = input_size.accessor<int,1>();
const int batch_size = input_size_a[0];
const int channels = input_size_a[1];
const int height = input_size_a[2];
const int width = input_size_a[3];
const int depth = input_size_a[4];

const int num_rois = argmax.size(0);

const int pooling_width = argmax.size(2);
const int pooling_height = argmax.size(3);
const int pooling_depth = argmax.size(4);

const auto total_size = num_rois * pooling_height * pooling_width * pooling_depth * channels;

auto grad_input =
    torch::zeros({batch_size, channels,  width, depth, height}, grad.options());

const dim3 grid(std::min((total_size + 512 - 1) / 512, 4*4096));
const dim3 block(512);

// handle possibly empty gradients
if (grad.numel() == 0) {
    return grad_input;
}

// get stride values to ensure indexing into gradients is correct.
int n_stride = grad.stride(0);
int c_stride = grad.stride(1);
int h_stride = grad.stride(2);
int w_stride = grad.stride(3);
int d_stride = grad.stride(4);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.scalar_type(), "RoIPool_backward", [&] {
    RoIPoolBackward<scalar_t><<<grid, block>>>(
        grad.numel(),
        grad.data_ptr<scalar_t>(),
        argmax.contiguous().data_ptr<int>(),
        channels,
        height,
        width,
        depth,
        pooling_width,
        pooling_height,
        pooling_depth,
        grad_input.data_ptr<scalar_t>(),
        rois.contiguous().data_ptr<scalar_t>(),
        n_stride,
        c_stride,
        h_stride,
        w_stride,
        d_stride);
});

return grad_input;

} ``