zhulifengsheng / fairseq_mmt

MIT License
4 stars 9 forks source link

multimodal machine translation(MMT)

Our dependency

Install fairseq

cd fairseq_mmt
pip install --editable ./

Multi30k data & Flickr30k entities

Multi30k data from here and here
flickr30k entities data from here
Here, We get multi30k text data from Revisit-MMT

cd fairseq_mmt
git clone https://github.com/BryanPlummer/flickr30k_entities.git
cd flickr30k_entities
unzip annotations.zip

# download data and create a directory anywhere
flickr30k
├─ flickr30k-images
├─ test2017-images
├─ test_2016_flickr.txt
├─ test_2017_flickr.txt
├─ test_2017_mscoco.txt
├─ test_2018_flickr.txt
├─ testcoco-images
├─ train.txt
└─ val.txt

Extract image feature

1. Vision Transformer

image_feat_shape

  # please read scripts/README.md to modify the code of timm firstly!
  # ⬆ ⬆ ⬆ ⬆ ⬆ ⬆ ⬆ ⬆
  python3 scripts/get_img_feat.py --dataset train --model vit_base_patch16_384 --path ../flickr30k

script parameters:

2. DETR

detr

  # please run scripts/get_img_feat_detr.py to download DETR offical code and model firstly
  # then modify detr.py (in DETR offical code) to return image feature according to the above image
  # ⬆ ⬆ ⬆ ⬆ ⬆ ⬆ ⬆ ⬆
  python3 scripts/get_img_feat_detr.py --dataset train --path ../flickr30k

script parameters:

Create masking data

pip3 install stanfordcorenlp 
wget https://nlp.stanford.edu/software/stanford-corenlp-latest.zip
unzip stanford-corenlp-latest.zip

cd fairseq_mmt
cat data/multi30k/train.en data/multi30k/valid.en data/multi30k/test.2016.en > train_val_test2016.en
python3 get_and_record_noun_from_f30k_entities.py # recording noun and nouns position in each sentence by flickr30k_entities
python3 record_color_people_position.py

cd data/masking
# create en-de masking data
python3 match_origin2bpe_position.py en-de
python3 create_masking_multi30k.py en-de         # create mask1-4 & color & people data 
# create en-fr masking data
python3 match_origin2bpe_position.py en-fr
python3 create_masking_multi30k.py en-fr         # create mask1-4 & color & people data 

sh preprocess_mmt.sh

Train and Test

1. Preprocess(mask1 as an example)

src='en'
tgt='de'
mask=mask1  # mask1, mask2, mask3, maskc(color), maskp(character)
TEXT=data/multi30k-en-$tgt.$mask

fairseq-preprocess --source-lang $src --target-lang $tgt \
  --trainpref $TEXT/train \
  --validpref $TEXT/valid \
  --testpref $TEXT/test.2016,$TEXT/test.2017,$TEXT/test.coco \
  --destdir data-bin/multi30k.en-$tgt.$mask \
  --workers 8 --joined-dictionary \
  --srcdict data/dict.en2de_$mask.txt

sh preprocess.sh to generate no masking data

2. Train(mask1 as an example)

mask_data=mask1
data_dir=multi30k.en-de.mask1
src_lang='en'
tgt_lang='de'
image_feat=vit_base_patch16_384
tag=$image_feat/$image_feat-$mask_data
save_dir=checkpoints/multi30k-en2de/$tag
image_feat_path=data/$image_feat
image_feat_dim=768

criterion=label_smoothed_cross_entropy
fp16=1
lr=0.005
warmup=2000
max_tokens=4096
update_freq=1
keep_last_epochs=10
patience=10
max_update=8000
dropout=0.3

arch=image_multimodal_transformer_SA_top
SA_attention_dropout=0.1
SA_image_dropout=0.1
SA_text_dropout=0

CUDA_VISIBLE_DEVICES=0,1 fairseq-train data-bin/$data_dir
  --save-dir $save_dir
  --distributed-world-size 2 -s $src_lang -t $tgt_lang
  --arch $arch
  --dropout $dropout
  --criterion $criterion --label-smoothing 0.1
  --task image_mmt --image-feat-path $image_feat_path --image-feat-dim $image_feat_dim
  --optimizer adam --adam-betas '(0.9, 0.98)'
  --lr $lr --min-lr 1e-09 --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates $warmup
  --max-tokens $max_tokens --update-freq $update_freq --max-update $max_update
  --find-unused-parameters
  --share-all-embeddings
  --patience $patience
  --keep-last-epochs $keep_last_epochs
  --SA-image-dropout $SA_image_dropout
  --SA-attention-dropout $SA_attention_dropout
  --SA-text-dropout $SA_text_dropout

you can run train_mmt.sh instead of scripts above.

3. Test(mask1 as an example)

#sh translation_mmt.sh $1 $2
sh translation_mmt.sh mask1 vit_base_patch16_384  # origin text is mask0

script parameters:

Visualization

# uncomment line429-431,487-488 in /fairseq/models/image_multimodal_transformer_SA.py
# decode again, generate tensors to the checkpoint dir
# prepare files needed in /visualization/visualization.py
cd visualization
python3 visualization.py