k2-fsa / k2

FSA/FST algorithms, differentiable, with PyTorch compatibility.
https://k2-fsa.github.io/k2
Apache License 2.0
1.08k stars 211 forks source link

argmax_internal::PairInputIterator and argmax_internal::PairOutputIterator don't implement the random access iterator concept #1277

Open galv opened 2 months ago

galv commented 2 months ago

Hi all, a colleague recently reached out to me asking for help becuase he couldn't build k2 from source, citing the following cub error, using a newer version of cub than the one you build with:

  Remark: The warnings can be suppressed with "-diag-suppress <warning-number>"

  /usr/local/cuda/include/cub/device/dispatch/dispatch_reduce.cuh(390): error: no operator "*" matches these operands
              operand types are: * k2::argmax_internal::PairOutputIterator<float>
          *(d_out + blockIdx.x) = init;
          ^
            detected during:
              instantiation of "void k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceSegmentedReduceKernel<ChainedPolicyT,InputIteratorT,OutputIteratorT,BeginOffsetIteratorT,EndOffsetIteratorT,OffsetT,ReductionOpT,InitT,AccumT>(InputIteratorT, OutputIteratorT, BeginOffsetIteratorT, EndOffsetIteratorT, int, ReductionOpT, InitT) [with ChainedPolicyT=k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceReducePolicy<k2::argmax_internal::Pair<float>, int, k2::argmax_internal::PairMaxOp<float>>::Policy600, InputIteratorT=k2::argmax_internal::PairInputIterator<float>, OutputIteratorT=k2::argmax_internal::PairOutputIterator<float>, BeginOffsetIteratorT=const int32_t *, EndOffsetIteratorT=const int32_t *, OffsetT=int, ReductionOpT=k2::argmax_internal::PairMaxOp<float>, InitT=k2::argmax_internal::Pair<float>, AccumT=k2::argmax_internal::Pair<float>]" at line 1268
              instantiation of "cudaError_t k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DispatchSegmentedReduce<InputIteratorT, OutputIteratorT, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT, ReductionOpT, InitT, AccumT, SelectedPolicy>::Invoke<ActivePolicyT>() [with InputIteratorT=k2::argmax_internal::PairInputIterator<float>, OutputIteratorT=k2::argmax_internal::PairOutputIterator<float>, BeginOffsetIteratorT=const int32_t *, EndOffsetIteratorT=const int32_t *, OffsetT=int, ReductionOpT=k2::argmax_internal::PairMaxOp<float>, InitT=k2::argmax_internal::Pair<float>, AccumT=k2::argmax_internal::Pair<float>, SelectedPolicy=k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceReducePolicy<k2::argmax_internal::Pair<float>, int, k2::argmax_internal::PairMaxOp<float>>, ActivePolicyT=k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceReducePolicy<k2::argmax_internal::Pair<float>, int, k2::argmax_internal::PairMaxOp<float>>::Policy300]" at line 705 of /usr/local/cuda/include/cub/block/../iterator/../util_device.cuh
              instantiation of "cudaError_t k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::ChainedPolicy<PTX_VERSION, PolicyT, PolicyT>::Invoke(int, FunctorT &) [with PTX_VERSION=300, PolicyT=k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceReducePolicy<k2::argmax_internal::Pair<float>, int, k2::argmax_internal::PairMaxOp<float>>::Policy300, FunctorT=k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DispatchSegmentedReduce<k2::argmax_internal::PairInputIterator<float>, k2::argmax_internal::PairOutputIterator<float>, const int32_t *, const int32_t *, int, k2::argmax_internal::PairMaxOp<float>, k2::argmax_internal::Pair<float>, k2::argmax_internal::Pair<float>, k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceReducePolicy<k2::argmax_internal::Pair<float>, int, k2::argmax_internal::PairMaxOp<float>>>]" at line 688 of /usr/local/cuda/include/cub/block/../iterator/../util_device.cuh
              instantiation of "cudaError_t k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::ChainedPolicy<PTX_VERSION, PolicyT, PrevPolicyT>::Invoke(int, FunctorT &) [with PTX_VERSION=350, PolicyT=k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceReducePolicy<k2::argmax_internal::Pair<float>, int, k2::argmax_internal::PairMaxOp<float>>::Policy350, PrevPolicyT=k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceReducePolicy<k2::argmax_internal::Pair<float>, int, k2::argmax_internal::PairMaxOp<float>>::Policy300, FunctorT=k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DispatchSegmentedReduce<k2::argmax_internal::PairInputIterator<float>, k2::argmax_internal::PairOutputIterator<float>, const int32_t *, const int32_t *, int, k2::argmax_internal::PairMaxOp<float>, k2::argmax_internal::Pair<float>, k2::argmax_internal::Pair<float>, k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceReducePolicy<k2::argmax_internal::Pair<float>, int, k2::argmax_internal::PairMaxOp<float>>>]" at line 688 of /usr/local/cuda/include/cub/block/../iterator/../util_device.cuh
              instantiation of "cudaError_t k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::ChainedPolicy<PTX_VERSION, PolicyT, PrevPolicyT>::Invoke(int, FunctorT &) [with PTX_VERSION=600, PolicyT=k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceReducePolicy<k2::argmax_internal::Pair<float>, int, k2::argmax_internal::PairMaxOp<float>>::Policy600, PrevPolicyT=k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceReducePolicy<k2::argmax_internal::Pair<float>, int, k2::argmax_internal::PairMaxOp<float>>::Policy350, FunctorT=k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DispatchSegmentedReduce<k2::argmax_internal::PairInputIterator<float>, k2::argmax_internal::PairOutputIterator<float>, const int32_t *, const int32_t *, int, k2::argmax_internal::PairMaxOp<float>, k2::argmax_internal::Pair<float>, k2::argmax_internal::Pair<float>, k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceReducePolicy<k2::argmax_internal::Pair<float>, int, k2::argmax_internal::PairMaxOp<float>>>]" at line 1362
              instantiation of "cudaError_t k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DispatchSegmentedReduce<InputIteratorT, OutputIteratorT, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT, ReductionOpT, InitT, AccumT, SelectedPolicy>::Dispatch(void *, size_t &, InputIteratorT, OutputIteratorT, int, BeginOffsetIteratorT, EndOffsetIteratorT, ReductionOpT, InitT, cudaStream_t) [with InputIteratorT=k2::argmax_internal::PairInputIterator<float>, OutputIteratorT=k2::argmax_internal::PairOutputIterator<float>, BeginOffsetIteratorT=const int32_t *, EndOffsetIteratorT=const int32_t *, OffsetT=int, ReductionOpT=k2::argmax_internal::PairMaxOp<float>, InitT=k2::argmax_internal::Pair<float>, AccumT=k2::argmax_internal::Pair<float>, SelectedPolicy=k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceReducePolicy<k2::argmax_internal::Pair<float>, int, k2::argmax_internal::PairMaxOp<float>>]" at line 233 of /usr/local/cuda/include/cub/device/device_segmented_reduce.cuh
              instantiation of "cudaError_t k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceSegmentedReduce::Reduce(void *, size_t &, InputIteratorT, OutputIteratorT, int, BeginOffsetIteratorT, EndOffsetIteratorT, ReductionOp, T, cudaStream_t) [with InputIteratorT=k2::argmax_internal::PairInputIterator<float>, OutputIteratorT=k2::argmax_internal::PairOutputIterator<float>, BeginOffsetIteratorT=const int32_t *, EndOffsetIteratorT=const int32_t *, ReductionOp=k2::argmax_internal::PairMaxOp<float>, T=k2::argmax_internal::Pair<float>]" at line 664 of /tmp/pip-install-ebrvnn_m/k2_61be9d7db2244726a8b700f5fb090e48/k2/csrc/ragged_ops_inl.h
              instantiation of "void k2::ArgMaxPerSublist(k2::Ragged<T> &, T, k2::Array1<int32_t> *) [with T=float]" at line 1576 of /tmp/pip-install-ebrvnn_m/k2_61be9d7db2244726a8b700f5fb090e48/k2/csrc/fsa_utils.cu
              instantiation of "k2::Array1<FloatType> k2::BackpropGetBackwardScores(k2::FsaVec &, k2::Ragged<int32_t> &, k2::Ragged<int32_t> &, __nv_bool, const k2::Array1<FloatType> &, const k2::Array1<FloatType> &) [with FloatType=float]" at line 1626 of /tmp/pip-install-ebrvnn_m/k2_61be9d7db2244726a8b700f5fb090e48/k2/csrc/fsa_utils.cu

  /usr/local/cuda/include/cub/device/dispatch/dispatch_reduce.cuh(87): error: no operator "*" matches these operands
              operand types are: * k2::argmax_internal::PairOutputIterator<float>
      *d_out = reduction_op(init, block_aggregate);
      ^
            detected during:
              instantiation of "void k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::detail::reduce::finalize_and_store_aggregate(OutputIteratorT, ReductionOpT, InitT, AccumT) [with OutputIteratorT=k2::argmax_internal::PairOutputIterator<float>, ReductionOpT=k2::argmax_internal::PairMaxOp<float>, InitT=k2::argmax_internal::Pair<float>, AccumT=k2::argmax_internal::Pair<float>]" at line 407
              instantiation of "void k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceSegmentedReduceKernel<ChainedPolicyT,InputIteratorT,OutputIteratorT,BeginOffsetIteratorT,EndOffsetIteratorT,OffsetT,ReductionOpT,InitT,AccumT>(InputIteratorT, OutputIteratorT, BeginOffsetIteratorT, EndOffsetIteratorT, int, ReductionOpT, InitT) [with ChainedPolicyT=k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceReducePolicy<k2::argmax_internal::Pair<float>, int, k2::argmax_internal::PairMaxOp<float>>::Policy600, InputIteratorT=k2::argmax_internal::PairInputIterator<float>, OutputIteratorT=k2::argmax_internal::PairOutputIterator<float>, BeginOffsetIteratorT=const int32_t *, EndOffsetIteratorT=const int32_t *, OffsetT=int, ReductionOpT=k2::argmax_internal::PairMaxOp<float>, InitT=k2::argmax_internal::Pair<float>, AccumT=k2::argmax_internal::Pair<float>]" at line 1268
              instantiation of "cudaError_t k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DispatchSegmentedReduce<InputIteratorT, OutputIteratorT, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT, ReductionOpT, InitT, AccumT, SelectedPolicy>::Invoke<ActivePolicyT>() [with InputIteratorT=k2::argmax_internal::PairInputIterator<float>, OutputIteratorT=k2::argmax_internal::PairOutputIterator<float>, BeginOffsetIteratorT=const int32_t *, EndOffsetIteratorT=const int32_t *, OffsetT=int, ReductionOpT=k2::argmax_internal::PairMaxOp<float>, InitT=k2::argmax_internal::Pair<float>, AccumT=k2::argmax_internal::Pair<float>, SelectedPolicy=k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceReducePolicy<k2::argmax_internal::Pair<float>, int, k2::argmax_internal::PairMaxOp<float>>, ActivePolicyT=k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceReducePolicy<k2::argmax_internal::Pair<float>, int, k2::argmax_internal::PairMaxOp<float>>::Policy300]" at line 705 of /usr/local/cuda/include/cub/block/../iterator/../util_device.cuh
              instantiation of "cudaError_t k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::ChainedPolicy<PTX_VERSION, PolicyT, PolicyT>::Invoke(int, FunctorT &) [with PTX_VERSION=300, PolicyT=k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceReducePolicy<k2::argmax_internal::Pair<float>, int, k2::argmax_internal::PairMaxOp<float>>::Policy300, FunctorT=k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DispatchSegmentedReduce<k2::argmax_internal::PairInputIterator<float>, k2::argmax_internal::PairOutputIterator<float>, const int32_t *, const int32_t *, int, k2::argmax_internal::PairMaxOp<float>, k2::argmax_internal::Pair<float>, k2::argmax_internal::Pair<float>, k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceReducePolicy<k2::argmax_internal::Pair<float>, int, k2::argmax_internal::PairMaxOp<float>>>]" at line 688 of /usr/local/cuda/include/cub/block/../iterator/../util_device.cuh
              instantiation of "cudaError_t k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::ChainedPolicy<PTX_VERSION, PolicyT, PrevPolicyT>::Invoke(int, FunctorT &) [with PTX_VERSION=350, PolicyT=k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceReducePolicy<k2::argmax_internal::Pair<float>, int, k2::argmax_internal::PairMaxOp<float>>::Policy350, PrevPolicyT=k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceReducePolicy<k2::argmax_internal::Pair<float>, int, k2::argmax_internal::PairMaxOp<float>>::Policy300, FunctorT=k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DispatchSegmentedReduce<k2::argmax_internal::PairInputIterator<float>, k2::argmax_internal::PairOutputIterator<float>, const int32_t *, const int32_t *, int, k2::argmax_internal::PairMaxOp<float>, k2::argmax_internal::Pair<float>, k2::argmax_internal::Pair<float>, k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceReducePolicy<k2::argmax_internal::Pair<float>, int, k2::argmax_internal::PairMaxOp<float>>>]" at line 688 of /usr/local/cuda/include/cub/block/../iterator/../util_device.cuh
              instantiation of "cudaError_t k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::ChainedPolicy<PTX_VERSION, PolicyT, PrevPolicyT>::Invoke(int, FunctorT &) [with PTX_VERSION=600, PolicyT=k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceReducePolicy<k2::argmax_internal::Pair<float>, int, k2::argmax_internal::PairMaxOp<float>>::Policy600, PrevPolicyT=k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceReducePolicy<k2::argmax_internal::Pair<float>, int, k2::argmax_internal::PairMaxOp<float>>::Policy350, FunctorT=k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DispatchSegmentedReduce<k2::argmax_internal::PairInputIterator<float>, k2::argmax_internal::PairOutputIterator<float>, const int32_t *, const int32_t *, int, k2::argmax_internal::PairMaxOp<float>, k2::argmax_internal::Pair<float>, k2::argmax_internal::Pair<float>, k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceReducePolicy<k2::argmax_internal::Pair<float>, int, k2::argmax_internal::PairMaxOp<float>>>]" at line 1362
              instantiation of "cudaError_t k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DispatchSegmentedReduce<InputIteratorT, OutputIteratorT, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT, ReductionOpT, InitT, AccumT, SelectedPolicy>::Dispatch(void *, size_t &, InputIteratorT, OutputIteratorT, int, BeginOffsetIteratorT, EndOffsetIteratorT, ReductionOpT, InitT, cudaStream_t) [with InputIteratorT=k2::argmax_internal::PairInputIterator<float>, OutputIteratorT=k2::argmax_internal::PairOutputIterator<float>, BeginOffsetIteratorT=const int32_t *, EndOffsetIteratorT=const int32_t *, OffsetT=int, ReductionOpT=k2::argmax_internal::PairMaxOp<float>, InitT=k2::argmax_internal::Pair<float>, AccumT=k2::argmax_internal::Pair<float>, SelectedPolicy=k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceReducePolicy<k2::argmax_internal::Pair<float>, int, k2::argmax_internal::PairMaxOp<float>>]" at line 233 of /usr/local/cuda/include/cub/device/device_segmented_reduce.cuh
              instantiation of "cudaError_t k2::cub::CUB_200200_520_600_610_700_720_750_800_860_870_900_NS::DeviceSegmentedReduce::Reduce(void *, size_t &, InputIteratorT, OutputIteratorT, int, BeginOffsetIteratorT, EndOffsetIteratorT, ReductionOp, T, cudaStream_t) [with InputIteratorT=k2::argmax_internal::PairInputIterator<float>, OutputIteratorT=k2::argmax_internal::PairOutputIterator<float>, BeginOffsetIteratorT=const int32_t *, EndOffsetIteratorT=const int32_t *, ReductionOp=k2::argmax_internal::PairMaxOp<float>, T=k2::argmax_internal::Pair<float>]" at line 664 of /tmp/pip-install-ebrvnn_m/k2_61be9d7db2244726a8b700f5fb090e48/k2/csrc/ragged_ops_inl.h
              instantiation of "void k2::ArgMaxPerSublist(k2::Ragged<T> &, T, k2::Array1<int32_t> *) [with T=float]" at line 1576 of /tmp/pip-install-ebrvnn_m/k2_61be9d7db2244726a8b700f5fb090e48/k2/csrc/fsa_utils.cu
              instantiation of "k2::Array1<FloatType> k2::BackpropGetBackwardScores(k2::FsaVec &, k2::Ragged<int32_t> &, k2::Ragged<int32_t> &, __nv_bool, const k2::Array1<FloatType> &, const k2::Array1<FloatType> &) [with FloatType=float]" at line 1626 of /tmp/pip-install-ebrvnn_m/k2_61be9d7db2244726a8b700f5fb090e48/k2/csrc/fsa_utils.cu

