pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.02k stars 6.93k forks source link

I hope you can provide 3D align-pooling layer directly. Pray. #1678

Open huang229 opened 4 years ago

huang229 commented 4 years ago

I see that 2D align-pooling has been implemented, but there is no 3D one. I do target detection of 3D voxels. The input data format is (batch, channel, depth, height, w). I customized a 3D align-pooling Op. However, there is a big problem in the deployment. The custom OP needs to be registered in the torchscript. In this way, the model can be exported normally, and finally the model can be loaded on the C + + side.I tried for months, but it didn't work. I hope you can provide 3D align-pooling layer directly. Pray. This is the function: torchvision.ops.roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1)

fmassa commented 4 years ago

There is currently no plans to add support for this operator, but that doesn't mean we wouldn't be willing to add it. Do you have references for it being used in some research papers already?

huang229 commented 4 years ago

@fmassa I don't like writing papers.Other authors have related papers, such as: http://openaccess.thecvf.com/content_CVPR_2019/papers/Cui_ToothNet_Automatic_Tooth_Instance_Segmentation_and_Identification_From_Cone_Beam_CVPR_2019_paper.pdf In addition, a lot of medical image processing is volume data. I see a lot of papers on medical imaging. When they detect 3D objects, they all use the incomplete Faster-RCNN model, that is to say, they get the output results in RPN layer, without the align-pooiling layer and the later layer, which is very similar to SSD model. I guess they may have difficulty in implementing their own 3Dalign-pooing OP, let alone deployment. I think they should have the same expectation as me, and hope to have a 3Dalign-pooilng OP to help them verify some small ideas in medical image processing.

huang229 commented 4 years ago

The following code is the 3Dalign-pooling OP that I customized under pytorch, but only the forward part is shown here. Because of the limited code level, it takes a lot of time to train. When I deploy to the C++ side, I need to register to torchscript. I tried to deploy for several months, but failed to register all the time.

include <ATen/ATen.h>

include <ATen/cuda/CUDAContext.h>

include<torch/script.h>

include <THC/THC.h>

include <THC/THCAtomics.cuh>

include <THC/THCDeviceUtils.cuh>

global void Align3D_forward_kernel( const int nthreads, const float restrict image_data, const int restrict img_dim, const int restrict crop_dim, const int channel_num, const int image_depth, const int image_height, const int image_width, const float restrict boxes_data, const int num_boxes, float restrict corps_data, const int crop_depth, const int crop_height, const int crop_width, const int method ) { const int index = blockIdx.x blockDim.x + threadIdx.x;

if (index < nthreads)
{
    int n = index / (crop_depth * crop_height * crop_width * channel_num);
    int tmpn = index % (crop_depth * crop_height * crop_width * channel_num);
    int ch = tmpn / (crop_depth * crop_height * crop_width);
    int tmpc = tmpn % (crop_depth * crop_height * crop_width);

    int crop_d = tmpc / (crop_height * crop_width);
    int tmpd = tmpc % (crop_height * crop_width);
    int crop_h = tmpd / crop_width;
    int crop_w = tmpd % crop_width;

    //const float *box = boxes_data;
    const float *box = boxes_data + n * 7;
    const int bn = int(box[0]);
    const float z1 = box[1];
    const float y1 = box[2];
    const float x1 = box[3];
    const float z2 = box[4];
    const float y2 = box[5];
    const float x2 = box[6];

    if (crop_d >= image_depth || crop_h >= image_height || crop_w >= crop_width)
    {
        printf("index overstep the boundary! %s\n");
    }

    const float depth_scale = (z2 - z1) / (crop_depth - 1);
    const float height_scale = (y2 - y1) / (crop_height - 1);
    const float width_scale = (x2 - x1) / (crop_width - 1);

    //计算浮点坐标和整数坐标
    const float in_z = (crop_d) * depth_scale + z1;
    const int bottom_z = floorf(in_z);
    const int top_z = ceilf(in_z);

    const float in_y = (crop_h) * height_scale + y1;
    const int bottom_y = floorf(in_y);
    const int top_y = ceilf(in_y);

    const float in_x = (crop_w) * width_scale + x1;
    const int bottom_x = floorf(in_x);
    const int top_x = ceilf(in_x);

    //3d双线性插值
    float value_interp = 0;
    //计算浮点坐标的周围8个整数坐标的坐标值
    float coord_0 = image_data[bn * img_dim[0] + ch * img_dim[1] + top_z * img_dim[2] + top_y * image_width + top_x];
    float coord_1 = image_data[bn * img_dim[0] + ch * img_dim[1] + top_z * img_dim[2] + top_y * image_width + bottom_x];
    float coord_2 = image_data[bn * img_dim[0] + ch * img_dim[1] + top_z * img_dim[2] + bottom_y * image_width + top_x];
    float coord_3 = image_data[bn * img_dim[0] + ch * img_dim[1] + top_z * img_dim[2] + bottom_y * image_width + bottom_x];
    float coord_4 = image_data[bn * img_dim[0] + ch * img_dim[1] + bottom_z * img_dim[2] + top_y * image_width + top_x];
    float coord_5 = image_data[bn * img_dim[0] + ch * img_dim[1] + bottom_z * img_dim[2] + top_y * image_width + bottom_x];
    float coord_6 = image_data[bn * img_dim[0] + ch * img_dim[1] + bottom_z * img_dim[2] + bottom_y * image_width + top_x];
    float coord_7 = image_data[bn * img_dim[0] + ch * img_dim[1] + bottom_z * img_dim[2] + bottom_y * image_width + bottom_x];

    //printf("coord_0: %f, coord_1: %f, coord_2: %f, coord_3: %f, coord_4: %f, coord_5: %f\n", coord_0, coord_1, coord_2, coord_3, coord_4, coord_5);
    //两两坐标计算插值点的坐标值
    float value01 = (in_x - bottom_x) * (coord_0 - coord_1) + coord_1;
    float value23 = (in_x - bottom_x) * (coord_2 - coord_3) + coord_3;
    float value45 = (in_x - bottom_x) * (coord_4 - coord_5) + coord_5;
    float value67 = (in_x - bottom_x) * (coord_6 - coord_7) + coord_7;

    float value0123 = (in_y - bottom_y) * (value01 - value23) + value23;
    float value4567 = (in_y - bottom_y) * (value45 - value67) + value67;

    value_interp = (in_z - bottom_z) *(value0123 - value4567) + value4567;

    int index_ = n * crop_dim[0] + ch * crop_dim[1] + crop_d * crop_dim[2] + crop_h * crop_width + crop_w;
    corps_data[index_] = value_interp;
}

}

