facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.17k stars 6.37k forks source link

Too low BLEU score in reproducing simultaneous speech translation (MuST-C en-de) #3978

Open keiouok opened 2 years ago

keiouok commented 2 years ago

I tried to train simultaneous speech translation following simul_mustc_example.md. I trained simulst in bc3bd55ec98c39af45ff7323ae49bcbdf93acc36 branch (because in main 1ef3d6a1a2cb7fa9937233c8bf796957871bfc94 branch, Not found arch error was occurred. Preprocess and pretraining ASR was in main branch. )

However, the simuleval's BLEU result was terribly low. (documentation's BLEU : about 13)

{
    "Quality": {
        "BLEU": 0.010281760678875936
    },
    "Latency": {
        "AL": 1299.5488329014138,
        "AL_CA": 1662.8518504975802,
        "AP": 0.45365599087862923,
        "AP_CA": 0.5802261617796486,
        "DAL": 1460.296257343413,
        "DAL_CA": 1764.66547917616
    }
}

And most of the prediction in instance.log were like (Applause), (Musik).

{"index": 0, "prediction": "(Applaus) </s>", "delays": [1680.0, 1680.0], "elapsed": [2029.341630935669, 2033.8196086883545], "prediction_length": 2, "reference": "Diese Durchbr\u00fcche m\u00fcssen wir mit Vollgas verfolgen und das k\u00f6nnen wir messen: in Firmenzahlen, in Pilotprojekten und Regulierungs\u00e4nderungen.", "source": ["/*/dev/ted_767_0.wav", "samplerate: 16000 Hz", "channels: 1", "duration: 8.600 s", "format: WAV (Microsoft) [WAV]", "subtype: Signed 16 bit PCM [PCM_16]"], "source_length": 8600.0, "reference_length": 18, "metric": {"sentence_bleu": 0.44439199324869233, "latency": {"AL": 1453.6842041015625, "AP": 0.1953488439321518, "DAL": 1680.0}, "latency_ca": {"AL": 1805.264892578125, "AP": 0.2362302988767624, "DAL": 2029.341796875}}}
{"index": 1, "prediction": "(Applaus) </s>", "delays": [1680.0, 1680.0], "elapsed": [2007.0959854125977, 2011.3920497894287], "prediction_length": 2, "reference": "Es gibt viele gro\u00dfartige B\u00fccher zu diesem Thema.", "source": ["/*/dev/ted_767_1.wav", "samplerate: 16000 Hz", "channels: 1", "duration: 2.530 s", "format: WAV (Microsoft) [WAV]", "subtype: Signed 16 bit PCM [PCM_16]"], "source_length": 2530.0, "reference_length": 8, "metric": {"sentence_bleu": 2.4675789207681893, "latency": {"AL": 1539.4444580078125, "AP": 0.6640316247940063, "DAL": 1680.0}, "latency_ca": {"AL": 1868.6884765625, "AP": 0.7941675782203674, "DAL": 2007.095947265625}}}

If there are any solutions, please let me know. Thank you.

Code

Pre-trained ASR : checkpoint_best.pt with this code

fairseq-train ${OUT_ROOT} \
  --config-yaml config_asr.yaml --train-subset train_asr --valid-subset dev_asr \
  --save-dir ${TMP} --num-workers 1 --max-tokens 20000 --max-update 100000 \
  --task speech_to_text --criterion label_smoothed_cross_entropy --report-accuracy \
  --arch convtransformer_espnet --optimizer adam --lr 0.0005 --lr-scheduler inverse_sqrt \
  --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 16 --patience 4

Simultaneous speech translation

fairseq-train ${OUT_ROOT} \
       --config-yaml config_st.yaml --train-subset train_st --valid-subset dev_st \
       --save-dir ${TMP} --num-workers 1 \
       --optimizer adam --lr 0.0001 --lr-scheduler inverse_sqrt --clip-norm 10.0 \
       --criterion label_smoothed_cross_entropy \
       --warmup-updates 4000 --max-update 100000 --max-tokens 20000 --seed 1 \
       --load-pretrained-encoder-from ${ASR_SAVE_DIR}/checkpoint_best.pt \
       --task simul_speech_to_text \
       --arch convtransformer_simul_trans_espnet \
       --simul-type waitk_fixed_pre_decision \
       --waitk-lagging 3 \
       --fixed-pre-decision-ratio 7 \
       --update-freq 16 \
       --patience 4

What's your environment?

dev_st loss

image

EricLina commented 2 years ago

Do you use the SimulEval ? why not share how you use it in detail? there exist much difference in agent.py .

EricLina commented 2 years ago

@keiouok

keiouok commented 2 years ago

@duj12 Thank you for your empirical advice. Training in s2t_transformer is maybe efficient, I'll try to simul ST implement in s2t_transformer.

@1190301804 Yes, I followed the simuleval scripts in simul_mustc_example.md.

And, I happened maybe my extracted fbank feature was not suitable for default simultaneous ST. In this issue, I failed in this result by using the same fbank feature extracted in offline MuST-C ST docs setting (vocab 8000, not global cmvn) in main branch at that time. After that, I successed simulST (BLEU 11-12) using global cmvn feature following simul_mustc_example.md in 436166a00c2ecd1215df258f022608947cca2aa8 branch (both preprocess and train). However, I failed in without-global cmvn feature (almost all outputs were (Applaus)).

