ultralytics / yolov5

YOLOv5 🚀 in PyTorch > ONNX > CoreML > TFLite
https://docs.ultralytics.com
GNU Affero General Public License v3.0
51.23k stars 16.44k forks source link

mAP drop in yolov5s with v2.0 release #496

Closed hal-314 closed 4 years ago

hal-314 commented 4 years ago

❔Question

First of all, you are doing an excellent job!! Keep doing it :D

I notice that all yolov5 model have increase mAP with v2.0 release except Yolov5s that decreased it's mAP from 36.6 to 36.1 while maintaining the same GPU speed and number of parameters. Also, I see that Yolov5m accuracy is almost exactly the same than before.

So, changes in 2.0 release have focused in the big yolo models? Do you plan increase yolov5s accuracy back to where it was with 1.0 release?

Thank you.

Additional context

github-actions[bot] commented 4 years ago

Hello @hal-314, thank you for your interest in our work! Please visit our Custom Training Tutorial to get started, and see our Jupyter Notebook Open In Colab, Docker Image, and Google Cloud Quickstart Guide for example environments.

If this is a bug report, please provide screenshots and minimum viable code to reproduce your issue, otherwise we can not help you.

If this is a custom model or data training question, please note that Ultralytics does not provide free personal support. As a leader in vision ML and AI, we do offer professional consulting, from simple expert advice up to delivery of fully customized, end-to-end production solutions for our clients, such as:

For more information please visit https://www.ultralytics.com.

glenn-jocher commented 4 years ago

@hal-314 yes this is correct, very observant! The model architectures are identical to before, we have modified their construction a bit in yolo.py and also modified training to better balance output losses among the 3 layers. This resulted in the mAP changes, which appear to benefit the larger models at the expense of the smallest model as you noticed.

We also trained 2 different architectures which greatly improved the smaller models but also slowed them down, so we neglected to release those two architectures. YOLOv5s should be trainable to the same mAP as before, or better with the default settings, or possibly by increasing the hyp obj by 20-30%. Smaller models overtrain less than larger ones.

AlexWang1900 commented 4 years ago

@glenn-jocher Hi~~! May I ask how to decide the layer balance ? it seems obj loss 3x larger in new settings on Wheat detection dataset. and box loss and obj loss ratio? should I weighted them to make them almost same?

glenn-jocher commented 4 years ago

@AlexWang1900 oh interesting. The new balance is based on empirical results on COCO. I measured the obj loss across the 3 layers for COCO and balanced them using the new balance terms in compute_loss(). Only obj losses are balanced due to their unique nature. The small object output layer is weighted 4X now compared to before, so most wheat objects must be matched in the small obj layer if you are seeing a 3X increase in obj loss.

Every dataset is different though, so I can't actually predict if v2.0 will perform better or worse on wheat. On COCO YOLOv5x mAP improved from 48.4 to 49.0 when trained with the new v2.0 updates (no architecture changes).

glenn-jocher commented 4 years ago

@AlexWang1900 you might want to try to rebalance them yourself. The guideline I use is to look for the loss component that overfits first, which is usually obj. Then sometimes reducing obj helps delay overfitting when training from scratch, when training from pretrained weights I'm not so sure, I have less experience there. If you are seeing 3X loss, and obj is now 3X larger than GIoU, you may want to try an experiment:

AlexWang1900 commented 4 years ago

Thanks a lot!!!! great guidelines!!!!