void Align3D_forward_gpu( torch::Tensor image, const int channel_num, const int image_depth, const int image_height, const int image_width, torch::Tensor boxes, const int num_boxes, torch::Tensor corps, const int crop_depth, const int crop_height, const int crop_width, const int method ) { const int image_elements_dhwc = image_height image_width image_depth channel_num; const int image_elements_dhw = image_height image_width image_depth; const int image_elements_hw = image_width image_height;

const int crop_elements_dhwc = crop_depth * crop_height * crop_width * channel_num;
const int crop_elements_dhw = crop_height * crop_width * crop_depth;
const int crop_elements_hw = crop_width * crop_height;

torch::Tensor img_dim = torch::tensor({ image_elements_dhwc, image_elements_dhw, image_elements_hw }, c10::kInt).cuda();
torch::Tensor crop_dim = torch::tensor({ crop_elements_dhwc, crop_elements_dhw, crop_elements_hw }, c10::kInt).cuda();
const int output_size = num_boxes * crop_depth * crop_height * crop_width * channel_num;
dim3 grid(THCCeilDiv((long)output_size, 512L));
dim3 block(512);

if (corps.numel() == 0) {
    THCudaCheck(cudaGetLastError());
}

Align3D_forward_kernel <<<grid, block >>>(
    output_size,
    image.data<float>(),
    img_dim.data<int>(),
    crop_dim.data<int>(),
    channel_num,
    image_depth,
    image_height,
    image_width,
    boxes.data<float>(),
    num_boxes,
    corps.data<float>(),
    crop_depth,
    crop_height,
    crop_width,
    method
    );
cudaError_t cudaStatus = cudaGetLastError();
if (cudaStatus != cudaSuccess) 
{
    fprintf(stderr, "CropAndResizePerBox_kernel launch failed: %s\n", cudaGetErrorString(cudaStatus));

}

}

Hellcat1005 commented 4 years ago

I hope too. I am doing medical image detection in a medical image company. Medical images are usually 3d images. We really need 3d operations like nms_3d, roi_align_3d and so on.

huang229 commented 4 years ago

@Hellcat1005 Adding a custom OP is simple, but it may not be efficient, resulting in time-consuming training, and deployment will also have an impact. I think you can try adding custom OP first. https://github.com/pytorch/extension-script

lukasfolle commented 4 years ago

In general, are there any plans to provide support for 3D data or volumes? Especially for torchvision.transforms this would be really nice. However, I assume the backend e.g. PILLOW needs to support 3D data for that, right?

fmassa commented 4 years ago

@lukasfolle we are working on making the transforms in torchvision support Tensors of multiple dimensions, but it might not directly translate to 3d volumes because they assume a different structure, while we currently assume that we have 2d images which are possibly batched