baichuan-inc / Baichuan2

A series of large language models developed by Baichuan Intelligent Technology
https://huggingface.co/baichuan-inc
Apache License 2.0
4.08k stars 293 forks source link

Baichuan2-7B-Base中训练后显存翻倍问题 #387

Open Mr-KenLee opened 6 months ago

Mr-KenLee commented 6 months ago

我用Seq2SeqTrainer对Baichuan2-7B-Base进行LoRA微调,但是很奇怪,我发现在第一次预测后,会出现OOM问题,但是Baichuan2-7B-Chat并不会。 同时,我发现Baichuan2-7B-Base的OOM问题来源于,从预测回归训练后,模型好像会二次加载,使得显存占用翻倍从而OOM。 我对比了Base和Chat的modeling.py文件,发现主要是Base中下面代码的问题:

class NormHead(nn.Module):
    def __init__(self, hidden_size, vocab_size, bias=False):
        super().__init__()
        self.weight = nn.Parameter(torch.empty((vocab_size, hidden_size)))
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        self.first_flag = True

    def forward(self, hidden_states):
        if self.training:
            norm_weight = nn.functional.normalize(self.weight)
        elif self.first_flag:
            self.first_flag = False
            self.weight = nn.Parameter(nn.functional.normalize(self.weight))
            norm_weight = self.weight
        else:
            norm_weight = self.weight
        return nn.functional.linear(hidden_states, norm_weight)

而在Chat中则是:

class NormHead(nn.Module):
    def __init__(self, hidden_size, vocab_size, bias=False):
        super().__init__()
        self.weight = nn.Parameter(torch.empty((vocab_size, hidden_size)))
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        self.first_flag = True

    def forward(self, hidden_states):
        if self.training:
            norm_weight = nn.functional.normalize(self.weight)
            self.first_flag = True
        elif self.first_flag:
            self.first_flag = False
            self.weight.data = nn.functional.normalize(self.weight)
            norm_weight = self.weight
        else:
            norm_weight = self.weight
        return nn.functional.linear(hidden_states, norm_weight)

将Base中的替换为Chat中的NormHead后问题解决,想请问下这个原因是为什么呢?两个modeling文件是否可以互用?

Mr-KenLee commented 6 months ago

应该主要是没有self.first_flag = True造成的吧?Base没有这个就会造成从预测转训练的时候,进不到目标分支?