THUDM / ChatGLM2-6B

ChatGLM2-6B: An Open Bilingual Chat LLM | 开源双语对话语言模型
Other
15.71k stars 1.85k forks source link

[BUG/Help] 使用 adapter 方法微调模型在 forward pass 出现 nan / inf 异常值 #606

Open ZionDoki opened 11 months ago

ZionDoki commented 11 months ago

Is there an existing issue for this?

Current Behavior

尝试使用 adapter 方法微调模型,不论是在 fp16 还是 fp32 环境下都会出现异常值,以下是部分代码示意:


def check_exception(tensor):
    return torch.isnan(tensor).any() or torch.isinf(tensor).any()

class adapter(nn.Module):

    def __init__(self, config, *, **):
            self.wi = nn.Linear(config.hidden_size, config.la_hidden_size)
            self.wo = nn.Linear(config.la_hidden_size, config.hidden_size)
            self.activation = nn.ReLU()
            self.reset_parameters()

    def forward(self, x):
        print("x",  check_exception(x))    # 下图的日志说明异常值产生在 adapter 模块外部,及 glm 本身
        shortcut = x
        print("wi", check_exception(self.wi.weight))
        x = self.wi(x)
        print(2, check_exception(x))
        x = self.activation(x)
        print(3, check_exception(x))
        print("wi", check_exception(self.wo.weight))
        x = self.wo(x)
        print(4, check_exception(x))
        if check_exception(x):
            raise BaseException("!")
        return self.v1 * x + shortcut

# 这些 adapter 插入到 MLP 模块如下

class MLPWithAdapter(MLP):
    def __init__(self, config: ChatGLMConfig, device=None):
        super().__init__(config, device)

        if config.only_adapter_trainable:
            for params in super().parameters():
                params.requires_grad = False

        self.adapteri, self.adaptero = None, None
        if config.use_local_adapter:
            self.adapteri = Adapter(config)
            self.adaptero = Adapter(config)

    def forward(self, hidden_states):
        # [s, b, h]
        if self.adapteri:
            hidden_states = self.adapteri(hidden_states)

        intermediate_parallel = self.dense_h_to_4h(hidden_states)
        intermediate_parallel = self.activation_func(intermediate_parallel)
        output = self.dense_4h_to_h(intermediate_parallel)

        if self.adaptero:
            hidden_states = self.adapteri(hidden_states)
        return output

表现出来的结果是 loss 一直为 0 ,排查发现在 forward pass 过程中出现了异常值。如下图所示,异常值来自 adapter 的输入

image

我使用的是 ChatGLMForConditionalGeneration 进行微调,输入的数据格式未按照官方要求,但问题出现在 forward pass 过程且前几个模块中并未出现异常值,考虑应该不是输入数据的问题。

Expected Behavior

希望可以正常微调

Steps To Reproduce

不确定是否能准确复现

Environment

- OS: centos 7
- Python: 3.10.13
- Transformers: 4.33.2
- PyTorch: 2.0.1
- CUDA Support (`python -c "import torch; print(torch.cuda.is_available())"`) :True

Anything else?

No response