Closed Robinatp closed 3 years ago
--- a/otrans/model/lm.py +++ b/otrans/model/lm.py @@ -143,12 +143,13 @@ class TransformerLanguageModel(BaseLM): def predict(self, targets, last_frame=True):
dec_output = self.embedding(targets)
decoutput, = self.pos_embedding(dec_output)
dec_mask = get_seq_mask(targets)
for _, block in enumerate(self.blocks):
--- a/otrans/recognize/speech2text.py +++ b/otrans/recognize/speech2text.py @@ -99,10 +99,11 @@ class SpeechToTextRecognizer(Recognizer):
batch_log_probs, dec_cache, dec_attn_weights = self.decode(preds, memory, memory_mask, cache['decoder'])
if self.lm is not None:
batch_lm_log_probs, lm_hidden = self.lm_decode(preds, cache['lm'])
batch_lm_log_probs = batch_lm_log_probs.squeeze(1)
hello, I get some errors when I run the command below:
python eval.py -m egs/aishell/exp/transformer_baseline/model.average.from50to59.pt -lm egs/aishell/exp/transformer_lm_baseline/model.epoch.59.pt
the error logs:
Traceback (most recent call last): File "eval.py", line 236, in
main(cmd_args)
File "eval.py", line 138, in main
preds, scores = recognizer.recognize(enc_inputs, enc_mask)
File "/data/home/xxx/e2e_asr/OpenTransformer/otrans/recognize/speech2text.py", line 63, in recognize
preds, cache, scores, ending_flag = self.decode_step(
File "/data/home/xxx/e2e_asr/OpenTransformer/otrans/recognize/speech2text.py", line 103, in decode_step
batch_lm_log_probs, lm_hidden = self.lm_decode(preds, cache['lm'])
File "/data/home/xxx/e2e_asr/OpenTransformer/otrans/recognize/base.py", line 33, in lm_decode
log_probs = self.lm.predict(preds, last_frame=True)
File "/data/home/xxx/e2e_asr/OpenTransformer/otrans/model/lm.py", line 151, in predict
dec_output, dec_mask = block(dec_output, dec_mask)
File "/data/home/xxx/miniconda3/envs/espnet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, kwargs)
File "/data/home/xxx/e2e_asr/OpenTransformer/otrans/encoder/transformer.py", line 49, in forward
slf_attn_out, slf_attn_weights = self.slf_attn(x, mask)
File "/data/home/xxx/miniconda3/envs/espnet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(*input, *kwargs)
File "/data/home/xxx/e2e_asr/OpenTransformer/otrans/module/attention.py", line 68, in forward
x = self.qvk_proj(x)
File "/data/home/xxx/miniconda3/envs/espnet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in call
result = self.forward(input, kwargs)
File "/data/home/xxx/miniconda3/envs/espnet/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 87, in forward
return F.linear(input, self.weight, self.bias)
File "/data/home/xxx/miniconda3/envs/espnet/lib/python3.8/site-packages/torch/nn/functional.py", line 1608, in linear
if input.dim() == 2 and bias is not None:
AttributeError: 'tuple' object has no attribute 'dim'
Could you please take a look at it?
THX