PaddlePaddle / Paddle

PArallel Distributed Deep LEarning: Machine Learning Framework from Industrial Practice (『飞桨』核心框架,深度学习&机器学习高性能单机、分布式训练和跨平台部署)
http://www.paddlepaddle.org/
Apache License 2.0
22.23k stars 5.58k forks source link

AssertionError: When Variable is used as the condition of if/while , Variable can only contain one element. #53991

Closed LyTingHub closed 5 months ago

LyTingHub commented 1 year ago

bug描述 Describe the Bug

代码:

class LSTM_Model(nn.Layer):
    def __init__(self,vocab_num, emb_size, hidden_size, num_layers, num_labels, dropout):
        '''参数意义:vocab_num词向量模型的词表大小,emb_size词向量维度,num_labels标签集数量'''
        super(LSTM_Model, self).__init__()
        self.embedding = nn.Embedding(vocab_num, emb_size)   # 字向量词表
        self.lstm = nn.LSTM(emb_size, hidden_size, num_layers=num_layers, direction='bidirect', dropout=dropout)
        self.attention_linear = nn.Linear(hidden_size * 2, hidden_size)
        self.linear = nn.Linear(hidden_size * 2, num_labels)
        self.dropout = nn.Dropout(dropout)

    def forward(self,input_ids,label_length,target):
        token_emb = self.embedding(input_ids)
        sequence_output, (hidden, cell) = self.lstm(token_emb)  # [batch_size,time_steps,num_directions * hidden_size]
        sequence_output = self.dropout(sequence_output)
        logits = self.linear(sequence_output)  # lstm预测的得分,发射特征

        ## feature_out = fluid.layers.fc(input=hidden_1, size=len(label_dict), act='tanh')
        # 调用内置 CRF 函数并做状态转换解码.

        if target is not None:
            print(label_length)
            # fluid.enable_dygraph()
            emission = paddle.clone(logits)
            emission = paddle.reshape(emission,[-1, max_len, label_num])
            # label = paddle.clone(target)
            label = paddle.reshape(target,[-1, max_len, 1])
            # paddle.enable_static()
            # label_length = paddle.ones([batch_size, 1], dtype='int64')
            crf_cost = fluid.layers.linear_chain_crf(
                input=emission, 
                label=label,
                param_attr=fluid.ParamAttr(name='crfw3', learning_rate=0.0001),
                length=label_length
                )
            # crf_decode = paddle.static.nn.crf_decoding(input=emission,param_attr=paddle.ParamAttr(name="crfw"))
            avg_cost = fluid.layers.mean(crf_cost)
        else:
            avg_cost = 0

        # avg_cost = 0
        return logits, avg_cost

model = LSTM_Model(vocab_num, emb_size, hidden_size, num_layers, num_labels, dropout)
# 中间部分省略了参数设置和数据加载
model.train()
for epoch in range(num_epoch):
        for idx, (input_ids,labels,seq_lens) in enumerate(train_loader):
            print(input_ids)
            print(labels)
            print(seq_lens)
            logits,_= model(input_ids,seq_lens, labels)  # 这一行报错

报错信息:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
/tmp/ipykernel_4072/1945018549.py in <module>
      5 init_ckpt = './data_/{}/final.pdparams'.format(base_dir)
      6 
----> 7 do_train()
      8 # do_predict()

/tmp/ipykernel_4072/2064799652.py in do_train()
    100             print(seq_lens)
    101             # print(input_ids.size(),labels.size(),seq_lens.size())
--> 102             logits,_= model(input_ids,seq_lens, labels)
    103             probs_ids = paddle.argmax(logits, -1).numpy()
    104             # print(logits.shape,labels.shape)

/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py in __call__(self, *inputs, **kwargs)
    928             return self.forward(*inputs, **kwargs)
    929         else:
--> 930             return self._dygraph_call_func(*inputs, **kwargs)
    931 
    932     def forward(self, *inputs, **kwargs):

/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py in _dygraph_call_func(self, *inputs, **kwargs)
    913                 outputs = self.forward(*inputs, **kwargs)
    914         else:
--> 915             outputs = self.forward(*inputs, **kwargs)
    916 
    917         for forward_post_hook in self._forward_post_hooks.values():

/tmp/ipykernel_4072/3829405380.py in forward(self, input_ids, label_length, target)
     31                 label=label,
     32                 param_attr=fluid.ParamAttr(name='crfw3', learning_rate=0.0001),
---> 33                 length=label_length
     34                 )
     35             # crf_decode = paddle.static.nn.crf_decoding(input=emission,param_attr=paddle.ParamAttr(name="crfw"))

/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/nn.py in linear_chain_crf(input, label, param_attr, length)
    883     check_variable_and_dtype(label, 'label', ['int64'], 'linear_chain_crf')
    884     helper = LayerHelper('linear_chain_crf', **locals())
--> 885     size = input.shape[2] if length else input.shape[1]
    886     transition = helper.create_parameter(
    887         attr=helper.param_attr,

/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/varbase_patch_methods.py in __bool__(self)
    667 
    668     def __bool__(self):
--> 669         return self.__nonzero__()
    670 
    671     def __array__(self, dtype=None):

/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/varbase_patch_methods.py in __nonzero__(self)
    657     def __nonzero__(self):
    658         numel = np.prod(self.shape)
--> 659         assert numel == 1, "When Variable is used as the condition of if/while , Variable can only contain one element."
    660         if framework._in_eager_mode_:
    661             assert self._is_initialized(), "tensor not initialized"

AssertionError: When Variable is used as the condition of if/while , Variable can only contain one element.

其他补充信息 Additional Supplementary Information

代码原本是可以运行的,忽然出现了这个报错,很崩溃,检查了很多遍还是不太明白为什么已经传入了label_length的信息,还会出现没有初始化的报错? 下面的截图依次是 input_ids,labels,seq_lens image 麻烦帮忙看看是什么问题,非常感谢!

lyuwenyu commented 1 year ago

你这是动态图api和静态图的api混合用的?

paddle-bot[bot] commented 5 months ago

Since you haven\'t replied for more than a year, we have closed this issue/pr. If the problem is not solved or there is a follow-up one, please reopen it at any time and we will continue to follow up. 由于您超过一年未回复,我们将关闭这个issue/pr。 若问题未解决或有后续问题,请随时重新打开,我们会继续跟进。