Unfortunately, someone made a change to cub in the recent past that changes from operator[] to operator+ and operator* in this segmented reduce implementation.

The root cause is that these two types https://github.com/k2-fsa/k2/blob/e8158de00bcf91eb01294a96f850e21b5b11e6e3/k2/csrc/ragged_ops_inl.h#L550 and https://github.com/k2-fsa/k2/blob/e8158de00bcf91eb01294a96f850e21b5b11e6e3/k2/csrc/ragged_ops_inl.h#L580 don't properly implement the random access iterator "concept". Since cub expects its types to be random access iterators, the existing code just "happened to work" becuase it used [] instead of *.

This page handily shows the methods that must be implemented to implement a random access iterator: https://en.cppreference.com/w/cpp/iterator/random_access_iterator

Anyway, the point is that a few more methods need to be implemented on those types (and maybe others) to fix this error. I could probably do this myself, but I haven't really kept up with k2 work recently, so if someone would be willing to do so themselves, that would be much appreciated. Unfortunately I'm not sure of a way to "statically assert" that a specific type implements the Random Access Iterator concept right now...

Anyway, you can reproduce the error by building with the HEAD version of cub, or just the cub version including in the latest CUDA 12.3 toolkit.

csukuangfj commented 2 months ago

Thank you for reporting.

