BrikerMan / Kashgari

Kashgari is a production-level NLP Transfer learning framework built on top of tf.keras for text-labeling and text-classification, includes Word2Vec, BERT, and GPT2 Language Embedding.
http://kashgari.readthedocs.io/
Apache License 2.0
2.39k stars 441 forks source link

[BUG] 多标签NER标注,callback查看F1的时候报错 #185

Closed RichardHWD closed 5 years ago

RichardHWD commented 5 years ago

报错: File "cnn_train.py", line 102, in <module> callbacks=[eval_callback]) File "/home/wendong/anaconda2/envs/kashgari/lib/python3.6/site-packages/kashgari/tasks/base_model.py", line 295, in fit **fit_kwargs) File "/home/wendong/anaconda2/envs/kashgari/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 1433, in fit_generator steps_name='steps_per_epoch') File "/home/wendong/anaconda2/envs/kashgari/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_generator.py", line 331, in model_iteration callbacks.on_epoch_end(epoch, epoch_logs) File "/home/wendong/anaconda2/envs/kashgari/lib/python3.6/site-packages/tensorflow/python/keras/callbacks.py", line 311, in on_epoch_end callback.on_epoch_end(epoch, logs) File "/home/wendong/anaconda2/envs/kashgari/lib/python3.6/site-packages/kashgari/callbacks.py", line 50, in on_epoch_end precision = metrics.precision_score(y_true, y_pred, average=self.average) File "/home/wendong/anaconda2/envs/kashgari/lib/python3.6/site-packages/sklearn/metrics/classification.py", line 1569, in precision_score sample_weight=sample_weight) File "/home/wendong/anaconda2/envs/kashgari/lib/python3.6/site-packages/sklearn/metrics/classification.py", line 1415, in precision_recall_fscore_support pos_label) File "/home/wendong/anaconda2/envs/kashgari/lib/python3.6/site-packages/sklearn/metrics/classification.py", line 1239, in _check_set_wise_labels y_type, y_true, y_pred = _check_targets(y_true, y_pred) File "/home/wendong/anaconda2/envs/kashgari/lib/python3.6/site-packages/sklearn/metrics/classification.py", line 72, in _check_targets type_true = type_of_target(y_true) File "/home/wendong/anaconda2/envs/kashgari/lib/python3.6/site-packages/sklearn/utils/multiclass.py", line 260, in type_of_target raise ValueError('You appear to be using a legacy multi-label data' ValueError: You appear to be using a legacy multi-label data representation. Sequence of sequences are no longer supported; use a binary array or sparse matrix instead - the MultiLabelBinarizer transformer can convert to this format.

我的代码是按照教程写的,训练正常: `model = BiLSTM_CRF_Model (embedding=stack_embedding) eval_callback = EvalCallBack(kash_model=model, valid_x=valid_x, valid_y=valid_y, step=1)

model.fit(train_x,
          train_y,
          valid_x,
          valid_y,
          batch_size=128,
          epochs=180,
          callbacks=[eval_callback])`

老哥,在线等!感恩!

BrikerMan commented 5 years ago

似乎是 scikit-learn 版本问题,麻烦确认一下你机器上的版本。要求: scikit-learn>=0.21.1