THU-KEG / OmniEvent

A comprehensive, unified and modular event extraction toolkit.
https://omnievent.readthedocs.io/
MIT License
341 stars 33 forks source link

readme里面的运行步骤 写到一个py文件中 #45

Closed BudBudding closed 10 months ago

BudBudding commented 12 months ago

from OmniEvent.arguments import DataArguments, ModelArguments, TrainingArguments, ArgumentParser from OmniEvent.input_engineering.seq2seq_processor import EDSeq2SeqProcessor, type_start, type_end from OmniEvent.backbone.backbone import get_backbone from OmniEvent.model.model import get_model from OmniEvent.evaluation.metric import compute_seq_F1 from OmniEvent.trainer_seq2seq import Seq2SeqTrainer from OmniEvent.evaluation.utils import predict, get_pred_s2s from OmniEvent.evaluation.convert_format import get_trigger_detection_s2s from transformers import T5ForConditionalGeneration, T5TokenizerFast from ipdb import set_trace

def main():

# Step 2: Set up the customized configurations
parser = ArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_yaml_file(yaml_file="config/all-datasets/ed/s2s/duee.yaml")
training_args.output_dir = 'output/duee/ED/seq2seq/t5-base/'
data_args.markers = ["<event>", "</event>", type_start, type_end]
print('==================================step2 数据集配置文件yaml结束==================================')

# Step 3: Initialize the model and tokenizer
model_args.model_name_or_path = '/pretrained_model/t5'
model = T5ForConditionalGeneration.from_pretrained(model_args.model_name_or_path)
backbone = model
tokenizer = T5TokenizerFast.from_pretrained(model_args.model_name_or_path, never_split=data_args.markers)
config = model.config

model = get_model(model_args, backbone)
print("======================step3 模型初始化结束====================================")

# Step 4: Initialize the dataset and evaluation metric
data_args.train_file = '/data/processed/DuEE1.0/train.unified.jsonl'
data_args.test_file = "/data/processed/DuEE1.0/test.unified.jsonl"
data_args.validation_file = "/data/processed/DuEE1.0/valid.unified.jsonl"
train_dataset = EDSeq2SeqProcessor(data_args, tokenizer, data_args.train_file)
eval_dataset = EDSeq2SeqProcessor(data_args, tokenizer, data_args.validation_file)
metric_fn = compute_seq_F1

# Step 5: Define Trainer and train
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=metric_fn,
    data_collator=train_dataset.collate_fn,
    tokenizer=tokenizer,
)
resume_from_checkpoint = 'OmniEvent-main/output/duee/ED/seq2seq/t5-base/checkpoint-7440'
if resume_from_checkpoint :
    trainer.train(resume_from_checkpoint)
else:
    trainer.train()
print('*****************************************训练结束********************************************')

# Step 6: Unified Evaluation
logits, labels, metrics, test_dataset = predict(trainer=trainer, tokenizer=tokenizer, data_class=EDSeq2SeqProcessor,
                                                data_args=data_args, data_file=data_args.test_file,
                                                training_args=training_args)
set_trace()
# paradigm-dependent metrics
print("{} test performance before converting: {}".format(test_dataset.dataset_name, metrics["test_micro_f1"]))

preds = get_pred_s2s(logits, tokenizer)
# convert to the unified prediction and evaluate
pred_labels = get_trigger_detection_s2s(preds, labels, data_args.test_file, data_args, None)
print("{} test performance after converting: {}".format(test_dataset.dataset_name, pred_labels["test_micro_f1"]))

if name == "main": main()

您好,我在尝试将您readme里面的例子,用duee数据集,写成了py的格式。但是遇到了一些问题,例如metrics["test_micro_f1"]里为metrics["micro_f1"]、并且这里为0。请问您那边是否有这个的py文件,是否方便提供一下

h-peng17 commented 10 months ago

在examples/文件夹下有样例代码,超参数在config/文件夹下。