tsurumeso / vocal-remover

Vocal Remover using Deep Neural Networks
MIT License
1.47k stars 215 forks source link

Can the model be exported to ONNX ? #168

Open wxbool opened 6 months ago

wxbool commented 6 months ago

I attempted to export the pre-trained model baseline.pth and encountered some issues.

export.py

import torch
from lib import nets

model = nets.CascadedNet(n_fft=2048, hop_length=1024, is_complex=False)

model.load_state_dict(torch.load('models/baseline.pth', map_location=torch.device('cpu')))

model.eval()

dummy_input = torch.randn(1, 2, model.max_bin, 100)  

input_names = ["input_waves"]
output_names = ["output_waves"]

torch.onnx.export(model, dummy_input, "model.onnx",
                  verbose=False, input_names=input_names,
                  output_names=output_names,
                  opset_version=12) 

The following error occurred :

E:\Space\script\python\vocal-remover\lib\spec_utils.py:12: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if h1_shape[3] == h2_shape[3]:
E:\Space\script\python\vocal-remover\lib\spec_utils.py:14: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  elif h1_shape[3] < h2_shape[3]:
============= Diagnostic Run torch.onnx.export version 2.0.1+cu117 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

Traceback (most recent call last):
  File "E:\Space\script\python\vocal-remover\export.py", line 27, in <module>
    torch.onnx.export(model, dummy_input, "model.onnx",
  File "G:\ProgramData\Anaconda3\envs\audio-separator\lib\site-packages\torch\onnx\utils.py", line 506, in export
    _export(
  File "G:\ProgramData\Anaconda3\envs\audio-separator\lib\site-packages\torch\onnx\utils.py", line 1548, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "G:\ProgramData\Anaconda3\envs\audio-separator\lib\site-packages\torch\onnx\utils.py", line 1113, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "G:\ProgramData\Anaconda3\envs\audio-separator\lib\site-packages\torch\onnx\utils.py", line 989, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "G:\ProgramData\Anaconda3\envs\audio-separator\lib\site-packages\torch\onnx\utils.py", line 893, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "G:\ProgramData\Anaconda3\envs\audio-separator\lib\site-packages\torch\jit\_trace.py", line 1268, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "G:\ProgramData\Anaconda3\envs\audio-separator\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "G:\ProgramData\Anaconda3\envs\audio-separator\lib\site-packages\torch\jit\_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "G:\ProgramData\Anaconda3\envs\audio-separator\lib\site-packages\torch\jit\_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "G:\ProgramData\Anaconda3\envs\audio-separator\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "G:\ProgramData\Anaconda3\envs\audio-separator\lib\site-packages\torch\nn\modules\module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "E:\Space\script\python\vocal-remover\lib\nets.py", line 91, in forward
    l1 = self.stg1_low_band_net(l1_in)
  File "G:\ProgramData\Anaconda3\envs\audio-separator\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "G:\ProgramData\Anaconda3\envs\audio-separator\lib\site-packages\torch\nn\modules\module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "G:\ProgramData\Anaconda3\envs\audio-separator\lib\site-packages\torch\nn\modules\container.py", line 217, in forward
    input = module(input)
  File "E:\Space\script\python\vocal-remover\lib\nets.py", line 35, in __call__
    h = self.dec4(h, e4)
  File "E:\Space\script\python\vocal-remover\lib\layers.py", line 55, in __call__
    skip = spec_utils.crop_center(skip, x)
  File "E:\Space\script\python\vocal-remover\lib\spec_utils.py", line 15, in crop_center
    raise ValueError('h1_shape[3] must be greater than h2_shape[3]')
ValueError: h1_shape[3] must be greater than h2_shape[3]

I'm not very familiar with exporting models to ONNX. Could you provide a script that can export to ONNX ?

MarinoMing commented 6 months ago

Have you already fixed this issue