kaldi-asr / kaldi

kaldi-asr/kaldi is the official location of the Kaldi project.
http://kaldi-asr.org
Other
14.24k stars 5.32k forks source link

BatchedThreadedNnet3CudaPipeline2 Initialization: cudaError_t 11 : "invalid argument" #4112

Open pskrunner14 opened 4 years ago

pskrunner14 commented 4 years ago

I am using the BatchedThreadedNnet3CudaPipeline2 pipeline similar to how it's used in cudadecoderbin/batched-wav-nnet3-cuda2.cc in a custom application. On running the modified code, I got the following error:

LOG ([5.5.0~1-da93]:RemoveOrphanNodes():nnet-nnet.cc:948) Removed 1 orphan nodes.
LOG ([5.5.0~1-da93]:RemoveOrphanComponents():nnet-nnet.cc:847) Removing 2 orphan components.
LOG ([5.5.0~1-da93]:Collapse():nnet-utils.cc:1472) Added 1 components, removed 2
# Word Embeddings (RNNLM): 97396
LOG ([5.5.0~1-da93]:CompileLooped():nnet-compile-looped.cc:345) Spent 0.00422883 seconds in looped compilation.
LOG ([5.5.0~1-da93]:ComputeDerivedVars():ivector-extractor.cc:183) Computing derived variables for iVector extractor
LOG ([5.5.0~1-da93]:ComputeDerivedVars():ivector-extractor.cc:204) Done.
LOG ([5.5.0~1-da93]:CompileLooped():nnet-compile-looped.cc:345) Spent 0.0332868 seconds in looped compilation.
LOG ([5.5.0~1-da93]:SelectGpuId():cu-device.cc:223) CUDA setup operating under Compute Exclusive Mode.
LOG ([5.5.0~1-da93]:FinalizeActiveGpu():cu-device.cc:308) The active GPU is [0]: Tesla M60      free:7437M, used:181M, total:7618M, free/total:0.97621 version 5.2
LOG ([5.5.0~1-da93]:CheckAndFixConfigs():nnet3/nnet-am-decodable-simple.h:129) Increasing --frames-per-chunk from 50 to 63 due to --frame-subsampling-factor=3 and nnet shift-invariance modulus = 21
LOG ([5.5.0~1-da93]:ComputeDerivedVars():ivector-extractor.cc:183) Computing derived variables for iVector extractor
LOG ([5.5.0~1-da93]:ComputeDerivedVars():ivector-extractor.cc:204) Done.
ERROR ([5.5.0~1-da93]:CopyFromVec():cu-vector.cc:1086) cudaError_t 11 : "invalid argument" returned from 'cudaMemcpyAsync(data_, src.data_, src.dim_ * sizeof(Real), cudaMemcpyDeviceToDevice, cudaStreamPerThread)'

