PaddlePaddle / PaddleNLP

👑 Easy-to-use and powerful NLP and LLM library with 🤗 Awesome model zoo, supporting wide-range of NLP tasks from research to industrial applications, including 🗂Text Classification, 🔍 Neural Search, ❓ Question Answering, ℹ️ Information Extraction, 📄 Document Intelligence, 💌 Sentiment Analysis etc.
https://paddlenlp.readthedocs.io
Apache License 2.0
12k stars 2.93k forks source link

[Bug]: 使用amp_master_grad的同时开启recompute,weight没有main_grad #8365

Closed Wong4j closed 2 months ago

Wong4j commented 4 months ago

软件环境

- paddlepaddle: 
- paddlepaddle-gpu: 2.6
- paddlenlp: 2.7.1.post0

重复问题

错误描述

正常情况下,开启--amp_master_grad后,所有的weight都会有main_grad。
但是当使用recompute=full后,自定义python op 的backward中的weight却没有main_grad。

稳定复现步骤 & 代码

以llama训练为例

修改fused_layers.py #L32-L41的代码为:

    def forward(ctx, x, weight, bias=None, name=None):
        y = origin_linear(x, weight, bias)

        ctx.save_for_backward(weight)
        ctx.x = x
        ctx.bias = bias
        return y

    @staticmethod
    def backward(ctx, y_grad):
        weight, = ctx.saved_tensor()  #这个weight没有main_grad
        x = ctx.x
        bias = ctx.bias
        if hasattr(weight, "main_grad"):
            print("weight has main_grad")
        else:
            print("weight has no main_grad")

运行llama训练,backward就会报weight没有main_grad

而如果不使用ctx.save_for_backwardctx.saved_tensor(),用ctx.weight=weightweight=ctx.weight替代,则weight会有main_grad。

我debug发现,这大概是因为在开启recompute时,save_for_backward会触发recompute.py#L340这里的拷贝,将weight拷贝给一个名为weight.name+"cpy"的tensor,但并没有拷贝main_grad。

GuoxiaWang commented 4 months ago
        ctx.save_for_backward(weight)
        ctx.x = x
        ctx.bias = bias

这里为什么是拆开写的?试试下面的写法?

ctx.save_for_backward(x, weight, bias) 
x, weight, bias = ctx.saved_tensor()
Wong4j commented 4 months ago

@GuoxiaWang 因为这个issue里面我关心的重点是:开启recompute的时候ctx.save_for_backward(weight)这种写法会遇到backward中的weight没有main_grad的问题。

你说的这种写法是fused_layer.py中原本的写法,我也测试过,开启recompute=full后会遇到下面这个奇怪的错误,这就需要开另外一个issue了。

    outputs = model(**inputs)
  File "/usr/local/lib/python3.10/dist-packages/paddle/nn/layer/layers.py", line 1429, in __call__
    return self.forward(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/paddle/distributed/fleet/meta_parallel/meta_parallel_base.py", line 37, in forward
    output = self._layers(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/paddle/nn/layer/layers.py", line 1429, in __call__
    return self.forward(*inputs, **kwargs)
  File "/workspace/PaddleNLP/paddlenlp/transformers/llama/modeling.py", line 1913, in forward
    outputs = self.llama(
  File "/usr/local/lib/python3.10/dist-packages/paddle/nn/layer/layers.py", line 1429, in __call__
    return self.forward(*inputs, **kwargs)
  File "/workspace/PaddleNLP/paddlenlp/transformers/llama/modeling.py", line 1664, in forward
    layer_outputs = self.recompute_training_full(
  File "/workspace/PaddleNLP/paddlenlp/transformers/llama/modeling.py", line 1535, in recompute_training_full
    hidden_states = self.recompute_func(
  File "/usr/local/lib/python3.10/dist-packages/paddle/distributed/fleet/utils/__init__.py", line 142, in recompute
    return fleet.recompute.recompute(function, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/paddle/distributed/fleet/recompute/recompute.py", line 532, in recompute
    return _recompute_without_reentrant(function, preserve, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/paddle/distributed/fleet/recompute/recompute.py", line 399, in _recompute_without_reentrant
    outputs = function(*args, **kwargs)
  File "/workspace/PaddleNLP/paddlenlp/transformers/llama/modeling.py", line 1531, in custom_forward
    return module(*inputs)
  File "/usr/local/lib/python3.10/dist-packages/paddle/nn/layer/layers.py", line 1429, in __call__
    return self.forward(*inputs, **kwargs)
  File "/workspace/PaddleNLP/paddlenlp/transformers/llama/modeling.py", line 1228, in forward
    outputs = self.self_attn(
  File "/usr/local/lib/python3.10/dist-packages/paddle/nn/layer/layers.py", line 1429, in __call__
    return self.forward(*inputs, **kwargs)
  File "/workspace/PaddleNLP/paddlenlp/transformers/llama/modeling.py", line 901, in forward
    query_states = self.q_proj(hidden_states)
  File "/usr/local/lib/python3.10/dist-packages/paddle/nn/layer/layers.py", line 1429, in __call__
    return self.forward(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/paddle/distributed/fleet/layers/mpu/mp_layers.py", line 516, in forward
    output_parallel = self.linear(
  File "/workspace/PaddleNLP/llm/fused_layers.py", line 36, in forward
    ctx.save_for_backward(x, weight, bias)
  File "/usr/local/lib/python3.10/dist-packages/paddle/autograd/py_layer.py", line 91, in save_for_backward
    self.container = tensors
ValueError: (InvalidArgument) save_for_backward only support Tensor, list of Tensor, tuple of Tensor. (at /opt/paddle/paddle/paddle/fluid/pybind/eager_py_layer.cc:644)
Wong4j commented 4 months ago

更新一下,recompute设置reentrant=True,可以避开这个bug。仅reentrant = False会遇到这个bug。

Wong4j commented 4 months ago

更新一下,recompute设置reentrant=True,可以避开这个bug。仅reentrant = False会遇到这个bug。

@Xreki 麻烦帮忙找Paddle这边熟悉recompute的同学看一下

Xreki commented 4 months ago

![image](https://github.com@Wong4j PaddlePaddle/PaddleNLP/assets/12538138/84258d77-048e-41a2-9641-6d7a303ba6bf)

@Wong4j 这个倒是reentrant=False时的已知问题

github-actions[bot] commented 2 months ago

This issue is stale because it has been open for 60 days with no activity. 当前issue 60天内无活动,被标记为stale。

github-actions[bot] commented 2 months ago

This issue was closed because it has been inactive for 14 days since being marked as stale. 当前issue 被标记为stale已有14天,即将关闭。