PaddlePaddle / PaddleNLP

👑 Easy-to-use and powerful NLP and LLM library with 🤗 Awesome model zoo, supporting wide-range of NLP tasks from research to industrial applications, including 🗂Text Classification, 🔍 Neural Search, ❓ Question Answering, ℹ️ Information Extraction, 📄 Document Intelligence, 💌 Sentiment Analysis etc.
https://paddlenlp.readthedocs.io
Apache License 2.0
12.11k stars 2.94k forks source link

ERNIE+CRF 模型动转静导出异常 #458

Closed Wakp closed 1 year ago

Wakp commented 3 years ago

模型保存时报错,我参考ernie+crf例子,发现还是有问题,请予以帮忙

ValueError: We don't support to define layer with parameters in the function decorated by @declarative. Because that will re-defined parameters every time when you run the function. But we found parameter(linear_73.w_0) was created in the decorated function. Please define the layer with parameters in __init__ function.

import paddle
import paddle.nn as nn
from paddlenlp.transformers import BertModel, BertTokenizer, ErnieModel, ErnieTokenizer
from paddlenlp.layers import crf
from paddle.static import InputSpec
import numpy as np

class MutilErnielinear(nn.Layer):
    def __init__(self, label_nums=[], model_name_or_path='bert-wwm-chinese', dropout=None):
        super(MutilErnielinear, self).__init__()
        assert label_nums is not []
        self.label_num = len(label_nums)
        if 'bert' in  model_name_or_path:
            self.ernie = BertModel.from_pretrained(model_name_or_path)  # tokenizer加载
        elif 'ernie' in model_name_or_path:
            self.ernie = ErnieModel.from_pretrained(model_name_or_path)  # tokenizer加载
        else:
            raise ("bad Model! ____")
        self.linear_to_label = [nn.Linear(769 if i > 0 else 768, label_nums[i]) for i in range(self.label_num)]
        self.relu = [nn.ReLU6() for i in range(self.label_num)]
        self.dropout = [nn.Dropout(dropout if dropout is not None else 0.1) for _ in range(self.label_num)]
        self.loss = [paddle.nn.CrossEntropyLoss() for _ in range(self.label_num)]
        self.out = []
        self.logits = []

    def forward(self, X, lengths):
        self.out = []
        self.logits = []
        sequence_output = self.ernie(X)[0]
        for i in range(self.label_num):
            if i == 0:
                feature_drop = self.dropout[i](sequence_output)
                logit = self.linear_to_label[i](feature_drop)

            else:
                feature = paddle.concat([sequence_output, last], axis=-1)
                feature_drop = self.dropout[i](feature)
                feature_drop = self.relu[i](feature_drop)
                logit = self.linear_to_label[i](feature_drop)
            self.logits.append(logit)
            batch_path = paddle.argmax(logit, axis=-1)
            last = paddle.unsqueeze(batch_path, -1)
            last = last.astype("float32")
            self.out.append(batch_path)
        return self.out, self.logits

    def compute_loss(self, logits, Labels, length):
        loss = []
        for i in range(self.label_num):
            if Labels is not None:
                l = self.loss[i](logits[i], Labels[i])
                loss.append(paddle.mean(l))
        return loss

model=MutilErnielinear([8,12])
paddle.jit.to_static(model, input_spec=[
    InputSpec(shape=[None, None], dtype="int64", name='token_ids'),
    InputSpec(shape=[None], dtype="int64", name='length')])

paddle.jit.save(model, "test")
joey12300 commented 3 years ago

@Wakp 您好,我修改了您的代码后可以运行了。凡是 由layer组成的list必须使用nn.LayerList包住才能跑通。

import paddle
import paddle.nn as nn
from paddlenlp.transformers import BertModel, BertTokenizer, ErnieModel, ErnieTokenizer
from paddlenlp.layers import crf
from paddle.static import InputSpec
import numpy as np

class MutilErnielinear(nn.Layer):
    def __init__(self, label_nums=[], model_name_or_path='bert-wwm-chinese', dropout=None):
        super(MutilErnielinear, self).__init__()
        assert label_nums is not []
        self.label_num = len(label_nums)
        if 'bert' in  model_name_or_path:
            self.ernie = BertModel.from_pretrained(model_name_or_path)  # tokenizer加载
        elif 'ernie' in model_name_or_path:
            self.ernie = ErnieModel.from_pretrained(model_name_or_path)  # tokenizer加载
        else:
            raise ("bad Model! ____")
        # 必须使用nn.LayerList包住由nn.Layer组成的list
        self.linear_to_label = nn.LayerList([nn.Linear(769 if i > 0 else 768, label_nums[i]) for i in range(self.label_num)])
        self.relu = nn.LayerList([nn.ReLU6() for i in range(self.label_num)])
        self.dropout = nn.LayerList([nn.Dropout(dropout if dropout is not None else 0.1) for _ in range(self.label_num)])
        self.loss = nn.LayerList([paddle.nn.CrossEntropyLoss() for _ in range(self.label_num)])
        self.out = []
        self.logits = []

    def forward(self, X, lengths):
        self.out = []
        self.logits = []
        sequence_output = self.ernie(X)[0]
        for i in range(self.label_num):
            if i == 0:
                feature_drop = self.dropout[i](sequence_output)
                logit = self.linear_to_label[i](feature_drop)

            else:
                feature = paddle.concat([sequence_output, last], axis=-1)
                feature_drop = self.dropout[i](feature)
                feature_drop = self.relu[i](feature_drop)
                logit = self.linear_to_label[i](feature_drop)
            self.logits.append(logit)
            batch_path = paddle.argmax(logit, axis=-1)
            last = paddle.unsqueeze(batch_path, -1)
            last = last.astype("float32")
            self.out.append(batch_path)
        # 动转静要求每个输入都要与输出由关联,原有代码lengths没被使用,导致报错。
        return self.out, self.logits, lengths

    def compute_loss(self, logits, Labels, length):
        loss = []
        for i in range(self.label_num):
            if Labels is not None:
                l = self.loss[i](logits[i], Labels[i])
                loss.append(paddle.mean(l))
        return loss

model=MutilErnielinear([8,12])
static_model = paddle.jit.to_static(model, input_spec=[
    InputSpec(shape=[None, None], dtype="int64", name='token_ids'),
    InputSpec(shape=[None], dtype="int64", name='length')])

paddle.jit.save(static_model, "test")
Wakp commented 3 years ago

感谢,我的问题已经得到解决。而且这里没有使用nn.LayerList时,在训练过程中不会报bug,一切正常,只是在使用paddle.save()保存时参数时会保存模型初始化的参数,而不是经过训练优化后的参数,而且全程不报错。用户无感,希望可以给予提示,否则对于小白来说很容易踩坑。感谢

joey12300 commented 3 years ago

感谢,我的问题已经得到解决。而且这里没有使用nn.LayerList时,在训练过程中不会报bug,一切正常,只是在使用paddle.save()保存时参数时会保存模型初始化的参数,而不是经过训练优化后的参数,而且全程不报错。用户无感,希望可以给予提示,否则对于小白来说很容易踩坑。感谢

嗯,对于使用 list存的layers必须在外层使用nn.LayerList,否则在训练时可以正常执行前向后向,但是无法更新参数。错误提示确实有必要优化,感谢您的建议

github-actions[bot] commented 1 year ago

This issue is stale because it has been open for 60 days with no activity. 当前issue 60天内无活动,被标记为stale。

github-actions[bot] commented 1 year ago

This issue was closed because it has been inactive for 14 days since being marked as stale. 当前issue 被标记为stale已有14天,即将关闭。