---Original--- From: "Glenn Jocher"<notifications@github.com> Date: Sat, Jul 25, 2020 12:27 PM To: "ultralytics/yolov5"<yolov5@noreply.github.com>; Cc: "Mention"<mention@noreply.github.com>;"AlexWang1900"<41888506@qq.com>; Subject: Re: [ultralytics/yolov5] mAP drop in yolov5s with v2.0 release (#496)

@AlexWang1900 you might want to try to rebalance them yourself. The guideline I use is to look for the loss component that overfits first, which is usually obj. Then sometimes reducing obj helps delay overfitting when training from scratch, when training from pretrained weights I'm not so sure, I have less experience there. If you are seeing 3X loss, and obj is now 3X larger than GIoU, you may want to try an experiment:

Train with hyp['obj'] = 1.00 (default v2.0)

Train with hyp['obj'] = 0.67

Train with hyp['obj'] = 0.33

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub, or unsubscribe.

glenn-jocher commented 4 years ago

@AlexWang1900 you're welcome!

The strange thing is my main challenge with YOLOv5 is not really training to high mAP, it's preventing overtraining. Overtraining on the larger models in particular is terrible, I'm not entirely sure what to do. This is the old and new YOLOv5x. Default training is 300 epochs, from scratch, but overtraining begins to occur as early as 200 epochs.

You can see GIoU and cls bottom out at 250 epochs, but obj bottoms at 200 epochs. If I could delay it 50 more epochs so all 3 loss components bottom out at the same time, the result would improve greatly, but I'm struggling with exactly how to do that.

results

AlexWang1900 commented 4 years ago

from my team's work on Wheat and other classification  here are some thoughts: the bn in torch has a tracker,which doesn't exist in Tensorflow. that makes the bn mean and variance converge to the whole dataset mean and variance. especially for large epoches.  that may explain pytorch effdet can't out perform Tf official effdet map score. 

we believe it's the bn parameters lack of variance as epoches gets larger,causes overfitting.  So we did these to overcome overfitting 1)mixup. for an easy approach,we just mix with background pictures at 50%chance. without changing labels. that brings large variance to bn I think.  2)cutout  3) disable trackers in torch bn module but we haven't tried it yet.  without mixup it overfits under 40epoches,now it can train more then80 with all validation losses still go down

---Original--- From: "Glenn Jocher"<notifications@github.com> Date: Sat, Jul 25, 2020 12:48 PM To: "ultralytics/yolov5"<yolov5@noreply.github.com>; Cc: "Mention"<mention@noreply.github.com>;"AlexWang1900"<41888506@qq.com>; Subject: Re: [ultralytics/yolov5] mAP drop in yolov5s with v2.0 release (#496)

@AlexWang1900 you're welcome!

The strange thing is my main challenge with YOLOv5 is not really training to high mAP, it's preventing overtraining. Overtraining on the larger models in particular is terrible, I'm not entirely sure what to do. This is the old and new YOLOv5x. Default training is 300 epochs, from scratch, but overtraining begins to occur as early as 200 epochs.

You can see GIoU and cls bottom out at 250 epochs, but obj bottoms at 200 epochs. If I could delay it 50 more epochs so all 3 loss components bottom out at the same time, the result would improve greatly, but I'm struggling with exactly how to do that.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub, or unsubscribe.

glenn-jocher commented 4 years ago

@AlexWang1900 that's interesting. mixup 50% of the images with the same ratio or a beta distribution ratio (with alpha=0.3?). Not having to update the labels is nice. I will try this on COCO then.

About the BN I would assume that as the statistics smooth during later epochs the result would get better, unless the BN stats are not representative of the val images (i.e. due to augmentations perhaps). You can test the model in .train() mode instead of .eval() mode for the BN layers to ignore the stats and use realtime mean and sigma for normalization, but this should always produce worse results I believe.

glenn-jocher commented 4 years ago

I modified a beta distribution for this by skewing the alpha params to one side to obtain ratios for this idea. Perhaps this is simply trending towards an exp distribution by doing this.

                import numpy as np
                import matplotlib.pyplot as plt
                plt.hist(np.random.beta(.3, 1.0, 1000000), 300)

Figure_1

An alternative would be to use the same alphas np.random.beta(0.3, 0.3), and to flip the labels when the ratio exceeds 50%. What do you think?

            if random.random() < 0.5:
                img2, labels2 = load_mosaic(self, random.randint(0, len(self.labels) - 1))
                r = np.random.beta(0.3, 0.3)  # mixup ratio, alpha=beta=0.3
                img = (img * r + img2 * (1 - r)).astype(np.uint8)
                labels = labels2 if r < 0.5 else labels

Figure_1

AlexWang1900 commented 4 years ago

we  fixed 0.5:0.5  with external pictures.as for in domain pictures I think a normal distribution may be good? concentrated in 0.3:07--0.7:0.3,with both labels on  the beta distribution from original paper is for classification,which can use weighted label add,using same weghts as pictures mixup weights. but boxes can only be hard labels.  and even for classification,we found a uniform distribution is better. I think under some conditions beta distribution (a=b=1) is a flat line. the 8:2or 2:8  is conservative,not to make too hard cases for model to learn. 

---Original--- From: "Glenn Jocher"<notifications@github.com> Date: Sun, Jul 26, 2020 04:14 AM To: "ultralytics/yolov5"<yolov5@noreply.github.com>; Cc: "Mention"<mention@noreply.github.com>;"AlexWang1900"<41888506@qq.com>; Subject: Re: [ultralytics/yolov5] mAP drop in yolov5s with v2.0 release (#496)

I modified a beta distribution for this by skewing the alpha params to one side to obtain ratios for this idea. Perhaps this is simply trending towards an exp distribution by doing this. import numpy as np import matplotlib.pyplot as plt plt.hist(np.random.beta(.3, 1.0, 1000000), 300)

An alternative would be to use the same alphas np.random.beta(0.3, 0.3), and to flip the labels when the ratio exceeds 50%. What do you think? if random.random() < 0.5: img2, labels2 = load_mosaic(self, random.randint(0, len(self.labels) - 1)) r = np.random.beta(0.3, 0.3) # mixup ratio, alpha=beta=0.3 img = (img r + img2 (1 - r)).astype(np.uint8) labels = labels2 if r > 0.5 else labels

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub, or unsubscribe.

glenn-jocher commented 4 years ago

@AlexWang1900 yes, beta distribution is very interesting. alpha=1,1 is a uniform as you say, I think alpha=10,1 is a simple line with slope 1. I started training COCO YOLOv5l with alpha=20,4, which gives me this distribution, which I like because it rarely crosses 0.5, and doesn't waste time at 1.0 either. I don't modify the labels at all. This will take a long time to train, about a week.

I had another idea about introducing perspective transforms into the augmentation. I think rotation may be partly hurting mAP on COCO due to the rotated image edges during training, which never happens during testing. So for this I applied extra effort to applying a perspective transform about the mosaic origin, which retains orthogonal lines as in the zero rotation case:

Example 1 Figure_4

Example 2 Figure_5

AlexWang1900 commented 4 years ago

I don't know in object detection it took so longer for mixup. classification task must be much easier. as for edges a random  center crop may eliminate edges.  perspective change I haven't seen it in any  paper. it's a new idea!!! also a gridmask is the new best augmentation in classification. maybe it can be used in object detection. 

---Original--- From: "Glenn Jocher"<notifications@github.com> Date: Sun, Jul 26, 2020 07:02 AM To: "ultralytics/yolov5"<yolov5@noreply.github.com>; Cc: "Mention"<mention@noreply.github.com>;"AlexWang1900"<41888506@qq.com>; Subject: Re: [ultralytics/yolov5] mAP drop in yolov5s with v2.0 release (#496)

@AlexWang1900 yes, beta distribution is very interesting. alpha=1,1 is a uniform as you say, I think alpha=10,1 is a simple line with slope 1. I started training COCO YOLOv5l with alpha=20,4, which gives me this distribution, which I like because it rarely crosses 0.5, and doesn't waste time at 1.0 either. I don't modify the labels at all. This will take a long time to train, about a week.

I had another idea about introducing perspective transforms into the augmentation. I think rotation may be partly hurting mAP on COCO due to the rotated image edges during training, which never happens during testing. So for this I applied extra effort to applying a perspective transform about the mosaic origin, which retains orthogonal lines as in the zero rotation case:

Example 1

Example 2

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub, or unsubscribe.

AlexWang1900 commented 4 years ago

about bn. there must be a little distribution shift between train and Val and test. to overfit the val/test distribution there must be some random shift also in training process. like a random search. while pytorch bn trakcer makes that impossible,as epoch grows every batch update makes fewer and fewer contribution in bn.  most augmentations will  just mean out the changes. doesn't change the distribution especially with bn tracker.  bn trakcer may makes training more stable by representing trainset distribution. but some small disturbance each batch brings the chance to overfit to val/test set. 

as for open bn during validation. many efficientDet does that and have a huge gain. I think you saw worse result because batch size =1,during testing,that unstable bn mean  too much.

---Original--- From: "Glenn Jocher"<notifications@github.com> Date: Sun, Jul 26, 2020 03:34 AM To: "ultralytics/yolov5"<yolov5@noreply.github.com>; Cc: "Mention"<mention@noreply.github.com>;"AlexWang1900"<41888506@qq.com>; Subject: Re: [ultralytics/yolov5] mAP drop in yolov5s with v2.0 release (#496)

@AlexWang1900 that's interesting. mixup 50% of the images with the same ratio or a beta distribution ratio (with alpha=0.3?). Not having to update the labels is nice. I will try this on COCO then.

About the BN I would assume that as the statistics smooth during later epochs the result would get better, unless the BN stats are not representative of the val images (i.e. due to augmentations perhaps). You can test the model in .train() mode instead of .eval() mode for the BN layers to ignore the stats and use realtime mean and sigma for normalization, but this should always produce worse results I believe.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub, or unsubscribe.

github-actions[bot] commented 4 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

ZeKunZhang1998 commented 4 years ago

@AlexWang1900 yes, beta distribution is very interesting. alpha=1,1 is a uniform as you say, I think alpha=10,1 is a simple line with slope 1. I started training COCO YOLOv5l with alpha=20,4, which gives me this distribution, which I like because it rarely crosses 0.5, and doesn't waste time at 1.0 either. I don't modify the labels at all. This will take a long time to train, about a week.

I had another idea about introducing perspective transforms into the augmentation. I think rotation may be partly hurting mAP on COCO due to the rotated image edges during training, which never happens during testing. So for this I applied extra effort to applying a perspective transform about the mosaic origin, which retains orthogonal lines as in the zero rotation case:

Example 1 Figure_4

Example 2 Figure_5

I think after rotate, sometime the box becomes not the best one to fit the object.

glenn-jocher commented 4 years ago

@ZeKunZhang1998 yes, the boxes are augmented using the same transforms as the images. This may result in suboptimal boxes at extreme augmentations (i.e. 45 degree rotations).