ictnlp / OR-NMT

Source Code for ACL2019 paper <Bridging the Gap between Training and Inference for Neural Machine Translation>
41 stars 10 forks source link

OR-NMT: Bridging the Gap between Training and Inference for Neural Machine Translation

Wen Zhang, Yang Feng, Fandong Meng, Di You and Qun Liu. Bridging the Gap between Training and Inference for Neural Machine Translation. In Proceedings of ACL, 2019. [paper][code]

Codes in the two directories are the OR-NMT systems based on the RNNsearch and Transformer models correspondingly

Runtime Environment

This system has been tested in the following environment.

For OR-Transformer:

First, go into the OR-Transformer directory.

Then, the training script is the same with fairseq, except for the following arguments:

By default, the probability is decayed based on the update index.

NOTE: For a new data set, the hyperparameter --decay-k needs to be manually adjusted according to the maximum number of training updates (default) or epochs (--use-epoch-numbers-decay) to ensure that the probability of sampling golden words does not decay so quickly.

For Eq.(11~13) in the paper, is actually the same as . The operation is not needed in the code implementation.

Gumbel noise:

As for the --arch and --criterion arguments, oracle_ should be used as the prefix for OR-NMT training, such as:

Example of the script for word-level training and decaying the probability based on epoch index:

export CUDA_VISIBLE_DEVICES=0,1,2,3
batch_size=4096
accum=2
data_dir=directory_of_data_bin
model_dir=./ckpt
python train.py $data_dir \
    --arch oracle_transformer_vaswani_wmt_en_de_big --share-all-embeddings \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --lr-scheduler inverse_sqrt \
    --warmup-init-lr 1e-07 --warmup-updates 4000 --lr 0.0005 --min-lr 1e-09 \
    --weight-decay 0.0 --criterion oracle_label_smoothed_cross_entropy --label-smoothing 0.1 \
    --max-tokens $batch_size --update-freq $accum --no-progress-bar --log-format json --max-update 200000 \
    --log-interval 10 --save-interval-updates 10000 --keep-interval-updates 10 --save-interval 10000 \
    --seed 1111 --skip-invalid-size-inputs-valid-test \
    --distributed-port 28888 --distributed-world-size 4 --ddp-backend=no_c10d \
    --source-lang en --target-lang de --save-dir $model_dir \
    --use-word-level-oracles --use-epoch-numbers-decay --decay-k 10 \
    --use-greed-gumbel-noise --gumbel-noise 0.5 | tee -a $model_dir/training.log

Model settings on NIST Chinese->English (Zh->En) and WMT'14 English->German (En->De).

Models Translation Task #GPUs #Toks. #Freq. Max
Transformer-big Zh->En 8 4096 3 30 epochs
+Word-level Oracle Zh->En 8 4096 3 30 epochs
Transformer-base En->De 8 6144 2 80000 updates (62 epochs)
+Word-level Oracle En->De 8 12288 1 80000 updates (62 epochs)
+Sentence-level Oracle En->De 8 12288 1 40000 updates (62th epoch -> 93th epoch)

#Toks. means batchsize on single GPU.

#Freq. means the times of gradient accumulation.

Max represents the maximum number of training epochs (30) or updates (80k).

Results and Settings on NIST Chinese->English translation task

We calculate the case-insensitive 4-gram tokenized BLEU by script multibleu.perl

Models Dev. (MT02) MT03 MT04 MT05 MT06 MT08 Average
Transformer-big 48.50 47.29 47.79 48.28 47.50 38.50 45.87 30
+Word-level Oracle (==10) 49.18 48.70 48.67 48.69 48.49 39.58 46.83 30
+Word-level Oracle (==15) 49.05 48.57 48.73 48.68 48.59 39.68 46.85 30
+Word-level Oracle (==20) 49.30 48.46 48.57 48.87 48.57 39.46 46.79 30
+Word-level Oracle (==25) 48.88 48.32 48.66 48.74 48.32 39.38 46.68 30
+Word-level Oracle (==30) 48.47 48.37 48.50 48.63 48.07 39.54 46.62 46.74 30

