coqui-ai / TTS

πŸΈπŸ’¬ - a deep learning toolkit for Text-to-Speech, battle-tested in research and production
http://coqui.ai
Mozilla Public License 2.0
34.31k stars 4.16k forks source link

VITS to ONNX [Feature request] #2472

Closed ADD-eNavarro closed 1 year ago

ADD-eNavarro commented 1 year ago

πŸš€ Feature Description In your roadmap there's a plan to add exports to tflite and onnx. Let's see together how to export VITS to onnx format.

Solution

End up with a simple way to export VITS model to onnx, as a starting point to be able to export any model to onnx.

ADD-eNavarro commented 1 year ago

So, here's the thing: I have trained a VITS model (from phonemes and for spanish) in an A100 GPU, and now I'm trying to export it to onnx format, since I intend to use it in .NET C#. Hopefully the leassons we learn here will be useful to develop a more general model_to_onnx mechanism.

So, I have this script, ToOnnx.py:

from lib2to3.pytree import convert
import os
import wave
from trainer import Trainer, TrainerArgs
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import Vits, VitsAudioConfig
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor

import torch.onnx 

#Function to Convert to ONNX 
# Since we're loading a .pth file, we need first to create an empty model and then load in it the weights from the file.
def Convert_ONNX(model_file_name): 
    output_path = os.path.dirname(os.path.abspath(__file__))
    dataset_config = BaseDatasetConfig(
        formatter="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "LJSpeech-1.1/")
    )
    audio_config = VitsAudioConfig(
        sample_rate=22050, win_length=1024, hop_length=256, num_mels=80, mel_fmin=0, mel_fmax=None
    )

    config = VitsConfig(
        audio=audio_config,
        run_name="vits_ljspeech_ES",
        batch_size=32,
        eval_batch_size=16,
        batch_group_size=5,
        num_loader_workers=8,
        num_eval_loader_workers=4,
        run_eval=True,
        test_delay_epochs=-1,
        epochs=1000,
        text_cleaner="phoneme_cleaners",
        use_phonemes=True,
        phoneme_language="es-es",
        phoneme_cache_path=os.path.join(output_path, "phoneme_cache_es"),
        compute_input_seq_cache=True,
        print_step=25,
        print_eval=True,
        mixed_precision=True,
        output_path=output_path,
        datasets=[dataset_config],
        cudnn_benchmark=False
    )

    # Init the audio processor
    ap = AudioProcessor.init_from_config(config)

    # Init the tokenizer
    tokenizer, config = TTSTokenizer.init_from_config(config)

    # Init model
    model = Vits(config, ap, tokenizer, speaker_manager=None)

    # Load pretrained weights into the model.
    # Added map_location to work on CPU
    model.load_state_dict(torch.load(model_file_name, map_location=torch.device('cpu')), strict=False)

    # Set the model to inference mode 
    model.eval() 

    # Let's create a dummy input tensor 
    # 
    # Couldn't find the input shape to use, but I looked into your tests and got this.
    # Code extracted from: github.com/coqui-ai/TTS/blob/dev/tests/tts_tests/test_vits.py, line148
    # 
    def _create_inputs(config, batch_size=1, device='cpu'):
        print("ToOnnx.py -> _create_inputs")

        # The next line raises an error, it takes on the GPU but it says there's also some usage of CPU and can't work on both devices at the same time. So I switched to pure CPU (as a function parameter), should be OK since transforming to onnx is not such a heavy task.
        # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # It raises an error 
        input_dummy = torch.randint(0, 24, (batch_size, 128)).long().to(device)
        print("input_dummy: ", input_dummy.shape)
        input_lengths = torch.randint(100, 129, (batch_size,)).long().to(device)
        input_lengths[-1] = 128

        spec = torch.rand(batch_size, config.audio["fft_size"] // 2 + 1, 30).to(device)
        print("spec: ", spec.shape)

        mel = torch.rand(batch_size, config.audio["num_mels"], 30).to(device)
        print("mel: ", mel.shape)

        spec_lengths = torch.randint(20, 30, (batch_size,)).long().to(device)
        print("spec_lengths: ", spec_lengths)

        spec_lengths[-1] = spec.size(2)
        print("spec_lengths: ", spec_lengths)

        waveform = torch.rand(batch_size, 1, spec.size(2) * config.audio["hop_length"]).to(device)
        return input_dummy, input_lengths, mel, spec, spec_lengths, waveform

    config.model_args.spec_segment_size = 10
    dummy_input, input_lengths, _, spec, spec_lengths, waveform = _create_inputs(config)  

   # Then the actual dummy_input (I name it dum_input) takes form:
    dum_input = (dummy_input, spec, spec_lengths, waveform)

    # Export the model   
    torch.onnx.export(model,        # model being run 
             dum_input,         # model input (or a tuple for multiple inputs) 
             "VITS_to.onnx",       # where to save the model  
             export_params=True,    # store the trained parameter weights inside the model file 
             opset_version=16,      # the ONNX version to export the model to 
             do_constant_folding=True,  # whether to execute constant folding for optimization 
             input_names = ['text'],    # the model's input names 
             output_names = ['wav', 'att_w', 'dur'], # the model's output names 
             dynamic_axes={'text' : {0 : 'batch_size'},   
                           'wav' : {0 : 'batch_size'},
                           'att_w' : {0 : 'batch_size'},
                           'dur' : {0 : 'batch_size'},
                           })  # variable length axes 
    print(" ") 
    print('Model has been converted to ONNX') 

Convert_ONNX("checkpoint_400000.pth")

And my result is this:

(env_YourTTS) ...\YourTTS>python ToOnnx.py
 > Setting up Audio Processor...
 | > sample_rate:22050
 | > resample:False
 | > num_mels:80
 | > log_func:np.log10
 | > min_level_db:0
 | > frame_shift_ms:None
 | > frame_length_ms:None
 | > ref_level_db:None
 | > fft_size:1024
 | > power:None
 | > preemphasis:0.0
 | > griffin_lim_iters:None
 | > signal_norm:None
 | > symmetric_norm:None
 | > mel_fmin:0
 | > mel_fmax:None
 | > pitch_fmin:None
 | > pitch_fmax:None
 | > spec_gain:20.0
 | > stft_pad_mode:reflect
 | > max_norm:1.0
 | > clip_norm:True
 | > do_trim_silence:False
 | > trim_db:60
 | > do_sound_norm:False
 | > do_amp_to_db_linear:True
 | > do_amp_to_db_mel:True
 | > do_rms_norm:False
 | > db_level:None
 | > stats_path:None
 | > base:10
 | > hop_length:256
 | > win_length:1024
ToOnnx.py -> _create_inputs
input_dummy:  torch.Size([1, 128])
spec:  torch.Size([1, 513, 30])
mel:  torch.Size([1, 80, 30])
spec_lengths:  tensor([27])
spec_lengths:  tensor([30])
...\YourTTS\TTS\TTS\tts\layers\vits\networks.py:86: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert x.shape[0] == x_lengths.shape[0]
Traceback (most recent call last):
  File "ToOnnx.py", line 117, in <module>
    Convert_ONNX()
  File "ToOnnx.py", line 109, in Convert_ONNX
    'dur' : {0 : 'batch_size'},
  File "C:\Users\Enrique.Navarro\Desktop\Pruebas_TTS\YourTTS\env_YourTTS\lib\site-packages\torch\onnx\utils.py", line 519, in export
    export_modules_as_functions=export_modules_as_functions,
  File "C:\Users\Enrique.Navarro\Desktop\Pruebas_TTS\YourTTS\env_YourTTS\lib\site-packages\torch\onnx\utils.py", line 1539, in _export
    dynamic_axes=dynamic_axes,
  File "C:\Users\Enrique.Navarro\Desktop\Pruebas_TTS\YourTTS\env_YourTTS\lib\site-packages\torch\onnx\utils.py", line 1111, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "C:\Users\Enrique.Navarro\Desktop\Pruebas_TTS\YourTTS\env_YourTTS\lib\site-packages\torch\onnx\utils.py", line 987, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "C:\Users\Enrique.Navarro\Desktop\Pruebas_TTS\YourTTS\env_YourTTS\lib\site-packages\torch\onnx\utils.py", line 896, in _trace_and_get_graph_from_model
    _return_inputs_states=True,
  File "C:\Users\Enrique.Navarro\Desktop\Pruebas_TTS\YourTTS\env_YourTTS\lib\site-packages\torch\jit\_trace.py", line 1184, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "C:\Users\Enrique.Navarro\Desktop\Pruebas_TTS\YourTTS\env_YourTTS\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Enrique.Navarro\Desktop\Pruebas_TTS\YourTTS\env_YourTTS\lib\site-packages\torch\jit\_trace.py", line 132, in forward
    self._force_outplace,
  File "C:\Users\Enrique.Navarro\Desktop\Pruebas_TTS\YourTTS\env_YourTTS\lib\site-packages\torch\jit\_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "C:\Users\Enrique.Navarro\Desktop\Pruebas_TTS\YourTTS\env_YourTTS\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Enrique.Navarro\Desktop\Pruebas_TTS\YourTTS\env_YourTTS\lib\site-packages\torch\nn\modules\module.py", line 1182, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "C:\Users\Enrique.Navarro\Desktop\Pruebas_TTS\YourTTS\TTS\TTS\tts\models\vits.py", line 1015, in forward
    x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
  File "C:\Users\Enrique.Navarro\Desktop\Pruebas_TTS\YourTTS\env_YourTTS\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\Enrique.Navarro\Desktop\Pruebas_TTS\YourTTS\env_YourTTS\lib\site-packages\torch\nn\modules\module.py", line 1182, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "C:\Users\Enrique.Navarro\Desktop\Pruebas_TTS\YourTTS\TTS\TTS\tts\layers\vits\networks.py", line 103, in forward
    x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)  # [b, 1, t]
  File "C:\Users\Enrique.Navarro\Desktop\Pruebas_TTS\YourTTS\TTS\TTS\tts\utils\helpers.py", line 65, in sequence_mask
    mask = seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
