allanzelener / YAD2K

YAD2K: Yet Another Darknet 2 Keras
Other
2.71k stars 877 forks source link

is it possible to change the classification loss to crossentropy? #124

Open LaTournesol opened 6 years ago

LaTournesol commented 6 years ago

Hi there,

I'm using retrain_yolo.py to train yolo on my own data set.

I'm getting satisfying results on the bounding box accuracy, but sometimes boxes would be classified to the wrong class. (even though I only have 2 classes)

That's when I took a closer look at the yolo model, and I found that the yolo_loss uses squared error for classification and not crossentropy.

Can I ask why is this?

I tried to change the classification loss to crossentropy but I got nan for val_loss at first, and then some seemingly nice decreases, and then back to nan.

Here is what I've tried:

# classification_loss = (class_scale * detectors_mask * K.square(matching_classes - pred_class_prob))

classification_loss = (class_scale * detectors_mask * calculate_crossentropy_loss(matching_classes, pred_class_prob))

classification_loss_sum = K.sum(classification_loss) * (1/207)

def calculate_crossentropy_loss(true_label, pred_label): return - (true_label * K.log(pred_label))

I am hard-coding the number of samples (207) because I'm not sure how to get that number given those tensors.

Can you give me a little ideas on how to modify this? Thank you!

LaTournesol commented 6 years ago

Hi, I think I successfully changed it.

I was getting nan for val error because it predicts the label too well so that the term in K.log() becomes zeros, after adding a small bias into the log function, I got rid of the nan and trained the model successfully. And the 1/207 term is also redundant so I removed it.

The results were pretty good on my test set. Both classification and localization error are close to 100%.

This is what I have now: classification_loss = (calculate_crossentropy_loss(matching_classes, pred_class_prob))

classification_loss_sum = K.sum(classification_loss)

def calculate_crossentropy_loss(true_label, pred_label): return - (true_label * K.log(pred_label + 1e-8))