k2-fsa / sherpa

Speech-to-text server framework with next-gen Kaldi
https://k2-fsa.github.io/sherpa
Apache License 2.0
521 stars 105 forks source link

add use gpu(torch.jit.script to zipformer encoder) #346

Open whaozl opened 1 year ago

ravi-mr commented 1 year ago

@whaozl @csukuangfj following https://k2-fsa.github.io/sherpa/cpp/pretrained_models/online_transducer.html#icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 with --use-gpu=true, even with this pull request I landed up the following error.

./pruned_transducer_stateless7_streaming/jit_trace_export.py(135): export_encoder_model_jit_trace
./pruned_transducer_stateless7_streaming/jit_trace_export.py(298): main
/star-zw/env/k2_icefall/lib/python3.8/site-packages/torch/autograd/grad_mode.py(28): decorate_context
./pruned_transducer_stateless7_streaming/jit_trace_export.py(313): <module>
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Aborted (core dumped)

I tried tweaking jit_trace_export.py, but didn't worked!

whaozl commented 1 year ago

@ravi-mr , you should modify the https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py#L135:

traced_model = torch.jit.trace(encoder_model, (x, x_lens, states))

modify after:

traced_model = torch.jit.script(encoder_model)
csukuangfj commented 1 year ago

@whaozl Could you try https://github.com/k2-fsa/icefall/pull/1005

We now support models exported by torch.jit.script().

I think you can export the model on CPU with torch.jit.script() and then run it on GPU within sherpa.

Please see also https://github.com/k2-fsa/sherpa/pull/365