Open Mr-KenLee opened 6 months ago
我用Seq2SeqTrainer对Baichuan2-7B-Base进行LoRA微调,但是很奇怪,我发现在第一次预测后,会出现OOM问题,但是Baichuan2-7B-Chat并不会。 同时,我发现Baichuan2-7B-Base的OOM问题来源于,从预测回归训练后,模型好像会二次加载,使得显存占用翻倍从而OOM。 我对比了Base和Chat的modeling.py文件,发现主要是Base中下面代码的问题:
class NormHead(nn.Module): def __init__(self, hidden_size, vocab_size, bias=False): super().__init__() self.weight = nn.Parameter(torch.empty((vocab_size, hidden_size))) nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) self.first_flag = True def forward(self, hidden_states): if self.training: norm_weight = nn.functional.normalize(self.weight) elif self.first_flag: self.first_flag = False self.weight = nn.Parameter(nn.functional.normalize(self.weight)) norm_weight = self.weight else: norm_weight = self.weight return nn.functional.linear(hidden_states, norm_weight)
而在Chat中则是:
class NormHead(nn.Module): def __init__(self, hidden_size, vocab_size, bias=False): super().__init__() self.weight = nn.Parameter(torch.empty((vocab_size, hidden_size))) nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) self.first_flag = True def forward(self, hidden_states): if self.training: norm_weight = nn.functional.normalize(self.weight) self.first_flag = True elif self.first_flag: self.first_flag = False self.weight.data = nn.functional.normalize(self.weight) norm_weight = self.weight else: norm_weight = self.weight return nn.functional.linear(hidden_states, norm_weight)
将Base中的替换为Chat中的NormHead后问题解决,想请问下这个原因是为什么呢?两个modeling文件是否可以互用?
应该主要是没有self.first_flag = True造成的吧?Base没有这个就会造成从预测转训练的时候,进不到目标分支?
我用Seq2SeqTrainer对Baichuan2-7B-Base进行LoRA微调,但是很奇怪,我发现在第一次预测后,会出现OOM问题,但是Baichuan2-7B-Chat并不会。 同时,我发现Baichuan2-7B-Base的OOM问题来源于,从预测回归训练后,模型好像会二次加载,使得显存占用翻倍从而OOM。 我对比了Base和Chat的modeling.py文件,发现主要是Base中下面代码的问题:
而在Chat中则是:
将Base中的替换为Chat中的NormHead后问题解决,想请问下这个原因是为什么呢?两个modeling文件是否可以互用?