keonlee9420 / Comprehensive-E2E-TTS

A Non-Autoregressive End-to-End Text-to-Speech (text-to-wav), supporting a family of SOTA unsupervised duration modelings. This project grows with the research community, aiming to achieve the ultimate E2E-TTS
145 stars 19 forks source link

Question about Differentiable Duration Modeling #4

Open LEECHOONGHO opened 2 years ago

LEECHOONGHO commented 2 years ago

Hello, I'm trying to implement Differentiable Duration Modeling(DDM) module introduced in Differentiable Duration Modeling for End-to-End Text-to-Speech.

I opened this issue to get advice on implementation DDM.

My Implementation of Differentiable Alignment Encoder outputs attention like thing from noise input. But the training speed of DDM is too slow(10s/iter). Seems like it hanged in backward progress.

Can anyone give me some advice to improve the speed of recursive tensor operation? Should I use cuda.jit like Soft DTW? Or is there something wrong with the approach itself?

The module's output from noise input and code is like below.

Thank you.

dae = DifferentiableAlignmentEncoder()
b = 5
text_max_len = 25
mel_max_len = 85
dim = 256
x_len = torch.randint(1, text_max_len, (b,))
mel_len = torch.randint(2, mel_max_len, (b,))
x = torch.randn(b, max(x_len), dim)
s, l, q, dur = dae(x, x_len, mel_len)
i = 2
plt.imshow(l[i, :x_len[i], :mel_len[i]].detach().numpy())
plt.imshow(q[i, :x_len[i], :mel_len[i]].detach().numpy())
plt.imshow(s[i, :x_len[i], :mel_len[i]].detach().numpy())
plt.plot(dur[i, :x_len[i]])

L image Q image S = soft attention image Duration image

Code

class DifferentiableAlignmentEncoder(nn.Module):
    def __init__(
        self,
        hidden_dim=256,
        conv_kernels=3,
        num_layers=3,
        dropout_p=0.2,
        max_mel_len=1150 # Max Length of Mel-Spectrogram Frame in training data
    ):
        super().__init__()

        self.conv_layer_blocks = nn.ModuleList([
            nn.Sequential(
                ConvNorm(hidden_dim, hidden_dim, conv_kernels, bias=True, transpose=True),
                nn.ReLU(),
                nn.LayerNorm(hidden_dim),
                nn.Dropout(dropout_p)
            )
            for i in range(num_layers)
        ])
        self.dur_prob_proj = LinearNorm(hidden_dim, max_mel_len, bias=False)

        self.ddm = DifferentiableDurationModeling()

    def forward(self, x, phon_lens, mel_lens, x_masks=None):

        """
        x  : Tensor[B, T_phon, C_phone]
        phon_lens : LongTensor[B]
        mel_lens : LongTensor[B]
        s : S Matrix : Tensor[B, T_phon, T_mel]
        dur : Duration Matrix : Tensor[B, T_phon]
        """

        max_mel_len = int(torch.max(mel_lens))

        for layer in self.conv_layer_blocks:
            if x_masks is not None:
                x = x * (1 - x_masks.float())
            x = layer(x)
        x = self.dur_prob_proj(x)

        norm = torch.randn(x.shape).to(x.device)
        x = x + norm

        p = torch.sigmoid(x)
        p = p[:, :, :max_mel_len]

        s, l, q, dur = self.ddm(p, phon_lens, mel_lens)

        dur = dur.detach()

        return s, l, q, dur

class DifferentiableDurationModeling(nn.Module):
    def __init__(self):
        super().__init__()

    def _get_attn_mask(self, phon_lens, mel_lens):
        phon_mask = ~get_mask_from_lengths(phon_lens)
        mel_mask = ~get_mask_from_lengths(mel_lens)

        return phon_mask.unsqueeze(-1) * mel_mask.unsqueeze(1), phon_mask

    def forward(self, p, phon_lens, mel_lens):

        attn_mask, phon_mask = self._get_attn_mask(phon_lens, mel_lens)

        p = p * attn_mask

        l = self._get_l(p, attn_mask)

        l = l * attn_mask

        dur = self._get_duration(l)

        dur = dur * phon_mask

        q = self._get_q(l)

        q = q * attn_mask

        s = self._get_s(q, l)

        s = s * attn_mask

        return s, l, q, dur

    def _get_duration(self, l):
        with torch.no_grad():
            m = torch.arange(1, l.shape[-1] + 1)[None, :].expand_as(l).to(l.device)
            dur = torch.sum(m * l, dim=-1)
        return dur

    def _get_l(self, p, mask):
        # getting l is numerically unstable for the gradient computation.
        # Paper's Author resolve this issue by computing this product in the log-space
        _p = torch.log(mask[:, :, 1:].float() - p[:, :, 1:] + 1e-8)
        p = torch.log(p + 1e-8)
        com = torch.cumsum(_p, dim=-1)
        l_0 = com[:, :, -1].unsqueeze(-1)
        l_1 = p[:, :, 1].unsqueeze(-1)

        l_m = com[:, :, :-1] + p[:, :, 2:]

        l = torch.cat([l_0, l_1, l_m], dim=-1)

        l = torch.exp(l)

        return l

    def _variable_kernel_size_convolution(self, x, y, length):
        matrix = torch.flip(x.unsqueeze(1) * y.unsqueeze(-1), dims=[-1])
        output =  torch.flip(
            torch.cat(
                [
                    torch.sum(
                        torch.diagonal(
                            matrix, offset=idx, dim1=-2, dim2=-1
                        ), dim=1
                    ).unsqueeze(1) 
                    for idx in range(length)
                ],
                dim=1
            ),
            dims=[1] 
        )
        return output

    def _get_q(self, l):
        length = l.shape[-1]
        q = [l[:, 0, :]]
        if l.shape[-1] > 1:
            for i in range(1, l.shape[1]):
                q.append(self._variable_kernel_size_convolution(q[i-1], l[:, i], length))

        q = torch.cat([_.unsqueeze(1) for _ in q], dim=1)

        return q   

    def _reverse_cumsum(self, x):
        return torch.flip(torch.cumsum(torch.flip(x, dims=[-1]), dim=-1), dims=[-1])

    def _get_s(self, q, l):
        length = l.shape[-1]
        l_rev_cumsum = self._reverse_cumsum(l)
        s = [l_rev_cumsum[:, 0, :]]

        if l.shape[-1] > 1:
            for i in range(1, q.shape[1]):
                s.append(self._variable_kernel_size_convolution(q[:, i-1], l_rev_cumsum[:, i], length))

        s = torch.cat([_.unsqueeze(1) for _ in s], dim=1)

        return s