k2-fsa / k2

FSA/FST algorithms, differentiable, with PyTorch compatibility.
https://k2-fsa.github.io/k2
Apache License 2.0
1.11k stars 213 forks source link

error: conflicting declaration ‘typedef struct CUevent_st* #1279

Open binhtranmcs opened 5 months ago

binhtranmcs commented 5 months ago

Currently, I am trying to implement a custom k2 tritonserver backend, but i get this compilation error:

In file included from /usr/local/cuda/include/builtin_types.h:59,
                 from /usr/local/cuda/include/cuda_runtime_api.h:149,
                 from /home/cpu13266/binhtt4/research/asr-end2end-triton/triton_k2/cmake-build-debug/_deps/repo-backend-src/include/triton/backend/backend_common.h:48,
                 from /home/cpu13266/binhtt4/research/asr-end2end-triton/triton_k2/src/libk2.cc:16:
/usr/local/cuda/include/driver_types.h:2792:43: error: conflicting declaration ‘typedef enum cudaError cudaError_t’
 2792 | typedef __device_builtin__ enum cudaError cudaError_t;
      |                                           ^~~~~~~~~~~
In file included from /usr/local/include/k2/csrc/context.h:45,
                 from /usr/local/include/k2/csrc/array.h:30,
                 from /usr/local/include/k2/csrc/ragged.h:28,
                 from /usr/local/include/k2/csrc/fsa.h:26,
                 from /usr/local/include/k2/csrc/intersect_dense_pruned.h:25,
                 from /home/cpu13266/binhtt4/research/asr-end2end-triton/triton_k2/src/libk2.cc:6:
/usr/local/include/k2/csrc/fake_cuda.h:50:7: note: previous declaration as ‘using cudaError_t = int32_t’
   50 | using cudaError_t = int32_t;
      |       ^~~~~~~~~~~
In file included from /usr/local/cuda/include/builtin_types.h:59,
                 from /usr/local/cuda/include/cuda_runtime_api.h:149,
                 from /home/cpu13266/binhtt4/research/asr-end2end-triton/triton_k2/cmake-build-debug/_deps/repo-backend-src/include/triton/backend/backend_common.h:48,
                 from /home/cpu13266/binhtt4/research/asr-end2end-triton/triton_k2/src/libk2.cc:16:
/usr/local/cuda/include/driver_types.h:2797:48: error: conflicting declaration ‘typedef struct CUstream_st* cudaStream_t’
 2797 | typedef __device_builtin__ struct CUstream_st *cudaStream_t;
      |                                                ^~~~~~~~~~~~
In file included from /usr/local/include/k2/csrc/context.h:45,
                 from /usr/local/include/k2/csrc/array.h:30,
                 from /usr/local/include/k2/csrc/ragged.h:28,
                 from /usr/local/include/k2/csrc/fsa.h:26,
                 from /usr/local/include/k2/csrc/intersect_dense_pruned.h:25,
                 from /home/cpu13266/binhtt4/research/asr-end2end-triton/triton_k2/src/libk2.cc:6:
/usr/local/include/k2/csrc/fake_cuda.h:51:7: note: previous declaration as ‘using cudaStream_t = int32_t*’
   51 | using cudaStream_t = int32_t *;
      |       ^~~~~~~~~~~~
In file included from /usr/local/cuda/include/builtin_types.h:59,
                 from /usr/local/cuda/include/cuda_runtime_api.h:149,
                 from /home/cpu13266/binhtt4/research/asr-end2end-triton/triton_k2/cmake-build-debug/_deps/repo-backend-src/include/triton/backend/backend_common.h:48,
                 from /home/cpu13266/binhtt4/research/asr-end2end-triton/triton_k2/src/libk2.cc:16:
/usr/local/cuda/include/driver_types.h:2802:47: error: conflicting declaration ‘typedef struct CUevent_st* cudaEvent_t’
 2802 | typedef __device_builtin__ struct CUevent_st *cudaEvent_t;
      |                                               ^~~~~~~~~~~
In file included from /usr/local/include/k2/csrc/context.h:45,
                 from /usr/local/include/k2/csrc/array.h:30,
                 from /usr/local/include/k2/csrc/ragged.h:28,
                 from /usr/local/include/k2/csrc/fsa.h:26,
                 from /usr/local/include/k2/csrc/intersect_dense_pruned.h:25,
                 from /home/cpu13266/binhtt4/research/asr-end2end-triton/triton_k2/src/libk2.cc:6:
/usr/local/include/k2/csrc/fake_cuda.h:52:7: note: previous declaration as ‘using cudaEvent_t = int32_t*’
   52 | using cudaEvent_t = int32_t *;
      |       ^~~~~~~~~~~
