Closed HelloNicoo closed 10 months ago
Hi @HelloNicoo, thanks for raising an issue.
We get many issues and feature requests and so need to you help us so that we can get through them at a reasonable pace. Could you:
``` code goes here ```
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
System Info
transformers
version: 4.34.1Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
``''' @File : custom_bert_model.py @Time : 2023/09/15 14:37:17 @Author : Raomoin @Version : 1.0 @Contact : roamoin0509@gmail.com @License : (C)Copyright 2023-2024, Liugroup-NLPR-CASIA @Desc : None '''
import torch import warnings from dataclasses import dataclass, field from typing import Dict
import numpy as np import torch.nn as nn from datasets import load_dataset from sklearn.metrics import f1_score, precision_score, recall_score from transformers import (BertModel, BertPreTrainedModel, BertTokenizer, Trainer, TrainingArguments) from transformers.modeling_outputs import SequenceClassifierOutput from transformers.trainer_utils import EvalPrediction
warnings.filterwarnings("ignore")
MODEL_NAME = 'bert-base-chinese' token = BertTokenizer.from_pretrained(MODEL_NAME, local_files_only=True)
@dataclass class ModelArguments: """ 模型参数定义 """ ner_num_labels: int = field(default=2, metadata={"help": "需要预测的标签数量"})
def compute_metrics(eval_output): """ 该函数是回调函数,Trainer会在进行评估时调用该函数。 (如果使用Pycharm等IDE进行调试,可以使用断点的方法来调试该函数,该函数在进行评估时被调用) """ print('qqqqqqqq') preds = eval_output.predictions preds = np.argmax(preds, axis=-1).flatten() labels = eval_output.label_ids.flatten()
labels为0表示为,因此计算时需要去掉该部分
class CustomBertModel(BertPreTrainedModel): """ 自定义模型 """
def tokenize_function(examples): """ map处理数据 """ new_data = token(examples['text'], padding='max_length', truncation=True)
new_data['labels'] = [cate_dict[label] for label in examples["cat_leaf_name_old"]]
if name == 'main': model_args = ModelArguments() model = CustomBertModel.from_pretrained(MODEL_NAME, model_args=model_args, local_files_only=True)
Expected behavior
Hello, I would like to ask about the situation that rewriting the compute_metrics function does not take effect after rewriting the model part in the transformers framework. Is there any solution Here is my code, can you help me, thx !