We also evaluate by the case-insensitive 4-gram detokenized BLEU with SacreBLEU, which is calculated the script score.py provided by fairseq: BLEU+case.mixed+lang.en-{de,fr}+numrefs.4+smooth.exp+tok.13a+version.1.4.4

Models Dev. (MT02) MT03 MT04 MT05 MT06 MT08 Average
Transformer-big 48.46 47.41 47.88 48.25 47.52 38.60 45.93 30
+Word-level Oracle (==10) 49.20 48.80 48.77 48.64 48.49 39.79 46.90 30
+Word-level Oracle (==15) 49.07 48.64 48.81 48.63 48.65 39.88 46.92 30
+Word-level Oracle (==20) 49.32 48.54 48.73 48.82 48.51 39.50 46.82 30
+Word-level Oracle (==25) 48.90 48.18 48.70 48.59 47.73 39.14 46.47 30
+Word-level Oracle (==30) 48.53 48.59 48.74 48.58 48.07 39.71 46.74 30

The setting of the NIST Chinese->English:

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
data_bin_dir=directory_of_data_bin
model_dir=./ckpt
python train.py $data_bin_dir \
    --arch oracle_transformer_vaswani_wmt_en_de_big --share-all-embeddings \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --lr-scheduler inverse_sqrt \
    --warmup-init-lr 1e-07 --warmup-updates 4000 --lr 0.0007 --min-lr 1e-09 \
    --weight-decay 0.0 --criterion oracle_label_smoothed_cross_entropy --label-smoothing 0.1 \
    --max-tokens 4096 --update-freq 3 --no-progress-bar --log-format json --max-epoch 30 \
    --log-interval 10 --save-interval 2 --keep-last-epochs 10 \
    --seed 1111 --use-epoch-numbers-decay \
    --use-word-level-oracles --decay-k 15 --use-greed-gumbel-noise --gumbel-noise 0.5 \
    --distributed-port 32222 --distributed-world-size 8 --ddp-backend=no_c10d \
    --source-lang zh --target-lang en --save-dir $model_dir | tee -a $model_dir/training.log

As Eq.(15) in the paper, the probability of sampling golden words decays with the number of epochs as follows:

Results and Settings on WMT'14 English->German translation task

We calculate the case-sensitive 4-gram tokenized BLEU by script multibleu.perl

Models newstest2014 #update
Transformer-base 27.54 80000
+Word-level Oracle (==50, ==0.8) 28.01 80000
+Sentence-level Oracle (==5800, ==0.5, beam_size==4) 28.45 40000

We also evaluate by the case-sensitive 4-gram detokenized BLEU with SacreBLEU, which is calculated the script score.py provided by fairseq: BLEU+case.mixed+lang.en-{de,fr}+numrefs.1+smooth.exp+tok.13a+version.1.4.4

Models newstest2014 #update
Transformer-base 26.45 80000
+Word-level Oracle (==50, ==0.8) 26.86 80000
+Sentence-level Oracle (==5800, ==0.5, beam_size==4) 27.24 40000

Setting of the word-level oracle for the WMT'14 English->German dataset:

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
data_bin_dir=directory_of_data_bin
model_dir=./ckpt
python train.py $data_bin_dir \
    --arch oracle_transformer_wmt_en_de --share-all-embeddings \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --lr-scheduler inverse_sqrt \
    --warmup-init-lr 1e-07 --warmup-updates 4000 --lr 0.0007 --min-lr 1e-09 \
    --weight-decay 0.0 --criterion oracle_label_smoothed_cross_entropy --label-smoothing 0.1 \
    --max-tokens 12288 --update-freq 1 --no-progress-bar --log-format json --max-update 80000 \
    --log-interval 10 --save-interval-updates 4000 --keep-interval-updates 10 --save-interval 10000 \
    --seed 1111 --use-epoch-numbers-decay \
    --use-word-level-oracles --decay-k 50 --use-greed-gumbel-noise --gumbel-noise 0.8 \
    --distributed-port 31111 --distributed-world-size 8 --ddp-backend=no_c10d \
    --source-lang en --target-lang de --save-dir $model_dir | tee -a $model_dir/training.log