make[3]: *** [CMakeFiles/triton-k2-backend.dir/build.make:79: CMakeFiles/triton-k2-backend.dir/src/libk2.cc.o] Error 1
make[3]: Leaving directory '/home/cpu13266/binhtt4/research/asr-end2end-triton/triton_k2/cmake-build-debug'
make[2]: *** [CMakeFiles/Makefile2:255: CMakeFiles/triton-k2-backend.dir/all] Error 2
make[2]: Leaving directory '/home/cpu13266/binhtt4/research/asr-end2end-triton/triton_k2/cmake-build-debug'
make[1]: *** [CMakeFiles/Makefile2:262: CMakeFiles/triton-k2-backend.dir/rule] Error 2
make[1]: Leaving directory '/home/cpu13266/binhtt4/research/asr-end2end-triton/triton_k2/cmake-build-debug'
make: *** [Makefile:172: triton-k2-backend] Error 2

This is the CmakeLists.txt I use:

cmake_minimum_required(VERSION 3.24)

project(tritonk2backend)

set(languages CXX)
set(CMAKE_CXX_STANDARD 17)

set(CMAKE_CUDA_ARCHITECTURES 86)
set(CMAKE_CUDA_COMPILER /usr/local/cuda/bin/nvcc)

list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)

#### TRITON REPOS ####
set(TRITON_COMMON_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/common repo")
set(TRITON_CORE_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/core repo")
set(TRITON_BACKEND_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/backend repo")

if(NOT CMAKE_BUILD_TYPE)
    set(CMAKE_BUILD_TYPE Release)
endif()

set(CMAKE_VERBOSE_MAKEFILE ON)
set(CUDA_VERBOSE_BUILD ON)

include(FetchContent)

FetchContent_Declare(
    repo-common
    GIT_REPOSITORY https://github.com/triton-inference-server/common.git
    GIT_TAG ${TRITON_COMMON_REPO_TAG}
    GIT_SHALLOW ON)

FetchContent_Declare(
    repo-core
    GIT_REPOSITORY https://github.com/triton-inference-server/core.git
    GIT_TAG ${TRITON_CORE_REPO_TAG}
    GIT_SHALLOW ON)

FetchContent_Declare(
    repo-backend
    GIT_REPOSITORY https://github.com/triton-inference-server/backend.git
    GIT_TAG ${TRITON_BACKEND_REPO_TAG}
    GIT_SHALLOW ON)

FetchContent_MakeAvailable(repo-common repo-core repo-backend)
########

find_package(Torch REQUIRED)
find_package(k2 REQUIRED)

#### TRITON BACKEND ####
set(K2_SRC src/libk2.cc src/utils.cc)

add_library(triton-k2-backend SHARED ${K2_SRC})

add_library(TritonKaldifeatBackend::triton-k2-backend ALIAS triton-k2-backend)

set_target_properties(triton-k2-backend PROPERTIES
    POSITION_INDEPENDENT_CODE ON
    OUTPUT_NAME triton_k2)
########

#### link libraries ####
target_include_directories(triton-k2-backend PUBLIC ${K2_INCLUDE_DIRS})

target_link_libraries(triton-k2-backend PRIVATE
    TritonCore::triton-core-serverapi    # from repo-core
    TritonCore::triton-core-backendapi   # from repo-core
    TritonCore::triton-core-serverstub   # from repo-core
    TritonBackend::triton-backend-utils  # from repo-backend
    ${TORCH_LIBRARIES}
    ${K2_LIBRARIES}
    )

The problem seems to be that I link the cpu-only version of k2. I wonder what is the proper way to link k2 with cuda to my C++ project using cmake. Please help me with this.

Thanks in advance!

csukuangfj commented 5 months ago

Please only use https://github.com/k2-fsa/k2/blob/master/k2/torch/csrc/torch_api.h the above header file in your project.

All other header files and C++ source files should NOT be used in your project.

You can refer to https://github.com/k2-fsa/sherpa/blob/master/cmake/k2.cmake for an example about how to use k2 in another CMake-based project.

binhtranmcs commented 5 months ago

Please only use https://github.com/k2-fsa/k2/blob/master/k2/torch/csrc/torch_api.h the above header file in your project.

All other header files and C++ source files should NOT be used in your project.

tks a lot @csukuangfj, this fixes the problem.

binhtranmcs commented 5 months ago

@csukuangfj, k2/torch/csrc/torch_api.h does not expose interfaces k2::SymbolTable, k2::DecodeStateInfo and k2::OnlineDenseIntersecter. I can include k2/torch/csrc/symbol_table.h for the k2::SymbolTable, but the other 2 are not possible. I think k2/torch/csrc/torch_api.h needs updating. Please have a further look.

csukuangfj commented 5 months ago

As I said before, please only use k2/torch/csrc/torch_api.h.

We have examples in k2-fsa/sherpa, please refer to it for examples.

Rembemer that using any other header files from k2 are not supported.

The core functions from k2 have already been added to k2/torch/csrc/torch_api.h.

binhtranmcs commented 5 months ago

Thanks @csukuangfj, I understand that, but torch_api.h does not expose interfaces for online decoding as well as nbest decoding. And I believe sherpa also does not have examples for those.

csukuangfj commented 5 months ago

I am unsure whether @pkufool has time to add online decoding to torch_api.h

By the way, we support streaming HLG decoding in sherpa-onnx, though HLG decoding in sherpa-onnx runs only on CPU.

pkufool commented 5 months ago

@binhtranmcs I will start this work after the May Day holiday.