cmnfriend / O-LoRA

MIT License
126 stars 12 forks source link

可能可以简化代码? #6

Open WuNein opened 7 months ago

WuNein commented 7 months ago

此处https://github.com/cmnfriend/O-LoRA/blob/ff73694786e8a5de5149a9bfb55ad2cedb66fdd1/src/uie_trainer_lora.py#L91

由于这边是跟没有梯度的lora(old)来计算正交,那直接在上一步把lora(old)save为pth是不是可以避免修改peft库了

import torch

# 假设 self.model 是你的模型
stacked_params = {}

for name, param in self.model.named_parameters():
    if "lora_" in name:
        stacked_params[name] = param.data.clone()  # 使用 clone() 复制参数并避免共享内存

# 保存堆叠的参数到文件
torch.save(stacked_params, "path/to/stacked_params.pth")

然后在trainer类里面加载

# 初始化一个字典来存储匹配的模块和对应的参数
matched_modules = {} #load pth

for name, param in self.model.named_parameters():
    if "lora_A" in name:
          # 匹配的模块名称和对应的参数
          param_ = matched_modules[name]

          orthogonal_loss += torch.abs(torch.mm(param, param_.T)).sum()  # [r * dim] * [dim * r]
          break  # target modules have been matched

大致这个意思

是不是就可以避免修改PEFT代码,方便很多?

cmnfriend commented 7 months ago

可以的!👍

WuNein commented 7 months ago

哦对,有个问题我不懂就问了:)懒得再翻您改的PEFT代码了(不是 既然说是当前LoRA在之前LoRA的正交方向上更新的;那么当前的LoRA大概率是merge之前LoRA,以此为基础继续训练的吧?我没理解错吧

DumoeDss commented 7 months ago

哦对,有个问题我不懂就问了:)懒得再翻您改的PEFT代码了(不是 既然说是当前LoRA在之前LoRA的正交方向上更新的;那么当前的LoRA大概率是merge之前LoRA,以此为基础继续训练的吧?我没理解错吧

训练完会进行merge https://github.com/cmnfriend/O-LoRA/issues/5#issuecomment-1803532686

WuNein commented 7 months ago

哦对,有个问题我不懂就问了:)懒得再翻您改的PEFT代码了(不是 既然说是当前LoRA在之前LoRA的正交方向上更新的;那么当前的LoRA大概率是merge之前LoRA,以此为基础继续训练的吧?我没理解错吧

训练完会进行merge #5 (comment)

我的疑惑在新的task的lora初始化上面,既然说是最后合并的,我姑且认为是随机初始化的~毕竟代码上loss要保证两个lora_a是正交的。

DumoeDss commented 7 months ago

话说照着你这样修改的话,原本的l2_loss就没有了吗? 最终的loss = loss + orthogonal_loss * lamda_1吗?

WuNein commented 7 months ago

话说照着你这样修改的话,原本的l2_loss就没有了吗? 最终的loss = loss + orthogonal_loss * lamda_1吗?

你自己加上就好了,又不冲突…… 只是我懒得写了

DumoeDss commented 7 months ago

话说照着你这样修改的话,原本的l2_loss就没有了吗? 最终的loss = loss + orthogonal_loss * lamda_1吗?

你自己加上就好了,又不冲突…… 只是我懒得写了

是直接用matched_modules进行计算吗?

l2_loss = 0.
        for name, param in matched_modules:
            l2_loss += torch.norm(param, p=2)
WuNein commented 7 months ago

话说照着你这样修改的话,原本的l2_loss就没有了吗? 最终的loss = loss + orthogonal_loss * lamda_1吗?

你自己加上就好了,又不冲突…… 只是我懒得写了

是直接用matched_modules进行计算吗?

l2_loss = 0.
        for name, param in matched_modules:
            l2_loss += torch.norm(param, p=2)

完全不对吧,

# l2-normalization for loranew_A/B
        l2_loss = 0.
        for name, param in self.model.named_parameters():
            if "loranew_" in name:
                l2_loss += torch.norm(param, p=2)

原本代码里面写的是新的loranew,那么简化代码以后目标是

# l2-normalization for loranew_A/B
        l2_loss = 0.
        for name, param in self.model.named_parameters():
            if "lora_" in name:
                l2_loss += torch.norm(param, p=2)

lora_ 就是原本的lora_new啊,l2正则肯定是对现在task的参数进行的啊