FlagOpen / FlagEmbedding

Retrieval and Retrieval-augmented LLMs
MIT License
7.49k stars 539 forks source link

【code share】在微调BGE时增加 evaluation data 监测验证集指标 #1098

Open 5663015 opened 2 months ago

5663015 commented 2 months ago

官方代码在训练时没有添加验证集指标,不太容易监测是否过拟合。经过尝试,增加compute_metrics也不行,Trainerevaluate逻辑有点复杂走不到这里,最终还是得重构一下evaluate。下面分享一个很简单的重构供参考,训练过程中返回验证集的损失,只需正常添加do_evaleval_stepsevaluation_strategy等参数就像。可以根据自己的需求完善验证的逻辑。

trainer.pyclass BiTrainer(Trainer):

    def evaluate(
        self,
        test_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> Dict[str, float]:
        losses = []
        for inputs in tqdm(self.eval_dataset, desc='evaluate'):
            inputs = self.data_collator([inputs])
            inputs['query']['input_ids'] = inputs['query']['input_ids'].to('npu')
            inputs['query']['attention_mask'] = inputs['query']['attention_mask'].to('npu')
            inputs['passage']['input_ids'] = inputs['passage']['input_ids'].to('npu')
            inputs['passage']['attention_mask'] = inputs['passage']['attention_mask'].to('npu')
            inputs.pop('teacher_scores')
            inputs.pop('bi_directions')
            loss = self.compute_loss(self.model, inputs)
            loss = loss.mean().detach().item()
            losses.append(loss)
        metrics = {'eval_loss': sum(losses) / len(losses)}
        self.log(metrics)
        return metrics
clareliu1234 commented 2 months ago

官方代码在训练时没有添加验证集指标,不太容易监测是否过拟合。经过尝试,增加compute_metrics也不行,Trainerevaluate逻辑有点复杂走不到这里,最终还是得重构一下evaluate。下面分享一个很简单的重构供参考,训练过程中返回验证集的损失,只需正常添加do_evaleval_stepsevaluation_strategy等参数就像。可以根据自己的需求完善验证的逻辑。

trainer.pyclass BiTrainer(Trainer):

    def evaluate(
        self,
        test_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> Dict[str, float]:
        losses = []
        for inputs in tqdm(self.eval_dataset, desc='evaluate'):
            inputs = self.data_collator([inputs])
            inputs['query']['input_ids'] = inputs['query']['input_ids'].to('npu')
            inputs['query']['attention_mask'] = inputs['query']['attention_mask'].to('npu')
            inputs['passage']['input_ids'] = inputs['passage']['input_ids'].to('npu')
            inputs['passage']['attention_mask'] = inputs['passage']['attention_mask'].to('npu')
            inputs.pop('teacher_scores')
            inputs.pop('bi_directions')
            loss = self.compute_loss(self.model, inputs)
            loss = loss.mean().detach().item()
            losses.append(loss)
        metrics = {'eval_loss': sum(losses) / len(losses)}
        self.log(metrics)
        return metrics

您好,想请问微调bge-m3的显存消耗是多少?