agentmorris / MegaDetector

MegaDetector is an AI model that helps conservation folks spend less time doing boring things with camera trap images.
MIT License
109 stars 24 forks source link

The use of sample weights in calculating loss/accuracy of validation set #138

Closed anhquan0412 closed 3 months ago

anhquan0412 commented 3 months ago

Hi Dan,

This might not really is an issue, I just want to double-check if my understanding of this code is correct So in this line here: https://github.com/agentmorris/MegaDetector/blob/ab53c1081a894cafeefc9d6ef0bf60eefb501e4a/megadetector/classification/train_classifier.py#L421

The sample weights will be used in calculating the loss and accuracy of the validation set, but not on in the train set (line 411)? Isn't it the other way around? (I thought you might not to keep the validation set as is, as in there's no weight sampling on the validation set)

Let me know if my thought is valid, or did I miss something.

Thank you! Quan

agentmorris commented 3 months ago

This is definitely a trip down memory lane... first, a reminder that this code only applies to MegaClassifier training, it is not related to MegaDetector. And MegaClassifier was trained ~four years ago, which is approximately infinity in deep learning years.

As far as inference goes, MegaClassifier is still going strong! But the training code is largely obsolete; I recommend training new camera trap image classifiers via MEWC, or - if you want to take a slightly more DIY approach - via timm's training script.

That said, you've got me curious, so I'll ping Chris, who wrote this code and trained MegaClassifier.

chrisyeh96 commented 3 months ago

Great question @anhquan0412!

TLDR

Class imbalance is accounted for in both the train and val splits, but it is handled differently for a good reason (see longer explanation below). Running the training script with weighted labels (python train_classifier.py --label-weighted) does the following:

Longer explanation

We weight the loss function in validation split for the straightforward reason that we want to account for class-imbalance when measuring the classifier's performance, since we use the validation split for choosing the best hyperparameters.

We originally also implemented a class-weighted loss function during training, but this turned out to lead to either unstable training (gradients blow up) for moderate to large learning rates, or a model that doesn't learn at all (gradients too small) for smaller learning rates. The reason is because the class weights are heavily skewed, with the weights between the classes differing by a factor on the order of 100x or more. (My memory is a bit hazy on the exact class imbalance, but some classes had only a dozen training examples, while other classes had thousands). Thus, for moderate-to-large learning rates, the gradient would spike whenever a rare class was encountered in a minibatch, because the loss for that example would be scaled up by a factor of ~100x more than the losses for other examples. But if we lowered the learning rate, then the model wouldn't learn on the more common classes.

To get around this issue, I instead decided to use class-based weighted sampling of the training set (see code here), with an unweighted loss function. Training examples from a rare class would be sampled more frequently than usual during training, whereas examples from a common class would be sampled less frequently. This keeps the gradients more consistent during training, but it still helps mitigate the class imbalance issue.

Because we don't need gradients on the validation set, this isn't an issue on the validation set, so we can just use a weighted loss!

P.S. @agentmorris This indeed was a trip down memory lane, but I think I still remembered enough to piece together the rationale for this!

agentmorris commented 3 months ago

Thanks, @chrisyeh96! Closing this issue since Chris's explanation was quite comprehensive; @anhquan0412 feel free to comment if this doesn't address your question.

anhquan0412 commented 3 months ago

wow I didn't expect such a detailed explanation from you guys. Yes it is clear to me on why you don't do weight sampling on train set. Thank you @chrisyeh96 and @agentmorris so much, I hope this memory lane trip was fun haha

I am also in the process of training/finetuning a computer vision classifier for wildlife conservancy purposes as well (in Australia), and I am trying to see which are the best practices when it comes to these dataset.

I have one more question on train the classifier: in the CrossEntropyLoss, there's a 'weight' parameter to set the weight for each class as well. I wonder if you guys have tried to set it, and is it similar to the class-weighted loss function you have implemented?

chrisyeh96 commented 3 months ago

@anhquan0412 I was aware of the weight parameter for the nn.CrossEntropyLoss() function at the time I wrote this training code. However, the weight parameter can only be used to apply class-based weighting, whereas I also wanted to try weighting the animal species classification loss by the confidence of the MegaDetector animal detector. (Some of the training data for MegaClassifier was automatically annotated based on MegaDetector outputs, hence we didn't have 100% confidence in the quality of the labels used to train MegaClassifier.) This type of per-example weighting, turned on by the --weight-by-detection-conf flag, is not supported by nn.CrossEntropyLoss(), which is why I did not use it.

anhquan0412 commented 3 months ago

Got it! Thanks so much 😊