Closed ElderWanng closed 2 years ago
You can refer to the training script here. If you could not figure it out, please also provide the exact shell command you used for training.
Actually, I'm using this training script. Clone from this repo didn't do any modification. For debugging, I download the data in the google drive link. The error log is still. Here is my shell code for starting training:
set -e
export BART_PATH=/scratch/tw2112/codes/models/bart.large/model.pt
export DATA=/scratch/tw2112/codes/cliff/prepared_data/data
export TRAINED_MODELS=/scratch/tw2112/models
# XSum
cd scripts/bart
CUDA_VISIBLE_DEVICES=0,1 sh train_xsum_single_neg.sh \
$DATA/xsum_synthetic/negative_syslowcon $TRAINED_MODELS/bart_xsum/syslowcon
set -e
TOTAL_NUM_UPDATES=15000
WARMUP_UPDATES=500
LR=3e-05
MAX_TOKENS=1024
UPDATE_FREQ=16
NEG_DIR=$1
SAVE_PATH=$2
POS_DIR=$DATA/xsum_synthetic/positive_bt_filter
DATA_DIR=$DATA/xsum_binarized
USER_DIR=../../models/bart
fairseq-train $DATA_DIR --pos-data $POS_DIR --neg-data $NEG_DIR \
--restore-file $BART_PATH --save-dir $SAVE_PATH \
--max-tokens $MAX_TOKENS \
--task contrastive_translation --mlp 1024 \
--source-lang source --target-lang target \
--truncate-source \
--layernorm-embedding \
--share-all-embeddings \
--share-decoder-input-output-embed \
--reset-optimizer --reset-dataloader --reset-meters \
--required-batch-size-multiple 1 \
--arch contrastive_bart_large \
--criterion contrastive_loss \
--label-smoothing 0.1 \
--fixed-validation-seed 7 \
--dropout 0.1 --attention-dropout 0.1 \
--weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \
--clip-norm 0.1 \
--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
--fp16 --update-freq $UPDATE_FREQ \
--skip-invalid-size-inputs-valid-test --max-epoch 5 \
--no-save-optimizer-state --no-epoch-checkpoints \
--find-unused-parameters \
--user-dir $USER_DIR;
I think it is because you are using a newer version of fairseq. Can you try:
git clone https://github.com/pytorch/fairseq.git
cd fairseq
git checkout 0db28cd
pip install -e .
After that, if you encounter errors related to hydra
, please do:
pip install hydra-core==1.0.6
Thanks for your light-fast reply! Problem is solved! The training is running now. But I meet another error log by fairseq:
--- Logging error ---
Traceback (most recent call last):
File "/ext3/miniconda3/envs/fs/lib/python3.8/logging/__init__.py", line 1085, in emit
msg = self.format(record)
File "/ext3/miniconda3/envs/fs/lib/python3.8/logging/__init__.py", line 929, in format
return fmt.format(record)
File "/ext3/miniconda3/envs/fs/lib/python3.8/logging/__init__.py", line 668, in format
record.message = record.getMessage()
File "/ext3/miniconda3/envs/fs/lib/python3.8/logging/__init__.py", line 373, in getMessage
msg = msg % self.args
TypeError: not all arguments converted during string formatting
Call stack:
File "/ext3/miniconda3/envs/fs/bin/fairseq-train", line 33, in <module>
sys.exit(load_entry_point('fairseq', 'console_scripts', 'fairseq-train')())
File "/home/zp2053/wts/cliff/fairseq/fairseq_cli/train.py", line 392, in cli_main
distributed_utils.call_main(cfg, main)
File "/home/zp2053/wts/cliff/fairseq/fairseq/distributed_utils.py", line 334, in call_main
main(cfg, **kwargs)
File "/home/zp2053/wts/cliff/fairseq/fairseq_cli/train.py", line 117, in main
extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
File "/home/zp2053/wts/cliff/fairseq/fairseq/checkpoint_utils.py", line 192, in load_checkpoint
extra_state = trainer.load_checkpoint(
File "/home/zp2053/wts/cliff/fairseq/fairseq/trainer.py", line 340, in load_checkpoint
self.get_model().load_state_dict(
File "/home/zp2053/wts/cliff/fairseq/fairseq/models/fairseq_model.py", line 113, in load_state_dict
self.upgrade_state_dict(state_dict)
File "/home/zp2053/wts/cliff/fairseq/fairseq/models/fairseq_model.py", line 119, in upgrade_state_dict
self.upgrade_state_dict_named(state_dict, "")
File "/home/zp2053/wts/cliff/fairseq/fairseq/models/bart/model.py", line 280, in upgrade_state_dict_named
logger.info("Overwriting", prefix + "classification_heads." + k)
Message: 'Overwriting'
Arguments: ('classification_heads.contrast.dense.weight',)
and there are 4 more similar ones about loading weights. Those error message don't interrupt the fairseq-train. Is that ok for training?
That's okay for training. I actually commented out that logger.info
line in the fairseq code. Also, I found that some training scripts did not have the correct configuration and fixed them just now. You might want to have a check and see if your script is using the correct config.
BTW (or should I open a new issues thread?) how to configure the TOTAL_NUM_UPDATES
and UPDATE_FREQ
for other datasets. I want to apply your work on gigaword and wikiHow.
We basically set those hyperparameters following the original BART paper. You can see if Fairseq provides some instructions.
Recently I'm working on applying this work on gigaword.
When I gathered the data then started training, the fairseq printed error log:
I checked the
contrastive_translation.py
and there is a parameter alse named "--max-source-positions". Anyone could tell me the correct way to run the training process?