zzh8829 / yolov3-tf2

YoloV3 Implemented in Tensorflow 2.0
MIT License
2.51k stars 913 forks source link

Yolo loss binary_crossentropy version #396

Open ZXTFINAL opened 2 years ago

ZXTFINAL commented 2 years ago

change

TODO: use binary_crossentropy instead

class_loss = obj_mask * sparse_categorical_crossentropy( true_class_idx, pred_class)

to

true_class_onehot = tf.one_hot(tf.cast(true_class_idx, tf.int64), depth=classes, axis=-1) true_class_binary = tf.reshape(true_class_onehot, (tf.shape(y_true)[0], grid_size,grid_size,tf.shape(y_true)[3],-1, 1)) pred_class_binary = tf.reshape(pred_class, (tf.shape(y_true)[0], grid_size,grid_size,tf.shape(y_true)[3],-1, 1)) class_loss = obj_mask * tf.reduce_sum(binary_crossentropy(true_class_binary, pred_class_binary), axis=-1)