bojone / bert_in_keras

在Keras下微调Bert的一些例子;some examples of bert in keras
657 stars 236 forks source link

您好,请教一下,在nl2sql的例子中,pcsel的shape=(?, ?, ?, ?),会导致pcsel_loss的计算报如下错误,能否指导一下怎么处理?谢谢! #11

Closed axdwss closed 5 years ago

axdwss commented 5 years ago

TypeError Traceback (most recent call last)

in () 60 pcop_loss = K.sparse_categorical_crossentropy(cop_in, pcop) 61 pcop_loss = K.sum(pcop_loss * xm) / K.sum(xm) ---> 62 pcsel_loss = K.sparse_categorical_crossentropy(csel_in, pcsel) 63 pcsel_loss = K.sum(pcsel_loss * xm * cm) / K.sum(xm * cm) 64 loss = psel_loss + pconn_loss + pcop_loss + pcsel_loss /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py in sparse_categorical_crossentropy(target, output, from_logits, axis) 3343 output_shape = output.get_shape() 3344 targets = cast(flatten(target), 'int64') -> 3345 logits = tf.reshape(output, [-1, int(output_shape[-1])]) 3346 res = tf.nn.sparse_softmax_cross_entropy_with_logits( 3347 labels=targets, TypeError: __int__ returned non-int (type NoneType)
bojone commented 5 years ago

https://kexue.fm/archives/6771#%E5%AE%9E%E9%AA%8C%E7%BB%93%E6%9E%9C

根据这里修改keras源码

axdwss commented 5 years ago

谢谢!