bytedance / lightseq

LightSeq: A High Performance Library for Sequence Processing and Generation
Other
3.19k stars 329 forks source link

en2fr和en2de的模型结构存在差异? #485

Open MeJerry215 opened 1 year ago

MeJerry215 commented 1 year ago

使用ls_fs_transformer_export.py 导出en2fr的时候发现 缺少layernorm参数 en2de

dict_keys(['encoder.embed_tokens.para', 'encoder.layers.0.para', 'encoder.layers.1.para', 'encoder.layers.2.para', 'encoder.layers.3.para', 'encoder.layers.4.para', 'encoder.layers.5.para', 'encoder.layer_norm.weight', 'encoder.layer_norm.bias', 'decoder.embed_tokens.para', 'decoder.layers.0.para', 'decoder.layers.1.para', 'decoder.layers.2.para', 'decoder.layers.3.para', 'decoder.layers.4.para', 'decoder.layers.5.para', 'decoder.layer_norm.weight', 'decoder.layer_norm.bias', 'decoder.output_projection.clip_max'])

en2fr

dict_keys(['encoder.embed_tokens.para', 'encoder.layers.0.para', 'encoder.layers.1.para', 'encoder.layers.2.para', 'encoder.layers.3.para', 'encoder.layers.4.para', 'encoder.layers.5.para', 'decoder.embed_tokens.para', 'decoder.layers.0.para', 'decoder.layers.1.para', 'decoder.layers.2.para', 'decoder.layers.3.para', 'decoder.layers.4.para', 'decoder.layers.5.para', 'decoder.output_projection.clip_max'])

这个训练使用了和en2de一样的参数除了arch参数有一点差异主要是使用了 将原来脚本中ls_fairseq_wmt14en2de.sh的arch 从ls_transformer_wmt_en_de_big_t2t改为ls_transformer_vaswani_wmt_en_fr_big

#!/usr/bin/env bash
set -ex

THIS_DIR=$(dirname $(readlink -f $0))
cd $THIS_DIR/../../..

if [ ! -d "/tmp/wmt14_en_fr" ]; then
    echo "not found valid dataset"
    exit -1
fi

lightseq-train /tmp/wmt14_en_fr/ \
    --task translation \
    --arch ls_transformer_vaswani_wmt_en_fr_big --share-decoder-input-output-embed \
    --optimizer ls_adam --adam-betas '(0.9, 0.98)' \
    --clip-norm 0.0 \
    --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 --weight-decay 0.0001 \
    --criterion ls_label_smoothed_cross_entropy --label-smoothing 0.1 \
    --max-tokens 8192 \
    --eval-bleu \
    --eval-bleu-args '{"beam": 5, "max_len_a": 1.2, "max_len_b": 10}' \
    --eval-bleu-detok moses \
    --eval-bleu-remove-bpe \
    --eval-bleu-print-samples \
    --best-checkpoint-metric bleu \
    --maximize-best-checkpoint-metric \
    --fp16 \
    --find-unused-parameters \
    --save-dir checkpoints/en2fr/

image

看上去是因为encoder_normalize_before和decoder_normalize_before导致没有layernorm参数的?所以这个是直注释掉导出脚本中的layernorm相关的部分吗?

MeJerry215 commented 1 year ago

当我导出en2fr的时候 推理在加载h5py的时候出错异常,如果我注释掉layernorm 参数

Taka152 commented 1 year ago

你的enfr模型里encoder是post norm。现在fairseq的导出脚本对于post还不支持,可以根据hf_bert的导出来修改,或者用ende的模型结构来重新训练一下