I am trying to reproduce it locally and will try to fix it if I could reproduce it.

csukuangfj commented 2 months ago

Sorry that I could not reproduce it with the latest commit of cub.

Here is my change to cub.cmake:

diff --git a/cmake/cub.cmake b/cmake/cub.cmake
index dd66606a..51e85b67 100644
--- a/cmake/cub.cmake
+++ b/cmake/cub.cmake
@@ -20,18 +20,22 @@ function(download_cub)

   include(FetchContent)

-  set(cub_URL  "https://github.com/NVlabs/cub/archive/1.15.0.tar.gz")
-  set(cub_URL2 "https://hub.nuaa.cf/NVlabs/cub/archive/1.15.0.tar.gz")
-  set(cub_HASH "SHA256=1781ee5eb7f00acfee5bff88e3acfc67378f6b3c24281335e18ae19e1f2ff685")
+  # set(cub_URL  "https://github.com/NVlabs/cub/archive/1.15.0.tar.gz")
+  # set(cub_URL2 "https://hub.nuaa.cf/NVlabs/cub/archive/1.15.0.tar.gz")
+  # set(cub_HASH "SHA256=1781ee5eb7f00acfee5bff88e3acfc67378f6b3c24281335e18ae19e1f2ff685")
+
+  set(cub_URL  "https://github.com/NVlabs/cub/archive/0fc3c3701632a4be906765b73be20a9ad0da603d.zip")
+  set(cub_URL2 "https://hub.nuaa.cf/NVlabs/cub/archive/0fc3c3701632a4be906765b73be20a9ad0da603d.zip")
+  set(cub_HASH "SHA256=88dc9f86564f4a76f4407cdc98eec2dd1cfdca9d92fcf6e1d2a51f6456e118b5")

   # If you don't have access to the Internet,
   # please pre-download cub
   set(possible_file_locations
-    $ENV{HOME}/Downloads/cub-1.15.0.tar.gz
-    ${CMAKE_SOURCE_DIR}/cub-1.15.0.tar.gz
-    ${CMAKE_BINARY_DIR}/cub-1.15.0.tar.gz
-    /tmp/cub-1.15.0.tar.gz
-    /star-fj/fangjun/download/github/cub-1.15.0.tar.gz
+    $ENV{HOME}/Downloads/cub-0fc3c3701632a4be906765b73be20a9ad0da603d.zip
+    ${CMAKE_SOURCE_DIR}/cub-0fc3c3701632a4be906765b73be20a9ad0da603d.zip
+    ${CMAKE_BINARY_DIR}/cub-0fc3c3701632a4be906765b73be20a9ad0da603d.zip
+    /tmp/cub-0fc3c3701632a4be906765b73be20a9ad0da603d.zip
+    /star-fj/fangjun/download/github/cub-0fc3c3701632a4be906765b73be20a9ad0da603d.zip
   )

   foreach(f IN LISTS possible_file_locations)

I am using torch 1.13.0, cuda 11.6


There are no PyTorch versions for cuda 12.3 and we don't have such an environment locally.

k2 can be built for all currently available versions of PyTorch.

galv commented 2 months ago

There are no PyTorch versions for cuda 12.3 and we don't have such an environment locally.

Just to note, there is no problem with mixing cuda toolkit versions within a single application. The main requirement is that you use a cuda driver (aka, libcuda.so and the associated .ko files). There are a few instances in NeMo where we depend upon, e.g., cuda 12.3 or cuda 12.4 features, but use pytorch versions that link to older versions of the cuda toolkit.

galv commented 2 months ago

I can take a look at some point, but it's not high priority for me right now. If someone else runs into this issue, they should make a comment here.