ZhengkunTian / OpenTransformer

A No-Recurrence Sequence-to-Sequence Model for Speech Recognition
MIT License
372 stars 66 forks source link

LM Shollow Fusion error #39

Closed Robinatp closed 3 years ago

Robinatp commented 3 years ago

hello, I get some errors when I run the command below:

python eval.py -m egs/aishell/exp/transformer_baseline/model.average.from50to59.pt -lm egs/aishell/exp/transformer_lm_baseline/model.epoch.59.pt

the error logs:

Traceback (most recent call last): File "eval.py", line 236, in main(cmd_args) File "eval.py", line 138, in main preds, scores = recognizer.recognize(enc_inputs, enc_mask) File "/data/home/xxx/e2e_asr/OpenTransformer/otrans/recognize/speech2text.py", line 63, in recognize preds, cache, scores, ending_flag = self.decode_step( File "/data/home/xxx/e2e_asr/OpenTransformer/otrans/recognize/speech2text.py", line 103, in decode_step batch_lm_log_probs, lm_hidden = self.lm_decode(preds, cache['lm']) File "/data/home/xxx/e2e_asr/OpenTransformer/otrans/recognize/base.py", line 33, in lm_decode log_probs = self.lm.predict(preds, last_frame=True) File "/data/home/xxx/e2e_asr/OpenTransformer/otrans/model/lm.py", line 151, in predict dec_output, dec_mask = block(dec_output, dec_mask) File "/data/home/xxx/miniconda3/envs/espnet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in call result = self.forward(*input, kwargs) File "/data/home/xxx/e2e_asr/OpenTransformer/otrans/encoder/transformer.py", line 49, in forward slf_attn_out, slf_attn_weights = self.slf_attn(x, mask) File "/data/home/xxx/miniconda3/envs/espnet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in call result = self.forward(*input, *kwargs) File "/data/home/xxx/e2e_asr/OpenTransformer/otrans/module/attention.py", line 68, in forward x = self.qvk_proj(x) File "/data/home/xxx/miniconda3/envs/espnet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in call result = self.forward(input, kwargs) File "/data/home/xxx/miniconda3/envs/espnet/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 87, in forward return F.linear(input, self.weight, self.bias) File "/data/home/xxx/miniconda3/envs/espnet/lib/python3.8/site-packages/torch/nn/functional.py", line 1608, in linear if input.dim() == 2 and bias is not None: AttributeError: 'tuple' object has no attribute 'dim'

Could you please take a look at it?


THX

Robinatp commented 3 years ago

--- a/otrans/model/lm.py +++ b/otrans/model/lm.py @@ -143,12 +143,13 @@ class TransformerLanguageModel(BaseLM): def predict(self, targets, last_frame=True):

     dec_output = self.embedding(targets)

--- a/otrans/recognize/speech2text.py +++ b/otrans/recognize/speech2text.py @@ -99,10 +99,11 @@ class SpeechToTextRecognizer(Recognizer):

     batch_log_probs, dec_cache, dec_attn_weights = self.decode(preds, memory, memory_mask, cache['decoder'])