RuntimeError: The size of tensor a (128) must match the size of tensor b (30) at non-singleton dimension 3
ADD-eNavarro commented 1 year ago

I add a few more prints here and there to follow the shapes being used. In TTS\TTS\tts\layers\vits\networks.py (class TextEncoder, function forward) I have:

In TTS\TTS\tts\utils\helpers.py (function sequence_mask) I have:

Seeing this, I come back to the _create_inputs function, and try to fix the third dimension of my "spec" variable, making it 128. Now, my unsqueezed are: sequence_length.unsqueeze(1): torch.Size([1, 1, 513, 128]) seq_range.unsqueeze(0): torch.Size([1, 128])

And the error in helpers.py is no more, my error text ends like this:

  File "...\TTS\TTS\tts\layers\vits\networks.py", line 105, in forward
    x = self.encoder(x * x_mask, x_mask)
RuntimeError: The size of tensor a (192) must match the size of tensor b (513) at non-singleton dimension 3

Caused by the x transposed being [1, 192, 128] while the mask is [1, 1, 513, 128].

ADD-eNavarro commented 1 year ago

Now, this 513 comes from the audio configuration via (fft_size/2)+1, coded in _create_inputs: spec = torch.rand(batch_size, config.audio["fft_size"] // 2 + 1, dim_num).to(device)

So (you may have noticed by now that I'm stumbling blindly here) I hard-code it to 192, and my error moves to:

  File "...\TTS\TTS\tts\layers\glow_tts\transformer.py", line 417, in forward
    attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
RuntimeError: The size of tensor a (192) must match the size of tensor b (128) at non-singleton dimension 4

This x_mask parameter (class RelativePositionTransformer, function forward) has shape [1, 1, 1, 192, 128])

So I go back to ToOnnx.py and hard-code it to 128 instead, but the error just moves back to networkd.py:

File "...\TTS\TTS\tts\layers\vits\networks.py", line 109, in forward
    x = self.encoder(x * x_mask, x_mask)
RuntimeError: The size of tensor a (192) must match the size of tensor b (128) at non-singleton dimension 3

There, in class TextEncoder, self.emb is an embedding nn layer that uses hidden_channels, which are 192 as set in config.json. I believe I shouldn't change the config.json, since my model was trained with this one.

ADD-eNavarro commented 1 year ago

So, someone more experienced / technically capable than me, please, a hand? I'm not even sure I'm on the right way to fian a way to make the models "onnx-ables".

erogol commented 1 year ago

Thanks for the issue. You can either use https://github.com/rhasspy/piper or look at how they do it. It'd be a great PR though.

stale[bot] commented 1 year ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. You might also look our discussion channels.