from OmniEvent.arguments import DataArguments, ModelArguments, TrainingArguments, ArgumentParser
from OmniEvent.input_engineering.seq2seq_processor import EDSeq2SeqProcessor, type_start, type_end
from OmniEvent.backbone.backbone import get_backbone
from OmniEvent.model.model import get_model
from OmniEvent.evaluation.metric import compute_seq_F1
from OmniEvent.trainer_seq2seq import Seq2SeqTrainer
from OmniEvent.evaluation.utils import predict, get_pred_s2s
from OmniEvent.evaluation.convert_format import get_trigger_detection_s2s
from transformers import T5ForConditionalGeneration, T5TokenizerFast
from ipdb import set_trace
def main():
# Step 2: Set up the customized configurations
parser = ArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_yaml_file(yaml_file="config/all-datasets/ed/s2s/duee.yaml")
training_args.output_dir = 'output/duee/ED/seq2seq/t5-base/'
data_args.markers = ["<event>", "</event>", type_start, type_end]
print('==================================step2 数据集配置文件yaml结束==================================')
# Step 3: Initialize the model and tokenizer
model_args.model_name_or_path = '/pretrained_model/t5'
model = T5ForConditionalGeneration.from_pretrained(model_args.model_name_or_path)
backbone = model
tokenizer = T5TokenizerFast.from_pretrained(model_args.model_name_or_path, never_split=data_args.markers)
config = model.config
model = get_model(model_args, backbone)
print("======================step3 模型初始化结束====================================")
# Step 4: Initialize the dataset and evaluation metric
data_args.train_file = '/data/processed/DuEE1.0/train.unified.jsonl'
data_args.test_file = "/data/processed/DuEE1.0/test.unified.jsonl"
data_args.validation_file = "/data/processed/DuEE1.0/valid.unified.jsonl"
train_dataset = EDSeq2SeqProcessor(data_args, tokenizer, data_args.train_file)
eval_dataset = EDSeq2SeqProcessor(data_args, tokenizer, data_args.validation_file)
metric_fn = compute_seq_F1
# Step 5: Define Trainer and train
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=metric_fn,
data_collator=train_dataset.collate_fn,
tokenizer=tokenizer,
)
resume_from_checkpoint = 'OmniEvent-main/output/duee/ED/seq2seq/t5-base/checkpoint-7440'
if resume_from_checkpoint :
trainer.train(resume_from_checkpoint)
else:
trainer.train()
print('*****************************************训练结束********************************************')
# Step 6: Unified Evaluation
logits, labels, metrics, test_dataset = predict(trainer=trainer, tokenizer=tokenizer, data_class=EDSeq2SeqProcessor,
data_args=data_args, data_file=data_args.test_file,
training_args=training_args)
set_trace()
# paradigm-dependent metrics
print("{} test performance before converting: {}".format(test_dataset.dataset_name, metrics["test_micro_f1"]))
preds = get_pred_s2s(logits, tokenizer)
# convert to the unified prediction and evaluate
pred_labels = get_trigger_detection_s2s(preds, labels, data_args.test_file, data_args, None)
print("{} test performance after converting: {}".format(test_dataset.dataset_name, pred_labels["test_micro_f1"]))
from OmniEvent.arguments import DataArguments, ModelArguments, TrainingArguments, ArgumentParser from OmniEvent.input_engineering.seq2seq_processor import EDSeq2SeqProcessor, type_start, type_end from OmniEvent.backbone.backbone import get_backbone from OmniEvent.model.model import get_model from OmniEvent.evaluation.metric import compute_seq_F1 from OmniEvent.trainer_seq2seq import Seq2SeqTrainer from OmniEvent.evaluation.utils import predict, get_pred_s2s from OmniEvent.evaluation.convert_format import get_trigger_detection_s2s from transformers import T5ForConditionalGeneration, T5TokenizerFast from ipdb import set_trace
def main():
if name == "main": main()
您好,我在尝试将您readme里面的例子,用duee数据集,写成了py的格式。但是遇到了一些问题,例如metrics["test_micro_f1"]里为metrics["micro_f1"]、并且这里为0。请问您那边是否有这个的py文件,是否方便提供一下