keonlee9420 / Stepwise_Monotonic_Multihead_Attention

PyTorch Implementation of Stepwise Monotonic Multihead Attention similar to Enhancing Monotonicity for Robust Autoregressive Transformer TTS
MIT License
31 stars 6 forks source link

Apply Pytorch Stepwise_Monotonic_Attention for Tacotron #2

Closed v-nhandt21 closed 2 years ago

v-nhandt21 commented 2 years ago

Can I apply this implementation to replace the Location sensitive attention in your previous Tacotron2 repo:

https://github.com/keonlee9420/tacotron2_MMI/blob/fe0e19fdb00f2554d99b0f33e56d65a1d1956b86/model.py#L227

keonlee9420 commented 2 years ago

Hi @v-nhandt21 , theoretically yes you can replace it by SMA, but please note that the current implementation is only tested as an aligner in the reference encoder like architecture.

v-nhandt21 commented 2 years ago

Yes, thank @keonlee9420 , I have used Stepwise Monotonic Attention for Tacotron2 and It gives me better results (more robust) than Location Sensitive Attention, Forward Attention, and Dynamic Convolution Attention

`class StepwiseMonotonicAttention(nn.Module):

def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
             attention_location_n_filters, attention_location_kernel_size):
    super(StepwiseMonotonicAttention, self).__init__()
    self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, w_init_gain='tanh')
    self.score_mask_value = -float("inf")
    sigmoid_noise=2.0

    self.tanh = nn.Tanh()
    self.v = nn.Linear(attention_dim, 1, bias=False)
    self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
                                  bias=False, w_init_gain='tanh')
    self.alignment = None # alignment in previous query time step
    self.sigmoid_noise = sigmoid_noise

def init_attention(self, processed_memory):
    b, t, c = processed_memory.size()
    self.alignment = processed_memory.new_zeros(b, t)
    self.alignment[:, 0:1] = 1

def stepwise_monotonic_attention(self, p_i, prev_alignment):
    pad = prev_alignment.new_zeros(prev_alignment.size(0), 1)
    alignment = prev_alignment * p_i + torch.cat((pad, prev_alignment[:, :-1] * (1.0 - p_i[:, :-1])), dim=1)
    return alignment

def get_selection_probability(self, e, std):
    if self.training:
        noise = e.new_zeros(e.size()).normal_()
        e = e + noise * std
    return torch.sigmoid(e)

def get_probabilities(self, energies):
    p_i = self.get_selection_probability(energies, self.sigmoid_noise)
    alignment = self.stepwise_monotonic_attention(p_i, self.alignment)

    # (batch, max_time)
    self.alignment = alignment
    return alignment

def get_energies(self, query, processed_memory,
                           attention_weights_cat):

    processed_query = self.query_layer(query.unsqueeze(1))
    energies = self.v(torch.tanh(processed_query + processed_memory))
    energies = energies.squeeze(-1)

    return energies

def forward(self, attention_hidden_state, memory, processed_memory,
            attention_weights_cat, mask, log_alpha):

    alignment = self.get_energies(attention_hidden_state, processed_memory, attention_weights_cat) # (batch, max_time)

    if mask is not None:
        alignment.data.masked_fill_(mask, self.score_mask_value)

    alignment = self.get_probabilities(alignment)

    # print(alignment)
    attention_weights = alignment #F.softmax(alignment, dim=1)
    attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) # (batch, 1, dim)
    attention_context = attention_context.squeeze(1) # # (batch, dim)

    return attention_context, attention_weights`
keonlee9420 commented 2 years ago

Yes, thank @keonlee9420 , I have used Stepwise Monotonic Attention for Tacotron2 and It gives me better results (more robust) than Location Sensitive Attention, Forward Attention, and Dynamic Convolution Attention

`class StepwiseMonotonicAttention(nn.Module):

def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
             attention_location_n_filters, attention_location_kernel_size):
    super(StepwiseMonotonicAttention, self).__init__()
    self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, w_init_gain='tanh')
    self.score_mask_value = -float("inf")
    sigmoid_noise=2.0

    self.tanh = nn.Tanh()
    self.v = nn.Linear(attention_dim, 1, bias=False)
    self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
                                  bias=False, w_init_gain='tanh')
    self.alignment = None # alignment in previous query time step
    self.sigmoid_noise = sigmoid_noise

def init_attention(self, processed_memory):
    b, t, c = processed_memory.size()
    self.alignment = processed_memory.new_zeros(b, t)
    self.alignment[:, 0:1] = 1

def stepwise_monotonic_attention(self, p_i, prev_alignment):
    pad = prev_alignment.new_zeros(prev_alignment.size(0), 1)
    alignment = prev_alignment * p_i + torch.cat((pad, prev_alignment[:, :-1] * (1.0 - p_i[:, :-1])), dim=1)
    return alignment

def get_selection_probability(self, e, std):
    if self.training:
        noise = e.new_zeros(e.size()).normal_()
        e = e + noise * std
    return torch.sigmoid(e)

def get_probabilities(self, energies):
    p_i = self.get_selection_probability(energies, self.sigmoid_noise)
    alignment = self.stepwise_monotonic_attention(p_i, self.alignment)

    # (batch, max_time)
    self.alignment = alignment
    return alignment

def get_energies(self, query, processed_memory,
                           attention_weights_cat):

    processed_query = self.query_layer(query.unsqueeze(1))
    energies = self.v(torch.tanh(processed_query + processed_memory))
    energies = energies.squeeze(-1)

    return energies

def forward(self, attention_hidden_state, memory, processed_memory,
            attention_weights_cat, mask, log_alpha):

    alignment = self.get_energies(attention_hidden_state, processed_memory, attention_weights_cat) # (batch, max_time)

    if mask is not None:
        alignment.data.masked_fill_(mask, self.score_mask_value)

    alignment = self.get_probabilities(alignment)

    # print(alignment)
    attention_weights = alignment #F.softmax(alignment, dim=1)
    attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) # (batch, 1, dim)
    attention_context = attention_context.squeeze(1) # # (batch, dim)

    return attention_context, attention_weights`

Thanks for sharing! Wow, It seems really promising. I think your experiment can boost up the community in terms of the most robust aligner in AR-based TTS model :)