cywang97 / StreamingTransformer

Apache License 2.0
271 stars 42 forks source link

bad performance for streaming transformer using trigger #14

Open housebaby opened 3 years ago

housebaby commented 3 years ago

Hello, I trained a streaming transformer with following config, it seams that the loss is OK but the decoding performance is bad. Is it neccesary to use prefix-decoder ? When I use prefix-recognizie, error occurs. If I don't use prefix-recognize , the performance is bad

File "/home/storage15/username/tools/espnet/egs/librispeech/asr1/../../../espnet/bin/asr_recog.py", line 368, in main(sys.argv[1:]) File "/home/storage15/username/tools/espnet/egs/librispeech/asr1/../../../espnet/bin/asr_recog.py", line 335, in main recog_v2(args) File "/home/storage15/username/tools/espnet/espnet/asr/pytorch_backend/recog.py", line 174, in recog_v2 best, ids, score = model.prefix_recognize(feat, args, train_args, train_args.char_list, lm) File "/home/storage15/username/tools/espnet/espnet/nets/pytorch_backend/streaming_transformer.py", line 553, in prefix_recognize self.compute_hyps(tmp,i,h_len,enc_output, hat_att[chunk_index], mask, train_args.chunk) File "/home/storage15/username/tools/espnet/espnet/nets/pytorch_backend/streaming_transformer.py", line 776, in compute_hyps enc_output4use, partial_mask4use, cache4use) File "/home/storage15/username/tools/espnet/espnet/nets/pytorch_backend/transformer/decoder.py", line 310, in forward_one_step x, tgt_mask, memory, memory_mask, cache=c File "/home/storage15/username/tools/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in call result = self.forward(*input, **kwargs) File "/home/storage15/username/tools/espnet/espnet/nets/pytorch_backend/transformer/decoder_layer.py", line 94, in forward ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" AssertionError: torch.Size([5, 1, 512]) == (5, 2, 512)

train config:

This configuration requires 4 gpus with 12GB memory

accum-grad: 1 adim: 512 aheads: 8 batch-bins: 3000000 dlayers: 6 dropout-rate: 0.1 dunits: 2048 elayers: 12 epochs: 120 eunits: 2048 grad-clip: 5 lsm-weight: 0.1 model-module: espnet.nets.pytorch_backend.streaming_transformer:E2E mtlalpha: 0.3 opt: noam patience: 0 sortagrad: 0 transformer-attn-dropout-rate: 0.0 transformer-init: pytorch transformer-input-layer: conv2d transformer-length-normalized-loss: false transformer-lr: 1.0 transformer-warmup-steps: 2500 n-iter-processes: 0

enc-init: exp/train_960_pytorch_train_specaug/results/model.val5.avg.best

/path/to/model

enc-init-mods: encoder,ctc,decoder

streaming: true chunk: true chunk-size: 32

decode_config: lm-weight: 0.5 beam-size: 5 penalty: 2.0 maxlenratio: 0.0 minlenratio: 0.0 ctc-weight: 0.5 threshold: 0.0005 ctc-lm-weight: 0.5 prefix-decode: true

qzfnihao commented 3 years ago

Which version do you use?it looks like you merge steaming transducer to other espnet, because in this git, there is not invoking recog_v2