I can't believe the different features between gcmvn and cmvn causes such different results... I'll check the detail when I have time. Anyway, we could reproduce the result. Thank you.

EricLina commented 2 years ago

Sorry, By the way , Have you used the MMA model(text2text), Do you know how to write the agent.py in Ende dataset ?

keiouok commented 2 years ago

@1190301804 Sorry, I have not used the MMA model...

duj12 commented 2 years ago

simul_speech_to_text And, I happened maybe my extracted fbank feature was not suitable for default simultaneous ST. In this issue, I failed in this result by using the same fbank feature extracted in offline MuST-C ST docs setting (vocab 8000, not global cmvn) in main branch at that time. After that, I successed simulST (BLEU 11-12) using global cmvn feature following simul_mustc_example.md in 436166a00c2ecd1215df258f022608947cca2aa8 branch (both preprocess and train). However, I failed in without-global cmvn feature (almost all outputs were (Applaus)). I can't believe the different features between gcmvn and cmvn causes such different results... I'll check the detail when I have time. Anyway, we could reproduce the result. Thank you.

Hi, @keiouok . I firstly misused the config of 'non-global cmvn', too.
I trained with global cmvn, but inferenced with utterance cmvn, and the performance was poor. Then I use global cmvn to evaluate, and I average the 5 best checkpoints in development set, the result remains poor(BLEU=0.05.The training set is 80-hour Chinese-English BSTC+CommonVoice data, with pretrained ASR whose CER is 13%, offline ST’s BLEU is 8.6), the translations in instance.log are totally irrelevant to the audio.

I think maybe the learning rate is too small, so I increase it to 1e-3/1e-2, and keep other parameters unchanged, but the result is worse, lr = 1e-4 is more proper for this task. I also notice that the --task in your script is different with simul_mustc_example.md(simul_speech_to_text V.S. speech_to_text ), but this doesn't matter, actually.

Now, I am confused, how can I reproduce the result of simul ST. Is the dataset I used too small?

keiouok commented 2 years ago

Thank you, @duj12. I have only successed in En-De dataset. Now I'm trying to use MuST-C En-Ja dataset in https://iwslt.org/2022/offline#allowed-training-data. However, the result was poor even offline ST (pretrained ASR WER was about 14). The offline ST BLEU resulted in about 0.1 (the dev-st accuracy was about 30 and didn't surpass 40. The early stopping patience was 16. )

Now I'm tuning learning rate (convtransformer default simulst lr was 0.0005), but it didn't improved. When lr was larger like 0.02 / 0.002, the accuracy would be worse, when lr was smaller like 0.0001, the loss and accuracy convergence was slower but finally dev-st accuracy didn't surpass 30. I tried s2t_transformer also but result was the same. Needless to say, simulst was also terribly bad score even if wait-k (k=100).

I'm also confused about the result... I'm sorry I could help you now. If I found anything, I would suggest you again.

EricLina commented 2 years ago

even worse ,following simul_mustc_example.md for preprocess and train . On the full MustC-ende dataset (69G), I trained ASR model for about 120hours on 8GPUs with 900 epoch (I shut down it because it is too slow ), and ST model for about 70 epoch(I stop it early because it is also too slow ), then for evaluation, I use seg_mustc_data.py to split the dataset, and I use 100 sentences of them for evaluation(simuleval Connection refused when testing set is large) . The result is very poor ...

Does anyone have some suggestions? Thank you!

2022-03-09 20:34:51 | INFO     | simuleval.cli    | Evaluation results:
{
    "Quality": {
        "BLEU": 0.2027780041409297
    },
    "Latency": {
        "AL": 1248.1308325195312,
        "AL_CA": 15497.652229003907,
        "AP": 0.3861502431333065,
        "AP_CA": 39.99524466373026,
        "DAL": 1411.803270072937,
        "DAL_CA": 19719.095799560546
    }
}
duj12 commented 2 years ago

After debugging the training code, I found the reason. There is an example: encoder_state has 14 frames, and text_sequence has 4 tokens, batch_size = 1. In wait-k(k=3) mode, the p_choose(in p_choose_strategy.py) is formulated as: [[[0 0 1 0 0 0 0 0 0 0 0 0 0 0] [0 0 0 1 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 1 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 1 0 0 0 0 0 0 0 0]]] and alpha is the same as p_choose, beta is something like [[[0.3 0.3 0.4 0 0 0 0 0 0 0 0 0 0 0] [0.2 0.3 0.3 0.2 0 0 0 0 0 0 0 0 0 0] [0.2 0.2 0.1 0.2 0.3 0 0 0 0 0 0 0 0 0] [0.1 0.2 0.2 0.1 0.2 0.2 0 0 0 0 0 0 0 0]]] we can see, only context between [0~7) [0~k+text_seq_len) frame will be weighted sum, but the total context length is equal to the encoder_state' length.

So in this implementation, too many context near the tail are ignored, which may be the reason of poor performance.