Wen Zhang, Yang Feng, Fandong Meng, Di You and Qun Liu. Bridging the Gap between Training and Inference for Neural Machine Translation. In Proceedings of ACL, 2019. [paper][code]
Codes in the two directories are the OR-NMT systems based on the RNNsearch and Transformer models correspondingly
This system has been tested in the following environment.
For OR-Transformer:
First, go into the OR-Transformer directory.
Then, the training script is the same with fairseq, except for the following arguments:
--use-word-level-oracles
for training Transformer by word-level oracle.--use-sentence-level-oracles
for training Transformer by sentence-level oracle.By default, the probability is decayed based on the update index.
--use-epoch-numbers-decay
for decaying based on the epoch index.--decay-k
is used to control the speed of the inverse sigmoid decay, which is in Eq.(15) in the paper.
8~15
for the decaying based on epoch index3000~8000
for the decaying based on update indexNOTE: For a new data set, the hyperparameter
--decay-k
needs to be manually adjusted according to the maximum number of training updates (default
) or epochs (--use-epoch-numbers-decay
) to ensure that the probability of sampling golden words does not decay so quickly.For Eq.(11~13) in the paper, is actually the same as . The operation is not needed in the code implementation.
Gumbel noise:
--use-greed-gumbel-noise
to sample word-level oracle with Gumbel noise--use-bleu-gumbel-noise
to sample sentence-level oracle with Gumbel noise--gumbel-noise
is used as the hyper-parameter in the calculation of Gumbel noise--oracle-search-beam-size
is used to set the beam size in length-constrained decodingAs for the --arch
and --criterion
arguments, oracle_
should be used as the prefix for OR-NMT training, such as:
--arch transformer_vaswani_wmt_en_de_big
-> --arch oracle_transformer_vaswani_wmt_en_de_big
--criterion label_smoothed_cross_entropy
-> --criterion oracle_label_smoothed_cross_entropy
Example of the script for word-level training and decaying the probability based on epoch index:
export CUDA_VISIBLE_DEVICES=0,1,2,3
batch_size=4096
accum=2
data_dir=directory_of_data_bin
model_dir=./ckpt
python train.py $data_dir \
--arch oracle_transformer_vaswani_wmt_en_de_big --share-all-embeddings \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --lr-scheduler inverse_sqrt \
--warmup-init-lr 1e-07 --warmup-updates 4000 --lr 0.0005 --min-lr 1e-09 \
--weight-decay 0.0 --criterion oracle_label_smoothed_cross_entropy --label-smoothing 0.1 \
--max-tokens $batch_size --update-freq $accum --no-progress-bar --log-format json --max-update 200000 \
--log-interval 10 --save-interval-updates 10000 --keep-interval-updates 10 --save-interval 10000 \
--seed 1111 --skip-invalid-size-inputs-valid-test \
--distributed-port 28888 --distributed-world-size 4 --ddp-backend=no_c10d \
--source-lang en --target-lang de --save-dir $model_dir \
--use-word-level-oracles --use-epoch-numbers-decay --decay-k 10 \
--use-greed-gumbel-noise --gumbel-noise 0.5 | tee -a $model_dir/training.log
Models | Translation Task | #GPUs | #Toks. | #Freq. | Max |
---|---|---|---|---|---|
Transformer-big | Zh->En | 8 | 4096 | 3 | 30 epochs |
+Word-level Oracle | Zh->En | 8 | 4096 | 3 | 30 epochs |
Transformer-base | En->De | 8 | 6144 | 2 | 80000 updates (62 epochs) |
+Word-level Oracle | En->De | 8 | 12288 | 1 | 80000 updates (62 epochs) |
+Sentence-level Oracle | En->De | 8 | 12288 | 1 | 40000 updates (62th epoch -> 93th epoch) |
#Toks. means batchsize on single GPU.
#Freq. means the times of gradient accumulation.
Max represents the maximum number of training epochs (30) or updates (80k).
We calculate the case-insensitive 4-gram tokenized BLEU by script multibleu.perl
Models | Dev. (MT02) | MT03 | MT04 | MT05 | MT06 | MT08 | Average | ||
---|---|---|---|---|---|---|---|---|---|
Transformer-big | 48.50 | 47.29 | 47.79 | 48.28 | 47.50 | 38.50 | 45.87 | 30 | |
+Word-level Oracle (==10) | 49.18 | 48.70 | 48.67 | 48.69 | 48.49 | 39.58 | 46.83 | 30 | |
+Word-level Oracle (==15) | 49.05 | 48.57 | 48.73 | 48.68 | 48.59 | 39.68 | 46.85 | 30 | |
+Word-level Oracle (==20) | 49.30 | 48.46 | 48.57 | 48.87 | 48.57 | 39.46 | 46.79 | 30 | |
+Word-level Oracle (==25) | 48.88 | 48.32 | 48.66 | 48.74 | 48.32 | 39.38 | 46.68 | 30 | |
+Word-level Oracle (==30) | 48.47 | 48.37 | 48.50 | 48.63 | 48.07 | 39.54 | 46.62 | 46.74 | 30 |
We also evaluate by the case-insensitive 4-gram detokenized BLEU with SacreBLEU, which is calculated the script score.py provided by fairseq: BLEU+case.mixed+lang.en-{de,fr}+numrefs.4+smooth.exp+tok.13a+version.1.4.4
Models | Dev. (MT02) | MT03 | MT04 | MT05 | MT06 | MT08 | Average | |
---|---|---|---|---|---|---|---|---|
Transformer-big | 48.46 | 47.41 | 47.88 | 48.25 | 47.52 | 38.60 | 45.93 | 30 |
+Word-level Oracle (==10) | 49.20 | 48.80 | 48.77 | 48.64 | 48.49 | 39.79 | 46.90 | 30 |
+Word-level Oracle (==15) | 49.07 | 48.64 | 48.81 | 48.63 | 48.65 | 39.88 | 46.92 | 30 |
+Word-level Oracle (==20) | 49.32 | 48.54 | 48.73 | 48.82 | 48.51 | 39.50 | 46.82 | 30 |
+Word-level Oracle (==25) | 48.90 | 48.18 | 48.70 | 48.59 | 47.73 | 39.14 | 46.47 | 30 |
+Word-level Oracle (==30) | 48.53 | 48.59 | 48.74 | 48.58 | 48.07 | 39.71 | 46.74 | 30 |
The setting of the NIST Chinese->English:
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
data_bin_dir=directory_of_data_bin
model_dir=./ckpt
python train.py $data_bin_dir \
--arch oracle_transformer_vaswani_wmt_en_de_big --share-all-embeddings \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --lr-scheduler inverse_sqrt \
--warmup-init-lr 1e-07 --warmup-updates 4000 --lr 0.0007 --min-lr 1e-09 \
--weight-decay 0.0 --criterion oracle_label_smoothed_cross_entropy --label-smoothing 0.1 \
--max-tokens 4096 --update-freq 3 --no-progress-bar --log-format json --max-epoch 30 \
--log-interval 10 --save-interval 2 --keep-last-epochs 10 \
--seed 1111 --use-epoch-numbers-decay \
--use-word-level-oracles --decay-k 15 --use-greed-gumbel-noise --gumbel-noise 0.5 \
--distributed-port 32222 --distributed-world-size 8 --ddp-backend=no_c10d \
--source-lang zh --target-lang en --save-dir $model_dir | tee -a $model_dir/training.log
As Eq.(15) in the paper, the probability of sampling golden words decays with the number of epochs as follows:
We calculate the case-sensitive 4-gram tokenized BLEU by script multibleu.perl
Models | newstest2014 | #update |
---|---|---|
Transformer-base | 27.54 | 80000 |
+Word-level Oracle (==50, ==0.8) | 28.01 | 80000 |
+Sentence-level Oracle (==5800, ==0.5, beam_size==4) | 28.45 | 40000 |
We also evaluate by the case-sensitive 4-gram detokenized BLEU with SacreBLEU, which is calculated the script score.py provided by fairseq: BLEU+case.mixed+lang.en-{de,fr}+numrefs.1+smooth.exp+tok.13a+version.1.4.4
Models | newstest2014 | #update |
---|---|---|
Transformer-base | 26.45 | 80000 |
+Word-level Oracle (==50, ==0.8) | 26.86 | 80000 |
+Sentence-level Oracle (==5800, ==0.5, beam_size==4) | 27.24 | 40000 |
Setting of the word-level oracle for the WMT'14 English->German dataset:
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
data_bin_dir=directory_of_data_bin
model_dir=./ckpt
python train.py $data_bin_dir \
--arch oracle_transformer_wmt_en_de --share-all-embeddings \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --lr-scheduler inverse_sqrt \
--warmup-init-lr 1e-07 --warmup-updates 4000 --lr 0.0007 --min-lr 1e-09 \
--weight-decay 0.0 --criterion oracle_label_smoothed_cross_entropy --label-smoothing 0.1 \
--max-tokens 12288 --update-freq 1 --no-progress-bar --log-format json --max-update 80000 \
--log-interval 10 --save-interval-updates 4000 --keep-interval-updates 10 --save-interval 10000 \
--seed 1111 --use-epoch-numbers-decay \
--use-word-level-oracles --decay-k 50 --use-greed-gumbel-noise --gumbel-noise 0.8 \
--distributed-port 31111 --distributed-world-size 8 --ddp-backend=no_c10d \
--source-lang en --target-lang de --save-dir $model_dir | tee -a $model_dir/training.log
As Eq.(15) in the paper, the probability of sampling golden words decays with the number of epochs as follows:
In order to save training time, we use the sentence-level oracle method to finetune the best base model.
Setting of the sentence-level oracle for the WMT'14 English->German dataset:
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
data_bin_dir=directory_of_data_bin
model_dir=./ckpt
python train.py $data_bin_dir \
--arch oracle_transformer_wmt_en_de --share-all-embeddings \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --lr-scheduler inverse_sqrt \
--warmup-init-lr 1e-07 --warmup-updates 4000 --lr 0.0007 --min-lr 1e-09 \
--weight-decay 0.0 --criterion oracle_label_smoothed_cross_entropy --label-smoothing 0.1 \
--max-tokens 12288 --update-freq 1 --no-progress-bar --log-format json --max-update 40000 \
--log-interval 10 --save-interval-updates 2000 --keep-interval-updates 10 --save-interval 10000 \
--seed 1111 --reset-optimizer --reset-meters \
--use-sentence-level-oracles --decay-k 5800 --use-bleu-gumbel-noise --gumbel-noise 0.5 --oracle-search-beam-size 4 \
--distributed-port 31111 --distributed-world-size 8 --ddp-backend=no_c10d \
--source-lang en --target-lang de --save-dir $model_dir | tee -a $model_dir/training.log
As Eq.(15) in the paper, the probability of sampling golden words decays with the number of udpates as follows:
--use-epoch-numbers-decay
and --decay-k
need to be adapted on different training data.prob
field in the training log means the decay probability of sampling golden words.Test training speed and GPU memory usage on iwslt de2en training set
Model Name | Memory Usage (G) | Training Speed (upd/s) |
---|---|---|
Transformer | 4.39 | 2.65 |
Word-level training | 4.57 | 2.25 |
Sentence-level training (decay_prob=1, beam_size=4) | 4.75 | 0.59 |
please cite as:
@inproceedings{zhang2019bridging,
title = "Bridging the Gap between Training and Inference for Neural Machine Translation",
author = "Zhang, Wen and Feng, Yang and Meng, Fandong and You, Di and Liu, Qun",
booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics",
month = jul,
year = "2019",
address = "Florence, Italy",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/P19-1426",
doi = "10.18653/v1/P19-1426",
pages = "4334--4343",
}
Fairseq(-py) is a sequence modeling toolkit that allows researchers and developers to train custom models for translation, summarization, language modeling and other text generation tasks.
Fairseq provides reference implementations of various sequence-to-sequence models, including:
Additionally:
We also provide pre-trained models for translation and language modeling
with a convenient torch.hub
interface:
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model')
en2de.translate('Hello world', beam=5)
# 'Hallo Welt'
See the PyTorch Hub tutorials for translation and RoBERTa for more examples.
--cuda_ext
and --deprecated_fused_adam
optionsTo install fairseq:
pip install fairseq
On MacOS:
CFLAGS="-stdlib=libc++" pip install fairseq
If you use Docker make sure to increase the shared memory size either with
--ipc=host
or --shm-size
as command line options to nvidia-docker run
.
Installing from source
To install fairseq from source and develop locally:
git clone https://github.com/pytorch/fairseq
cd fairseq
pip install --editable .
The full documentation contains instructions for getting started, training new models and extending fairseq with new model types and tasks.
We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below, as well as example training and evaluation commands.
We also have more detailed READMEs to reproduce results from specific papers:
fairseq(-py) is MIT-licensed. The license applies to the pre-trained models as well.
Please cite as:
@inproceedings{ott2019fairseq,
title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
year = {2019},
}