Closed deyunote closed 4 years ago
Hi and thanks for making a great project!👍
When I tried seq2seq decoder with LocationAwareAttention, I found possibly unexpected broacasting caused by AddNorm. https://github.com/sooftware/KoSpeech/blob/f90354b565a43217cce580fbbe20629e3d41a174/kospeech/models/acoustic/seq2seq/decoder.py#L124 Expected context.size() is (batch, hidden_dim) but I received (batch, batch, hidden_dim).
context.size()
I think this probrem is caused by AddNorm. https://github.com/sooftware/KoSpeech/blob/f90354b565a43217cce580fbbe20629e3d41a174/kospeech/models/acoustic/transformer/sublayers.py#L24-L25 output[0] + residual represents (batch_size, hidden_dim) + (batch_size, 1, hidden_dim) and this is broadcasted to (batch_size, batch_size, hidden_dim).
output[0] + residual
import torch from kospeech.models.attention import LocationAwareAttention from kospeech.models.acoustic.transformer.sublayers import AddNorm hidden_dim = 512 attention = LocationAwareAttention(d_model=hidden_dim) attention_norm = AddNorm(LocationAwareAttention(d_model=hidden_dim), d_model=hidden_dim) batch_size = 2 seq_length = 128 output = torch.rand(batch_size, 1, hidden_dim, dtype=torch.float32) encoder_outputs = torch.rand(batch_size, seq_length, hidden_dim, dtype=torch.float32) attn = None context, _ = attention(output, encoder_outputs, attn) context_norm, _ = attention_norm(output, encoder_outputs, attn) print(context.size()) #torch.Size([2, 512]) i.e. (batch_size, hidden_dim) print(context_norm.size()) #torch.Size([2, 2, 512]) i.e. (batch_size, batch_size, hidden_dim)
PyTorch 1.7.0a0+018b4d7 Python 3.7.7
Thank you for let me know. I fix this bug in lask commit.
I didn't know this bug because i don't use loc-aware very well. If there are any other issue, please let me know. Thank you !!
Hi and thanks for making a great project!👍
When I tried seq2seq decoder with LocationAwareAttention, I found possibly unexpected broacasting caused by AddNorm. https://github.com/sooftware/KoSpeech/blob/f90354b565a43217cce580fbbe20629e3d41a174/kospeech/models/acoustic/seq2seq/decoder.py#L124 Expected
context.size()
is (batch, hidden_dim) but I received (batch, batch, hidden_dim).I think this probrem is caused by AddNorm. https://github.com/sooftware/KoSpeech/blob/f90354b565a43217cce580fbbe20629e3d41a174/kospeech/models/acoustic/transformer/sublayers.py#L24-L25
output[0] + residual
represents (batch_size, hidden_dim) + (batch_size, 1, hidden_dim) and this is broadcasted to (batch_size, batch_size, hidden_dim).Reproduction code
Environment
PyTorch 1.7.0a0+018b4d7 Python 3.7.7