mymusise / ChatGLM-Tuning

基于ChatGLM-6B + LoRA的Fintune方案
MIT License
3.73k stars 440 forks source link

有一个代码上的问题 #212

Closed wujohns closed 1 year ago

wujohns commented 1 year ago

在本工程中采用了以下方式重写 Trainer 的 save_model 方法:

class ModifiedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        return model(
            input_ids=inputs["input_ids"],
            labels=inputs["labels"],
        ).loss

    def save_model(self, output_dir=None, _internal_call=False):
        from transformers.trainer import TRAINING_ARGS_NAME

        os.makedirs(output_dir, exist_ok=True)
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
        saved_params = {
            k: v.to("cpu") for k, v in self.model.named_parameters() if v.requires_grad
        }
        torch.save(saved_params, os.path.join(output_dir, "adapter_model.bin"))

但依据 peft 的实现原理,应该也可以采用以下更简便的方式来对 Trainer 的 save_mode 方法做重写:

class ModifiedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        return model(
            input_ids=inputs["input_ids"],
            labels=inputs["labels"],
        ).loss

    def save_model(self, output_dir=None, _internal_call=False):
        self.model.save_pretrained(output_dir)

这边想问一下这两者有什么区别吗,手动采取 torch 来存储 lora 模型参数和配置是由于这个场景下 save_pretrained 表现会有异常? PS: 这里按照 self.model.save_pretrained(output_dir) 的方式(即第二种方式)跑了一下,发现没有训练部分以及训练后的推理都没有出现异常,所以比较好奇想问一下

mymusise commented 1 year ago

应该几乎等价,后者应该更好,每个checkpoint里面还会存下config.json

wujohns commented 1 year ago

应该几乎等价,后者应该更好,每个checkpoint里面还会存下config.json

OK,感谢说明,lora 的训练效果还是挺不错的,loss 降低到 1 ~ 3 时风格迁移挺明显的,训练速度也挺快,不过对数据集的要求还是挺高的

dongteng commented 1 year ago

借楼,请问这个class ModifiedTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): return model( input_ids=inputs["input_ids"], labels=inputs["labels"], ).loss loss的具体计算方式该怎么看呀

wujohns commented 1 year ago

借楼,请问这个class ModifiedTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): return model( input_ids=inputs["input_ids"], labels=inputs["labels"], ).loss loss的具体计算方式该怎么看呀

这块逻辑貌似是在 chatglm 的源码实现中来做的,具体应该可以看下 chatglm 的 huggingface 版本的 model 代码部分

ssgg-code commented 7 months ago

非常奇怪的一点,我直接改用self.model.save_pretrained(output_dir)有效,但是用源代码重写的save_model()保存的adapter.bin去加载lora模型,其生成结果和chatglm本身没有差距。 我发现有不少人都遇到了相同的问题,不太能理解这种问题为什么对部分人存在。