allanzelener / YAD2K

YAD2K: Yet Another Darknet 2 Keras
2.71k stars 877 forks source link

is it possible to change the classification loss to crossentropy? #124

Open LaTournesol opened 6 years ago

LaTournesol commented 6 years ago

Hi there,

I'm using to train yolo on my own data set.

I'm getting satisfying results on the bounding box accuracy, but sometimes boxes would be classified to the wrong class. (even though I only have 2 classes)

That's when I took a closer look at the yolo model, and I found that the yolo_loss uses squared error for classification and not crossentropy.

Can I ask why is this?

I tried to change the classification loss to crossentropy but I got nan for val_loss at first, and then some seemingly nice decreases, and then back to nan.

Here is what I've tried:

# classification_loss = (class_scale * detectors_mask * K.square(matching_classes - pred_class_prob))

classification_loss = (class_scale * detectors_mask * calculate_crossentropy_loss(matching_classes, pred_class_prob))

classification_loss_sum = K.sum(classification_loss) * (1/207)

def calculate_crossentropy_loss(true_label, pred_label): return - (true_label * K.log(pred_label))

I am hard-coding the number of samples (207) because I'm not sure how to get that number given those tensors.

Can you give me a little ideas on how to modify this? Thank you!

LaTournesol commented 6 years ago

Hi, I think I successfully changed it.

I was getting nan for val error because it predicts the label too well so that the term in K.log() becomes zeros, after adding a small bias into the log function, I got rid of the nan and trained the model successfully. And the 1/207 term is also redundant so I removed it.

The results were pretty good on my test set. Both classification and localization error are close to 100%.

This is what I have now: classification_loss = (calculate_crossentropy_loss(matching_classes, pred_class_prob))

classification_loss_sum = K.sum(classification_loss)

def calculate_crossentropy_loss(true_label, pred_label): return - (true_label * K.log(pred_label + 1e-8))