AidenHuen / BERT-BiLSTM-CRF

BERT-BiLSTM-CRF的Keras版实现
41 stars 15 forks source link

crf的mask报错 #3

Open lan2720 opened 4 years ago

lan2720 commented 4 years ago

这份代码跑不了了,crf层会报错:

Traceback (most recent call last):
  File "/data/jarvixwang/Project/BERT-BiLSTM-CRF/train.py", line 32, in <module>
    train_bert_model(para, use_generator=False)
  File "/data/jarvixwang/Project/BERT-BiLSTM-CRF/train.py", line 13, in train_bert_model
    model = ModelLib.BERT_MODEL(para)
  File "/data/jarvixwang/Project/BERT-BiLSTM-CRF/ModelLib.py", line 21, in BERT_MODEL
    crf_output = crf(repre)
  File "/data/jarvixwang/venv/BERT-BiLSTM-CRF/lib64/python3.6/site-packages/keras/engine/base_layer.py", line 489, in __call__
    output = self.call(inputs, **kwargs)
  File "/data/jarvixwang/venv/BERT-BiLSTM-CRF/lib64/python3.6/site-packages/keras_contrib/layers/crf.py", line 292, in call
    test_output = self.viterbi_decoding(X, mask)
  File "/data/jarvixwang/venv/BERT-BiLSTM-CRF/lib64/python3.6/site-packages/keras_contrib/layers/crf.py", line 564, in viterbi_decoding
    argmin_tables = self.recursion(input_energy, mask, return_logZ=False)
  File "/data/jarvixwang/venv/BERT-BiLSTM-CRF/lib64/python3.6/site-packages/keras_contrib/layers/crf.py", line 516, in recursion
    mask2 = K.cast(K.concatenate([mask, K.zeros_like(mask[:, :1])], axis=1),
  File "/data/jarvixwang/venv/BERT-BiLSTM-CRF/lib64/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2427, in concatenate
    return tf.concat([to_dense(x) for x in tensors], axis)
  File "/data/jarvixwang/venv/BERT-BiLSTM-CRF/lib64/python3.6/site-packages/tensorflow/python/util/dispatch.py", line 180, in wrapper
    return target(*args, **kwargs)
  File "/data/jarvixwang/venv/BERT-BiLSTM-CRF/lib64/python3.6/site-packages/tensorflow/python/ops/array_ops.py", line 1299, in concat
    return gen_array_ops.concat_v2(values=values, axis=axis, name=name)
  File "/data/jarvixwang/venv/BERT-BiLSTM-CRF/lib64/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 1256, in concat_v2
    "ConcatV2", values=values, axis=axis, name=name)
  File "/data/jarvixwang/venv/BERT-BiLSTM-CRF/lib64/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 499, in _apply_op_helper
    raise TypeError("%s that don't all match." % prefix)
TypeError: Tensors in list passed to 'values' of 'ConcatV2' Op have types [bool, float32] that don't all match.