jaywalnut310 / vits

VITS: Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech
https://jaywalnut310.github.io/vits-demo/index.html
MIT License
6.86k stars 1.26k forks source link

Problem with export model to onnx #156

Open JoanisTriandafilidi opened 1 year ago

JoanisTriandafilidi commented 1 year ago

Hello. Can anyone help me with exporting the model to onnx? I looked at a couple of issues on this topic, but they did not give me all the answers. I also tried using one of the pull requests during which I got errors. This is my first experience and I would really appreciate any help.

kafan1986 commented 1 year ago

@JoanisTriandafilidi Any progress?

tusharhchauhan commented 4 months ago
class SynthesizerTrn_inf(nn.Module):
    """
    Synthesizer for Training
    """

    def __init__(self,
                 n_vocab,
                 spec_channels,
                 segment_size,
                 inter_channels,
                 hidden_channels,
                 filter_channels,
                 n_heads,
                 n_layers,
                 kernel_size,
                 p_dropout,
                 resblock,
                 resblock_kernel_sizes,
                 resblock_dilation_sizes,
                 upsample_rates,
                 upsample_initial_channel,
                 upsample_kernel_sizes,
                 n_speakers=0,
                 gin_channels=0,
                 use_sdp=True,
                 **kwargs):

        super().__init__()
        self.n_vocab = n_vocab
        self.spec_channels = spec_channels
        self.inter_channels = inter_channels
        self.hidden_channels = hidden_channels
        self.filter_channels = filter_channels
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.kernel_size = kernel_size
        self.p_dropout = p_dropout
        self.resblock = resblock
        self.resblock_kernel_sizes = resblock_kernel_sizes
        self.resblock_dilation_sizes = resblock_dilation_sizes
        self.upsample_rates = upsample_rates
        self.upsample_initial_channel = upsample_initial_channel
        self.upsample_kernel_sizes = upsample_kernel_sizes
        self.segment_size = segment_size
        self.n_speakers = n_speakers
        self.gin_channels = gin_channels

        self.use_sdp = use_sdp

        self.enc_p = TextEncoder(n_vocab,
                                 inter_channels,
                                 hidden_channels,
                                 filter_channels,
                                 n_heads,
                                 n_layers,
                                 kernel_size,
                                 p_dropout)
        self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
        self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)

        if use_sdp:
            self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
        else:
            self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)

        if n_speakers > 1:
            self.emb_g = nn.Embedding(n_speakers, gin_channels)

    def forward(self, x):
        x_lengths = torch.LongTensor([x.size(1)])
        sid = None
        noise_scale = 1
        length_scale = 1
        noise_scale_w = 1.
        max_len = None
        x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)

        if self.n_speakers > 0:
            g = self.emb_g(sid).unsqueeze(-1)  # [b, h, 1]
        else:
            g = None
        if self.use_sdp:
            logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
        else:
            logw = self.dp(x, x_mask, g=g)

        w = torch.exp(logw) * x_mask * length_scale
        w_ceil = torch.ceil(w)
        y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
        y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
        attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
        attn = commons.generate_path(w_ceil, attn_mask)

        m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)  # [b, t', t], [b, t, d] -> [b, d, t']
        logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)  # [b, t', t], [b, t, d] -> [b, d, t']

        z_p = m_p + torch.ones(m_p.size()) * torch.exp(logs_p) * noise_scale
        z = self.flow(z_p, y_mask, g=g, reverse=True)
        o = self.dec((z * y_mask)[:, :, :max_len], g=g)
        return o
size = torch.randint(0, 255, (1,)).item()
tensor = torch.zeros((295), dtype=torch.int32)
x_tst = tensor.unsqueeze(0)
hps = utils.get_hparams_from_file("./configs/ljs_base.json")

net_g = SynthesizerTrn_inf(
    len(symbols),
    hps.data.filter_length // 2 + 1,
    hps.train.segment_size // hps.data.hop_length,
    **hps.model)
_ = utils.load_checkpoint("pretrained_ljs.pth", net_g, None)
# Set random values at the odd indices
for i in range(1, size, 2):
    tensor[i] = torch.randint(1, 160, (1,)).item()
net_g.eval()
net_g = torch.jit.trace(net_g, x_tst)
with torch.no_grad():
    torch.onnx.export(net_g, x_tst, "/VITS.onnx",
                      export_params=True,
                      keep_initializers_as_inputs=True,
                      opset_version=11, 
                      do_constant_folding=True,
                      operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
                      verbose=True, input_names=["input_0"], output_names=["output_0"], dynamic_axes=dynamic_axes)

This code is tested and working. you can use opset version according to you requirements.