goodbai-nlp / AMRBART

Code for our paper "Graph Pre-training for AMR Parsing and Generation" in ACL2022
MIT License
95 stars 30 forks source link

KeyError 'source' when finetuning #6

Closed PhMeier closed 2 years ago

PhMeier commented 2 years ago

Hello, during testing finetuning in a conda environment on the example data I encountered the following exception:

Traceback (most recent call last):
  File "/home/students/meier/AMRBART/fine-tune/run_amrparsing.py", line 154, in <module>
    main(args)
  File "/home/students/meier/AMRBART/fine-tune/run_amrparsing.py", line 129, in main
    trainer.fit(model, datamodule=data_module)
  File "/home/students/meier/anaconda3/envs/my_AMRBART_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in fit
    self._call_and_handle_interrupt(
  File "/home/students/meier/anaconda3/envs/my_AMRBART_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/students/meier/anaconda3/envs/my_AMRBART_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/students/meier/anaconda3/envs/my_AMRBART_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1193, in _run
    self._dispatch()
  File "/home/students/meier/anaconda3/envs/my_AMRBART_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1272, in _dispatch
    self.training_type_plugin.start_training(self)
  File "/home/students/meier/anaconda3/envs/my_AMRBART_env/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
    self._results = trainer.run_stage()
  File "/home/students/meier/anaconda3/envs/my_AMRBART_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1282, in run_stage
    return self._run_train()
  File "/home/students/meier/anaconda3/envs/my_AMRBART_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1304, in _run_train
    self._run_sanity_check(self.lightning_module)
  File "/home/students/meier/anaconda3/envs/my_AMRBART_env/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1368, in _run_sanity_check
    self._evaluation_loop.run()
  File "/home/students/meier/anaconda3/envs/my_AMRBART_env/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 151, in run
    output = self.on_run_end()
  File "/home/students/meier/anaconda3/envs/my_AMRBART_env/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 130, in on_run_end
    self._evaluation_epoch_end(outputs)
  File "/home/students/meier/anaconda3/envs/my_AMRBART_env/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 235, in _evaluation_epoch_end
    model.validation_epoch_end(outputs)
  File "/home/students/meier/AMRBART/fine-tune/model_interface/model_amrparsing.py", line 320, in validation_epoch_end
    source = flatten_list(x["source"] for x in ori_outputs)
  File "/home/students/meier/AMRBART/fine-tune/common/utils.py", line 109, in flatten_list
    return [x for x in itertools.chain.from_iterable(summary_ids)]
  File "/home/students/meier/AMRBART/fine-tune/common/utils.py", line 109, in <listcomp>
    return [x for x in itertools.chain.from_iterable(summary_ids)]
  File "/home/students/meier/AMRBART/fine-tune/model_interface/model_amrparsing.py", line 320, in <genexpr>
    source = flatten_list(x["source"] for x in ori_outputs)
KeyError: 'source'

Printing out "ori_outputs" shows this: ori outputs [{'loss': tensor(0.8626, device='cuda:0'), 'gen_time': 8.689491331577301, 'gen_len': 1024.0, 'preds': [[53842, 36, 53069, 51012, 52944, 36, 53070, 171, 4839, 52945, 36, 53071, 14195, 4839, 4839, 53843, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, The key 'source' is missing.

My sh script looks like this:

#!/bin/bash

ROOT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"

GPUID=$2
MODEL=$1
eval_beam=5
modelcate=base
modelcate=large

lr=8e-6

datacate=/home/students/meier/AMRBART/examples/ #/home/students/meier/MA/data/ #AMR2.0
# datacate=AMR3.0

Tokenizer=facebook/bart-$modelcate  #../../../data/pretrained-model/bart-$modelcate
export OUTPUT_DIR_NAME=outputs/fine_tune_amrparse #${datacate}-AMRBart-${modelcate}-amrparsing-6taskPLM-5e-5-finetune-lr${lr}

export CURRENT_DIR=${ROOT_DIR}
export OUTPUT_DIR=${CURRENT_DIR}/${OUTPUT_DIR_NAME}
cache=~/.cache  #../../../data/.cache/

if [ ! -d $OUTPUT_DIR ];then
  mkdir -p $OUTPUT_DIR
else
  echo "${OUTPUT_DIR} already exists, change a new one or delete origin one"
  exit 0
fi

export OMP_NUM_THREADS=10
export CUDA_VISIBLE_DEVICES=${GPUID}
python -u ${ROOT_DIR}/run_amrparsing.py \
    --data_dir=$datacate \
    --train_data_file=$datacate/train.jsonl \
    --eval_data_file=$datacate/val.jsonl \
    --test_data_file=$datacate/test.jsonl \
    --model_type ${MODEL} \
    --model_name_or_path=${MODEL} \
    --tokenizer_name_or_path=${Tokenizer} \
    --val_metric "smatch" \
    --learning_rate=${lr} \
    --max_epochs 20 \
    --max_steps -1 \
    --per_gpu_train_batch_size=4 \
    --per_gpu_eval_batch_size=4 \
    --unified_input \
    --accumulate_grad_batches 2 \
    --early_stopping_patience 10 \
    --gpus 1 \
    --output_dir=${OUTPUT_DIR} \
    --cache_dir ${cache} \
    --num_sanity_val_steps 4 \
    --src_block_size=512 \
    --tgt_block_size=1024 \
    --eval_max_length=1024 \
    --train_num_workers 8 \
    --eval_num_workers 4 \
    --process_num_workers 8 \
    --do_train --do_predict \
    --seed 42 \
    --fp16 \
    --eval_beam ${eval_beam} 2>&1 | tee $OUTPUT_DIR/run.log

I run call the script in the following way: srun ~/AMRBART/fine-tune/finetune_AMRbart_amrparsing_large.sh /workspace/students/meier/AMR_Bart_models/AMR-BART-LARGE 0

What can I do to solve the problem? Thanks for reading!

goodbai-nlp commented 2 years ago

Hi, please clone the latest code and try again. If you still get errors, please post here.

PhMeier commented 2 years ago

Thank you very much, this solved the issue!