sooftware / kospeech

Open-Source Toolkit for End-to-End Korean Automatic Speech Recognition leveraging PyTorch and Hydra.
https://sooftware.github.io/kospeech/
Apache License 2.0
603 stars 191 forks source link

Unexpected broadcasting? #43

Closed deyunote closed 4 years ago

deyunote commented 4 years ago

Hi and thanks for making a great project!👍

When I tried seq2seq decoder with LocationAwareAttention, I found possibly unexpected broacasting caused by AddNorm. https://github.com/sooftware/KoSpeech/blob/f90354b565a43217cce580fbbe20629e3d41a174/kospeech/models/acoustic/seq2seq/decoder.py#L124 Expected context.size() is (batch, hidden_dim) but I received (batch, batch, hidden_dim).

I think this probrem is caused by AddNorm. https://github.com/sooftware/KoSpeech/blob/f90354b565a43217cce580fbbe20629e3d41a174/kospeech/models/acoustic/transformer/sublayers.py#L24-L25 output[0] + residual represents (batch_size, hidden_dim) + (batch_size, 1, hidden_dim) and this is broadcasted to (batch_size, batch_size, hidden_dim).

Reproduction code

import torch

from kospeech.models.attention import LocationAwareAttention
from kospeech.models.acoustic.transformer.sublayers import AddNorm

hidden_dim = 512
attention = LocationAwareAttention(d_model=hidden_dim)
attention_norm = AddNorm(LocationAwareAttention(d_model=hidden_dim), d_model=hidden_dim)

batch_size = 2
seq_length = 128

output = torch.rand(batch_size, 1, hidden_dim, dtype=torch.float32)
encoder_outputs = torch.rand(batch_size, seq_length, hidden_dim, dtype=torch.float32)
attn = None

context, _ = attention(output, encoder_outputs, attn)
context_norm, _ = attention_norm(output, encoder_outputs, attn)

print(context.size()) #torch.Size([2, 512]) i.e. (batch_size, hidden_dim)
print(context_norm.size()) #torch.Size([2, 2, 512]) i.e. (batch_size, batch_size, hidden_dim)

Environment

PyTorch 1.7.0a0+018b4d7 Python 3.7.7

sooftware commented 4 years ago

Thank you for let me know.
I fix this bug in lask commit.

I didn't know this bug because i don't use loc-aware very well.
If there are any other issue, please let me know. Thank you !!