Open 5663015 opened 2 months ago
官方代码在训练时没有添加验证集指标,不太容易监测是否过拟合。经过尝试,增加
compute_metrics
也不行,Trainer
的evaluate
逻辑有点复杂走不到这里,最终还是得重构一下evaluate
。下面分享一个很简单的重构供参考,训练过程中返回验证集的损失,只需正常添加do_eval
、eval_steps
、evaluation_strategy
等参数就像。可以根据自己的需求完善验证的逻辑。
trainer.py
里class BiTrainer(Trainer):
def evaluate( self, test_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval", ) -> Dict[str, float]: losses = [] for inputs in tqdm(self.eval_dataset, desc='evaluate'): inputs = self.data_collator([inputs]) inputs['query']['input_ids'] = inputs['query']['input_ids'].to('npu') inputs['query']['attention_mask'] = inputs['query']['attention_mask'].to('npu') inputs['passage']['input_ids'] = inputs['passage']['input_ids'].to('npu') inputs['passage']['attention_mask'] = inputs['passage']['attention_mask'].to('npu') inputs.pop('teacher_scores') inputs.pop('bi_directions') loss = self.compute_loss(self.model, inputs) loss = loss.mean().detach().item() losses.append(loss) metrics = {'eval_loss': sum(losses) / len(losses)} self.log(metrics) return metrics
您好,想请问微调bge-m3的显存消耗是多少?
官方代码在训练时没有添加验证集指标,不太容易监测是否过拟合。经过尝试,增加
compute_metrics
也不行,Trainer
的evaluate
逻辑有点复杂走不到这里,最终还是得重构一下evaluate
。下面分享一个很简单的重构供参考,训练过程中返回验证集的损失,只需正常添加do_eval
、eval_steps
、evaluation_strategy
等参数就像。可以根据自己的需求完善验证的逻辑。trainer.py
里class BiTrainer(Trainer):