Open ZionDoki opened 11 months ago
尝试使用 adapter 方法微调模型,不论是在 fp16 还是 fp32 环境下都会出现异常值,以下是部分代码示意:
fp16
fp32
def check_exception(tensor): return torch.isnan(tensor).any() or torch.isinf(tensor).any() class adapter(nn.Module): def __init__(self, config, *, **): self.wi = nn.Linear(config.hidden_size, config.la_hidden_size) self.wo = nn.Linear(config.la_hidden_size, config.hidden_size) self.activation = nn.ReLU() self.reset_parameters() def forward(self, x): print("x", check_exception(x)) # 下图的日志说明异常值产生在 adapter 模块外部,及 glm 本身 shortcut = x print("wi", check_exception(self.wi.weight)) x = self.wi(x) print(2, check_exception(x)) x = self.activation(x) print(3, check_exception(x)) print("wi", check_exception(self.wo.weight)) x = self.wo(x) print(4, check_exception(x)) if check_exception(x): raise BaseException("!") return self.v1 * x + shortcut # 这些 adapter 插入到 MLP 模块如下 class MLPWithAdapter(MLP): def __init__(self, config: ChatGLMConfig, device=None): super().__init__(config, device) if config.only_adapter_trainable: for params in super().parameters(): params.requires_grad = False self.adapteri, self.adaptero = None, None if config.use_local_adapter: self.adapteri = Adapter(config) self.adaptero = Adapter(config) def forward(self, hidden_states): # [s, b, h] if self.adapteri: hidden_states = self.adapteri(hidden_states) intermediate_parallel = self.dense_h_to_4h(hidden_states) intermediate_parallel = self.activation_func(intermediate_parallel) output = self.dense_4h_to_h(intermediate_parallel) if self.adaptero: hidden_states = self.adapteri(hidden_states) return output
表现出来的结果是 loss 一直为 0 ,排查发现在 forward pass 过程中出现了异常值。如下图所示,异常值来自 adapter 的输入
我使用的是 ChatGLMForConditionalGeneration 进行微调,输入的数据格式未按照官方要求,但问题出现在 forward pass 过程且前几个模块中并未出现异常值,考虑应该不是输入数据的问题。
ChatGLMForConditionalGeneration
希望可以正常微调
不确定是否能准确复现
- OS: centos 7 - Python: 3.10.13 - Transformers: 4.33.2 - PyTorch: 2.0.1 - CUDA Support (`python -c "import torch; print(torch.cuda.is_available())"`) :True
No response
Is there an existing issue for this?
Current Behavior
尝试使用 adapter 方法微调模型,不论是在
fp16
还是fp32
环境下都会出现异常值,以下是部分代码示意:表现出来的结果是 loss 一直为 0 ,排查发现在 forward pass 过程中出现了异常值。如下图所示,异常值来自 adapter 的输入
我使用的是
ChatGLMForConditionalGeneration
进行微调,输入的数据格式未按照官方要求,但问题出现在 forward pass 过程且前几个模块中并未出现异常值,考虑应该不是输入数据的问题。Expected Behavior
希望可以正常微调
Steps To Reproduce
不确定是否能准确复现
Environment
Anything else?
No response