[ Stack-Trace: ]
/opt/kaldi/src/lib/libkaldi-base.so(kaldi::MessageLogger::LogMessage() const+0x82c) [0x7fd770ac52aa]
/opt/kaldi/src/lib/libkaldi-matrix.so(kaldi::MessageLogger::LogAndThrow::operator=(kaldi::MessageLogger const&)+0x21) [0x7fd773a52153]
/opt/kaldi/src/lib/libkaldi-cudamatrix.so(kaldi::CuVectorBase<double>::CopyFromVec(kaldi::CuVectorBase<double> const&)+0x186) [0x7fd77256fd98]
/opt/kaldi/src/lib/libkaldi-nnet3.so(kaldi::nnet3::NonlinearComponent::NonlinearComponent(kaldi::nnet3::NonlinearComponent const&)+0x57) [0x7fd7720a330d]
/opt/kaldi/src/lib/libkaldi-nnet3.so(kaldi::nnet3::RectifiedLinearComponent::Copy() const+0x21) [0x7fd7720c71a9]
/opt/kaldi/src/lib/libkaldi-nnet3.so(kaldi::nnet3::Nnet::Nnet(kaldi::nnet3::Nnet const&)+0x3cf) [0x7fd772125749]
/opt/kaldi/src/lib/libkaldi-cudadecoder.so(kaldi::cuda_decoder::BatchedThreadedNnet3CudaOnlinePipeline::ReadParametersFromModel()+0x5e6) [0x7fd7533bc720]
/opt/kaldi/src/lib/libkaldi-cudadecoder.so(kaldi::cuda_decoder::BatchedThreadedNnet3CudaOnlinePipeline::Initialize(fst::Fst<fst::ArcTpl<fst::TropicalWeightTpl<float> > > const&)+0x11) [0x7fd7533bcd21]
/usr/local/lib/libkaldiserve.so(kaldi::cuda_decoder::BatchedThreadedNnet3CudaPipeline2::BatchedThreadedNnet3CudaPipeline2(kaldi::cuda_decoder::BatchedThreadedNnet3CudaPipeline2Config const&, fst::Fst<fst::ArcTpl<fst::TropicalWeightTpl<float> > > const&, kaldi::nnet3::AmNnetSimple const&, kaldi::TransitionModel const&)+0xaf2) [0x7fd771c0e8e2]
/usr/local/lib/libkaldiserve.so(kaldiserve::BatchDecoder::start_decoding()+0xf3) [0x7fd771c0b4b3]
./bin/build/batched-gpu-decoder(main+0x7e9) [0x415e99]
/lib/x86_64-linux-gnu/libc.so.6(__libc_start_main+0xf0) [0x7fd77157b830]
./bin/build/batched-gpu-decoder(_start+0x29) [0x416629]

terminate called after throwing an instance of 'kaldi::KaldiFatalError'
  what():  kaldi::KaldiFatalError
Aborted (core dumped)

From what I can gather, it has something to do with CUDA not able to copy an NNet3 component to the GPU as called at cudamatrix/cu-vector.cc#L1086 from top level call at cudadecoder/batched-threaded-nnet3-cuda-online-pipeline.cc#L407.

I also tried using the batched-wav-nnet3-cuda2 binary to see if there was some issue with the model etc. but it ran fine:

LOG (batched-wav-nnet3-cuda2[5.5.0~1-da93]:SelectGpuId():cu-device.cc:223) CUDA setup operating under Compute Exclusive Mode.
LOG (batched-wav-nnet3-cuda2[5.5.0~1-da93]:FinalizeActiveGpu():cu-device.cc:308) The active GPU is [0]: Tesla M60       free:7437M, used:181M, total:7618M, fr
ee/total:0.97621 version 5.2
LOG (batched-wav-nnet3-cuda2[5.5.0~1-da93]:RemoveOrphanNodes():nnet-nnet.cc:948) Removed 1 orphan nodes.
LOG (batched-wav-nnet3-cuda2[5.5.0~1-da93]:RemoveOrphanComponents():nnet-nnet.cc:847) Removing 2 orphan components.
LOG (batched-wav-nnet3-cuda2[5.5.0~1-da93]:Collapse():nnet-utils.cc:1472) Added 1 components, removed 2
LOG (batched-wav-nnet3-cuda2[5.5.0~1-da93]:CheckAndFixConfigs():nnet3/nnet-am-decodable-simple.h:123) Increasing --frames-per-chunk from 50 to 51 to make it a
 multiple of --frame-subsampling-factor=3
