FunAudioLLM / SenseVoice

Multilingual Voice Understanding Model
https://funaudiollm.github.io/
Other
2.61k stars 249 forks source link

[about ONNX export] #78

Open LateLinux opened 1 month ago

LateLinux commented 1 month ago

❓ Questions and Help: Fail in exporting from pt to onnx using export.py

What is your question?

encounter error listed below when running `python export.py'. The error details 'RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)'

/miniconda3/envs/sensevoice/bin/python /mnt/sdb/SenseVoice-main/export.py /miniconda3/envs/sensevoice/lib/python3.10/site-packages/rotary_embedding_torch/rotary_embedding_torch.py:35: FutureWarning: torch.cuda.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cuda', args...) instead. @autocast(enabled = False) /miniconda3/envs/sensevoice/lib/python3.10/site-packages/rotary_embedding_torch/rotary_embedding_torch.py:262: FutureWarning: torch.cuda.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cuda', args...) instead. @autocast(enabled = False) Loading remote code successfully: model /miniconda3/envs/sensevoice/lib/python3.10/site-packages/funasr/train_utils/load_pretrained_model.py:38: FutureWarning: You are using torch.load with weights_only=False (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for weights_only will be flipped to True. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via torch.serialization.add_safe_globals. We recommend you start setting weights_only=True for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. src_state = torch.load(path, map_location=map_location) Traceback (most recent call last): File "/mnt/sdb/SenseVoice-main/export.py", line 30, in export_dir = export_utils.export(model=rebuilt_model, kwargs) File "/mnt/sdb/SenseVoice-main/utils/export_utils.py", line 17, in export _onnx( File "/mnt/sdb/SenseVoice-main/utils/export_utils.py", line 43, in _onnx torch.onnx.export( File "/miniconda3/envs/sensevoice/lib/python3.10/site-packages/torch/onnx/utils.py", line 551, in export _export( File "/miniconda3/envs/sensevoice/lib/python3.10/site-packages/torch/onnx/utils.py", line 1648, in _export graph, params_dict, torch_out = _model_to_graph( File "/miniconda3/envs/sensevoice/lib/python3.10/site-packages/torch/onnx/utils.py", line 1170, in _model_to_graph graph, params, torch_out, module = _create_jit_graph(model, args) File "/miniconda3/envs/sensevoice/lib/python3.10/site-packages/torch/onnx/utils.py", line 1046, in _create_jit_graph graph, torch_out = _trace_and_get_graph_from_model(model, args) File "/miniconda3/envs/sensevoice/lib/python3.10/site-packages/torch/onnx/utils.py", line 950, in _trace_and_get_graph_from_model trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph( File "/miniconda3/envs/sensevoice/lib/python3.10/site-packages/torch/jit/_trace.py", line 1497, in _get_trace_graph outs = ONNXTracedModule( File "/miniconda3/envs/sensevoice/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/miniconda3/envs/sensevoice/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(args, kwargs) File "/miniconda3/envs/sensevoice/lib/python3.10/site-packages/torch/jit/_trace.py", line 141, in forward graph, out = torch._C._create_graph_by_tracing( File "/miniconda3/envs/sensevoice/lib/python3.10/site-packages/torch/jit/_trace.py", line 132, in wrapper outs.append(self.inner(trace_inputs)) File "/miniconda3/envs/sensevoice/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/miniconda3/envs/sensevoice/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, *kwargs) File "/miniconda3/envs/sensevoice/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1543, in _slow_forward result = self.forward(input, kwargs) File "/mnt/sdb/SenseVoice-main/export_meta.py", line 32, in export_forward language_query = self.embed(language.to(speech.device)).unsqueeze(1) File "/miniconda3/envs/sensevoice/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/miniconda3/envs/sensevoice/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, *kwargs) File "/miniconda3/envs/sensevoice/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1543, in _slow_forward result = self.forward(input, kwargs) File "/miniconda3/envs/sensevoice/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 164, in forward return F.embedding( File "/miniconda3/envs/sensevoice/lib/python3.10/site-packages/torch/nn/functional.py", line 2267, in embedding return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select) Process finished with exit code 1

Code

What have you tried?

  1. Download the Sensevoicesmall model to local, which locates at <SenseVoice/model/SenseVoiceSmall>

  2. change the model dir in 'export.py' # model_dir = "iic/SenseVoiceSmall" model_dir = os.path.join(os.path.dirname("__file__"), "model", "SenseVoiceSmall")

  3. change the inference verification input #wav_or_scp = "/Users/shixian/Downloads/asr_example_hotword.wav" wav_or_scp = os.path.join(os.path.dirname("__file__"), "input", "bad1.wav")

What's your environment?

LauraGPT commented 1 month ago

make sure torch<=2.3