As Eq.(15) in the paper, the probability of sampling golden words decays with the number of epochs as follows:

In order to save training time, we use the sentence-level oracle method to finetune the best base model.

Setting of the sentence-level oracle for the WMT'14 English->German dataset:

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
data_bin_dir=directory_of_data_bin
model_dir=./ckpt
python train.py $data_bin_dir \
    --arch oracle_transformer_wmt_en_de --share-all-embeddings \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --lr-scheduler inverse_sqrt \                                                                                                                
    --warmup-init-lr 1e-07 --warmup-updates 4000 --lr 0.0007 --min-lr 1e-09 \
    --weight-decay 0.0 --criterion oracle_label_smoothed_cross_entropy --label-smoothing 0.1 \
    --max-tokens 12288 --update-freq 1 --no-progress-bar --log-format json --max-update 40000 \
    --log-interval 10 --save-interval-updates 2000 --keep-interval-updates 10 --save-interval 10000 \
    --seed 1111 --reset-optimizer --reset-meters \
    --use-sentence-level-oracles --decay-k 5800 --use-bleu-gumbel-noise --gumbel-noise 0.5 --oracle-search-beam-size 4 \ 
    --distributed-port 31111 --distributed-world-size 8 --ddp-backend=no_c10d \
    --source-lang en --target-lang de --save-dir $model_dir | tee -a $model_dir/training.log

As Eq.(15) in the paper, the probability of sampling golden words decays with the number of udpates as follows:

NOTE

Test training speed and GPU memory usage on iwslt de2en training set

Model Name Memory Usage (G) Training Speed (upd/s)
Transformer 4.39 2.65
Word-level training 4.57 2.25
Sentence-level training (decay_prob=1, beam_size=4) 4.75 0.59

Citation

please cite as:

@inproceedings{zhang2019bridging,
    title = "Bridging the Gap between Training and Inference for Neural Machine Translation",
    author = "Zhang, Wen and Feng, Yang and Meng, Fandong and You, Di and Liu, Qun",
    booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics",
    month = jul,
    year = "2019",
    address = "Florence, Italy",
    publisher = "Association for Computational Linguistics",
    url = "https://www.aclweb.org/anthology/P19-1426",
    doi = "10.18653/v1/P19-1426",
    pages = "4334--4343",
}



MIT License Latest Release Build Status Documentation Status


Fairseq(-py) is a sequence modeling toolkit that allows researchers and developers to train custom models for translation, summarization, language modeling and other text generation tasks.

What's New:

Features:

Fairseq provides reference implementations of various sequence-to-sequence models, including:

Additionally:

We also provide pre-trained models for translation and language modeling with a convenient torch.hub interface:

en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model')
en2de.translate('Hello world', beam=5)
# 'Hallo Welt'

See the PyTorch Hub tutorials for translation and RoBERTa for more examples.

Model

Requirements and Installation

To install fairseq:

pip install fairseq

On MacOS:

CFLAGS="-stdlib=libc++" pip install fairseq

If you use Docker make sure to increase the shared memory size either with --ipc=host or --shm-size as command line options to nvidia-docker run.

Installing from source

To install fairseq from source and develop locally:

git clone https://github.com/pytorch/fairseq
cd fairseq
pip install --editable .

Getting Started

The full documentation contains instructions for getting started, training new models and extending fairseq with new model types and tasks.

Pre-trained models and examples

We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below, as well as example training and evaluation commands.

We also have more detailed READMEs to reproduce results from specific papers:

Join the fairseq community

License

fairseq(-py) is MIT-licensed. The license applies to the pre-trained models as well.

Citation

Please cite as:

@inproceedings{ott2019fairseq,
  title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
  author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
  booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
  year = {2019},
}