LOG (batched-wav-nnet3-cuda2[5.5.0~1-da93]:ComputeDerivedVars():ivector-extractor.cc:183) Computing derived variables for iVector extractor
LOG (batched-wav-nnet3-cuda2[5.5.0~1-da93]:ComputeDerivedVars():ivector-extractor.cc:204) Done.
LOG (batched-wav-nnet3-cuda2[5.5.0~1-da93]:ComputeDerivedVars():ivector-extractor.cc:183) Computing derived variables for iVector extractor
LOG (batched-wav-nnet3-cuda2[5.5.0~1-da93]:ComputeDerivedVars():ivector-extractor.cc:204) Done.
LOG (batched-wav-nnet3-cuda2[5.5.0~1-da93]:ComputeDerivedVars():ivector-extractor.cc:183) Computing derived variables for iVector extractor
LOG (batched-wav-nnet3-cuda2[5.5.0~1-da93]:ComputeDerivedVars():ivector-extractor.cc:204) Done.
LOG (batched-wav-nnet3-cuda2[5.5.0~1-da93]:ComputeDerivedVars():ivector-extractor.cc:183) Computing derived variables for iVector extractor
LOG (batched-wav-nnet3-cuda2[5.5.0~1-da93]:ComputeDerivedVars():ivector-extractor.cc:204) Done.
LOG (batched-wav-nnet3-cuda2[5.5.0~1-da93]:AllocateNewRegion():cu-allocator.cc:506) About to allocate new memory region of 1935671296 bytes; current memory in
fo is: free:3691M, used:3927M, total:7618M, free/total:0.484533
LOG (batched-wav-nnet3-cuda2[5.5.0~1-da93]:ComputeDerivedVars():ivector-extractor.cc:183) Computing derived variables for iVector extractor
LOG (batched-wav-nnet3-cuda2[5.5.0~1-da93]:ComputeDerivedVars():ivector-extractor.cc:204) Done.
LOG (batched-wav-nnet3-cuda2[5.5.0~1-da93]:ComputeDerivedVars():ivector-extractor.cc:183) Computing derived variables for iVector extractor
LOG (batched-wav-nnet3-cuda2[5.5.0~1-da93]:ComputeDerivedVars():ivector-extractor.cc:204) Done.
LOG (batched-wav-nnet3-cuda2[5.5.0~1-da93]:ComputeDerivedVars():ivector-extractor.cc:183) Computing derived variables for iVector extractor
LOG (batched-wav-nnet3-cuda2[5.5.0~1-da93]:ComputeDerivedVars():ivector-extractor.cc:204) Done.
LOG (batched-wav-nnet3-cuda2[5.5.0~1-da93]:main():batched-wav-nnet3-cuda2.cc:180) Decoded 2 utterances, 0 with errors.
LOG (batched-wav-nnet3-cuda2[5.5.0~1-da93]:main():batched-wav-nnet3-cuda2.cc:182) Overall likelihood per frame was -nan per frame over 0 frames.
LOG (batched-wav-nnet3-cuda2[5.5.0~1-da93]:main():batched-wav-nnet3-cuda2.cc:185) Overall:  Aggregate Total Time: 1.60955 Total Audio: 10.2378 RealTimeX: 6.36
064
LOG (batched-wav-nnet3-cuda2[5.5.0~1-da93]:~CachingOptimizingCompiler():nnet-optimize.cc:710) 0.0357 seconds taken in nnet3 compilation total (breakdown: 0.02
13 compilation, 0.00266 optimization, 0.00969 shortcut expansion, 0.000512 checking, 0.000253 computing indexes, 0.00123 misc.) + 0 I/O.

Would appreciate some help on this issue. Adding link to code for reference: https://github.com/Vernacular-ai/kaldi-serve/blob/gpu-decoder/src/decoder/decoder-batch.cpp

btiplitz commented 4 years ago

Looks ok, but I don't have access to my implementation of the nvidia code. I'll try and look next week

btiplitz commented 4 years ago

@pskrunner14 Looking at your code, I see in other places you have direct calls to AdvanceDecoding. I'd be careful there as nvidia does change their code regularly. But in your code, the lamba call backs can occur in parallel, so it looks like you might be missing a lock on the callback.

The API i've used seems pretty simple.
Init CreateTaskGroup Call DecodeWithCallback WaitForGroup DestroyTaskGroup

I know you are not using the task group feature. That was added to allow for a continuous stream of data where you want to know a batch of processing is now complete through the library.

Your error is within the gpu code, but I'd want to ensure you don't have a threading issue first. The copy is a dma call that should only fail if the parameters to DMA have an error or if the gpu is not active. (or if the memory types are wrong). And I believe kaldi should ensure the gpu is always active so that should not be possible without some other complicating factor

btiplitz commented 4 years ago

@pskrunner14 have you looked at this more ?

pskrunner14 commented 4 years ago

@btiplitz I'll take a look this or next week.

stale[bot] commented 4 years ago

This issue has been automatically marked as stale by a bot solely because it has not had recent activity. Please add any comment (simply 'ping' is enough) to prevent the issue from being closed for 60 more days if you believe it should be kept open.