traveller59 / spconv

Spatial Sparse Convolution Library
Apache License 2.0
1.85k stars 362 forks source link

convert output in tv::Tensor format to tensorrt plugin format #536

Open hygxy opened 1 year ago

hygxy commented 1 year ago

I am trying to implement a tensorRT plugin for a single sparse convolution layer and need to write the convolution result back to output pointer defined in the enqueue() interface from TensorRT. Right now, I am doing it this way:

        float* out1 = static_cast<float*>((void*)outputs[0]);
        int* out2 = static_cast<int*>((void*)outputs[1]);

        out1 = (float*)out_features.raw_data();
    out2 = (int*)out_inds.raw_data();

where outputs is defined in the interface of the enqueue API from TensorRT, out_features and out_inds are just variables defined in main.cu. It turns out both out1 and out2 contains all zero values even if I copy them back to CPU using cudaMemcpyAsync() followed by cudaStreamSynchronize().

What´s the correct way of converting results in tv::Tensor format to tensorRT plugin format?

FindDefinition commented 1 year ago
  1. you should use tv::from_blob to create tv::Tensor from tensorrt pointers.
    // assume you record tensor metas in configure function
    auto out1_ten = tv::from_blob(out1_shape, out1_dtype, 0 /*GPU*/);

    You need to use data_ptr to get points instead of c-style cast:

    auto ptr = out1_ten.data_ptr<const __half>(); // if ptr in from_blob is const, you must use const dtype to get const pointer 
  2. you need to split pair (indices) generation and conv into two plugins, the output of pair (indices) generation can be reused by more than one conv layers, the out_inds should be a input of a conv layer.
hygxy commented 1 year ago

since the enqueue API has the interface int enqueue(...,void* const *outputs,...), if I implement as follows:

outputs[0] = out_features.data_ptr<const float>();
outputs[1] = out_inds.data_ptr<const int32_t>();

I get the following compilation error, even with const float and const int32_t

error: assignment of read-only location ‘* outputs’   outputs[0] = out_features.data_ptr<const float>();
error: assignment of read-only location ‘*(outputs + 8)’  outputs[1] = out_inds.data_ptr<const int32_t>();
FindDefinition commented 1 year ago

tensorrt outputs are allocated once when engine created, so you don't need to assign a new ptr, just use outputs to create tv::Tensor by tv::from_blob you should check tensorrt plugin tutorial or example first (IPluginV2DynamicExt), you can write a simple plugin to test all plugin functions.

hygxy commented 1 year ago

I have now another question: should the pointer passed to from_blob() as the first parameter be aligned? if yes, how many bytes should it be?

I calculated the workspace size(originally in main.cu :int workspace_size = SpconvOps::get_indice_gen_workspace_size) in the getWorkspaceSizeAPI that every tensorrt plugin must implement. Then I replace all tv::empty() with tv::from_blob() and self-increment the workspace pointer. For example:

//tv::Tensor pair_fwd_padded = tv::empty({KV, pair_fwd_size_padded}, tv::int32, 0);
tv::Tensor pair_fwd_padded = tv::from_blob(workspace, {KV, pair_fwd_size_padded}, tv::int32, 0);
workspace += KV * pair_fwd_size_padded * sizeof(tv::int32);

However, I get the following error:

error: 1: [defaultAllocator.cpp::deallocate::35] Error Code 1: Cuda Runtime (misaligned address)
error: 1: [cudaResources.cpp::~ScopedCudaStream::47] Error Code 1: Cuda Runtime (misaligned address)
error: 1: [cudaResources.cpp::~ScopedCudaEvent::24] Error Code 1: Cuda Runtime (misaligned address)
error: 1: [cudaResources.cpp::~ScopedCudaEvent::24] Error Code 1: Cuda Runtime (misaligned address)
error: 1: [cudaResources.cpp::~ScopedCudaEvent::24] Error Code 1: Cuda Runtime (misaligned address)
error: 1: [defaultAllocator.cpp::deallocate::35] Error Code 1: Cuda Runtime (misaligned address)
FindDefinition commented 1 year ago
  1. you should use get_indice_gen_tensors_from_workspace instead of manage workspace by your self.
  2. If the spatial shape is too large (trigger use_int64_hash_k), the misaligned may occur. you can change hash_size = int({SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE} * max_act_out_in_theory); in get_indice_gen_workspace_size and get_indice_gen_tensors_from_workspace to hash_size = tv::align_up(int({SPCONV_DIRECT_TABLE_HASH_SIZE_SCALE} * max_act_out_in_theory), 2); to avoid this problem.
  3. you need to find out which kernel cause misaligned address by CUDA_LAUNCH_BLOCKING=1.
hygxy commented 1 year ago

Thanks for your advice! I noticed that convolution with kernel shape 1*1*1 is not supported in libspconv yet. (while in the python version implemented with torch.mm already), is that correct?

If yes, which implementation do you suggest when implementing "torch.mm" in C++/cuda?

FindDefinition commented 1 year ago

if you use tensorrt, you can use matmul layer in tensorrt (skip sparse conv when construct tensorrt network), no need to implement a new one. By the way, torch use cuBLAS to implement torch.mm. if you need mm in c++, you can implement it by cuBLASLt, here is official example

superpigforever commented 1 year ago

Thanks for your advice! I noticed that convolution with kernel shape 1*1*1 is not supported in libspconv yet. (while in the python version implemented with torch.mm already), is that correct?

If yes, which implementation do you suggest when implementing "torch.mm" in C++/cuda?

Hi, May I ask how do you convert plugin in inputs to tv tensor? Do I need to cast it in any way?

ArseniuML commented 8 months ago

// static_num_act_in is just out_inds_num_limit of previous conv layer. // for regular conv, the input tensor has static shape, we should save a CPU // variable of real num_act_out. here we just use num_act_in. int real_num_act_in = real_num_voxels;

@FindDefinition how do I can pass CPU variable real_num_act_in from one tensorrt layer to another?