csukuangfj / kaldifeat

Kaldi-compatible online & offline feature extraction with PyTorch, supporting CUDA, batch processing, chunk processing, and autograd - Provide C++ & Python API
https://csukuangfj.github.io/kaldifeat
Other
187 stars 35 forks source link

any example about C++ API using cuda? #57

Closed wienerjier closed 2 months ago

csukuangfj commented 2 years ago

Yes, please have a look at https://github.com/k2-fsa/k2/blob/b8a45acfa16464324f8ba6cc10b6a9dd3bba0ccc/k2/torch/bin/ctc_decode.cu#L126

  kaldifeat::FbankOptions fbank_opts;
  fbank_opts.frame_opts.samp_freq = FLAGS_sample_rate;
  fbank_opts.frame_opts.dither = 0;
  fbank_opts.frame_opts.frame_shift_ms = FLAGS_frame_shift_ms;
  fbank_opts.frame_opts.frame_length_ms = FLAGS_frame_length_ms;
  fbank_opts.mel_opts.num_bins = FLAGS_num_bins;
  fbank_opts.device = device;

You need to assign a GPU device to fbank_opts


https://github.com/k2-fsa/k2/blob/b8a45acfa16464324f8ba6cc10b6a9dd3bba0ccc/k2/torch/csrc/features.cu#L49

  torch::Tensor features = fbank.ComputeFeatures(strided, /*vtln_warp*/ 1.0f);

You need to move strided to the same device as fbank_opts.

wienerjier commented 2 years ago

Yes, please have a look at https://github.com/k2-fsa/k2/blob/b8a45acfa16464324f8ba6cc10b6a9dd3bba0ccc/k2/torch/bin/ctc_decode.cu#L126

  kaldifeat::FbankOptions fbank_opts;
  fbank_opts.frame_opts.samp_freq = FLAGS_sample_rate;
  fbank_opts.frame_opts.dither = 0;
  fbank_opts.frame_opts.frame_shift_ms = FLAGS_frame_shift_ms;
  fbank_opts.frame_opts.frame_length_ms = FLAGS_frame_length_ms;
  fbank_opts.mel_opts.num_bins = FLAGS_num_bins;
  fbank_opts.device = device;

You need to assign a GPU device to fbank_opts

https://github.com/k2-fsa/k2/blob/b8a45acfa16464324f8ba6cc10b6a9dd3bba0ccc/k2/torch/csrc/features.cu#L49

  torch::Tensor features = fbank.ComputeFeatures(strided, /*vtln_warp*/ 1.0f);

You need to move strided to the same device as fbank_opts.

follow the above links, I test it on arm cpu, but I got a c10 error: what(): self must be a matrix. I wonder it's called by wave_data reader or by the libtorch? Error is like below: terminate called after throwing an instance of c10::Error what(): self must be a matrix Exception raised from meta at /media/nvidia/NVME/pytorch/pytorch-v1.10.0/aten/src/ATen/native/LinearAlgebra.cpp:46 (most recent call first): frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits, std::allocator >) + 0xa0 (0x7f7ed3a508 in /home/asr/.local/lib/python3.6/site-packages/torch/lib/libc10.so) frame #1: c10::detail::torchCheckFail(char const, char const, unsigned int, char const*) + 0xb4 (0x7f7ed369b4 in /home/asr/.local/lib/python3.6/site-packages/torch/lib/libc10.so) frame #2: at::meta::structured_mm::meta(at::Tensor const&, at::Tensor const&) + 0x398 (0x7f7fab2370 in /home/asr/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)

My wave_data is like below : dose it right type for k2::ComputeFeatures?

-0.0054 -0.0063 -0.0063 -0.0042 -0.0025 -0.0025 -0.0036 -0.0062 [ CPUFloatType{258000} ]

my code is like this:

include "../kaldifeat/csrc/feature-fbank.h"

include "k2wav/wave_reader.h"

include "k2wav/features.h"

include < cmath>

using namespace std; int main() {

torch::Device device(torch::kCPU);
kaldifeat::FbankOptions fbank_opts;
fbank_opts.frame_opts.samp_freq = 16000;
fbank_opts.frame_opts.dither = 0;
fbank_opts.frame_opts.frame_shift_ms = 10.0;
fbank_opts.frame_opts.frame_length_ms = 25.0;
fbank_opts.mel_opts.num_bins = 80;
fbank_opts.device = torch::Device(torch::kCPU);
kaldifeat::Fbank fbank(fbank_opts);

std::cout<<fbank_opts.device<<std::endl;

std::vector<std::string> wave_filenames(1);
wave_filenames[0]="/home/asr/vadspeaker/kaldifeat/testalan/hj.wav";
auto wave_data = k2::ReadWave(wave_filenames, 16000);

for (auto &w : wave_data) {
  w = w.to(device);
}

std::cout<<wave_data<<std::endl;
std::cout<<typeid(wave_data).name()<<std::endl;
std::vector<int64_t> num_frames;
auto features_vec = k2::ComputeFeatures(fbank, wave_data, &num_frames);

}

csukuangfj commented 2 years ago

Can you use gdb to get the stack trace?