Closed TaoXieSZ closed 4 years ago
@ChristopherSTAN wow, that's a super effective augmentation modification with very few lines of additional code, very impressive.
If I remember correctly though I think this type of augmentation benefits the most from a beta distribution in ratios, with the peaks around 0.1 and 0.9, and a maybe also a relative adjustment in confidences to match the ratios, but I'm not really sure.
Can you upload some train_batch0.jpg images this creates?
@glenn-jocher I use the tutorial notebook. And I get this.
Yes, I realized there is a bug here. The uint8's are saturating at 255, so rather than:
img = (img + r_img) // 2
you want:
img = img // 2 + r_img // 2
Or maybe better yet, a random fraction b
sampled from the beta distribution used in the paper:
img = img * b + r_img * (1 - b)
@glenn-jocher Thanks a lot! I think it will take much time on tuning the parameter.
@ChristopherSTAN I'd need to review the paper again, but this should be pretty easy. Let's see, from https://en.wikipedia.org/wiki/Beta_distribution we want alpha = beta = 0.5 roughly.
We can use numpy for the random number generator: https://numpy.org/doc/1.18/reference/random/generated/numpy.random.beta.html
So something like np.random.beta(0.5, 0.5, 1) should give you the b
parameter above I was mentioning.
import numpy as np
import matplotlib.pyplot as plt
x = np.random.beta(0.5, 0.5, 1000000)
plt.hist(x,500)
Or perhaps a=b=0.3 for a slightly more peaked bimodal distribution. I'm sure they tuned this parameter in the paper to a good value.
@glenn-jocher I think roughly coding the augmentation part is easy:
mixup_ratio = np.random.beta(0.3, 0.3) # alpha = beta = 0.3
...
img = img * mixup_ratio + r_img * (1 - mixup_ratio)
...
I have not read the paper so carefully. But I read some blogs and found that the loss calculation may need to be updated. That may involve large range of codes. So I decide to do it tomorrow(1:53 am now).
This is my first trial:
In train.py
:
https://github.com/ChristopherSTAN/yolov5/blob/a4e01e4b290982a2aab1eee5b1a22cac5c877cf5/train.py#L234-L266
In utils/dataset.py
, return one more output in __getitem__()
:
https://github.com/ChristopherSTAN/yolov5/blob/a4e01e4b290982a2aab1eee5b1a22cac5c877cf5/utils/datasets.py#L527
@glenn-jocher I take a look on some implementations of MixUp. https://github.com/dmlc/gluon-cv/blob/428ee05d7ae4f2955ef00380a1b324b05e6bc80f/scripts/detection/faster_rcnn/train_faster_rcnn.py#L187 https://github.com/dmlc/gluon-cv/blob/49be01910a8e8424b017ed3df65c4928fc918c67/gluoncv/data/mixup/detection.py#L65
I am not sure if the loss should be scaled by the mixup ratio. So I first ignore the loss scaling.
Shorter code to pick random index:
r_index = np.random.choice(np.delete(np.arange(len(self.img_files)), index))
@ChristopherSTAN I worked on this a bit. Came up with this code. Results look pretty good. You can see tennis player overlaid with airplane in image 1.
if self.mosaic:
# Load mosaic
img, labels = load_mosaic(self, index)
shapes = None
# Apply MixUp
if random.random() < 0.99:
img2, labels2 = load_mosaic(self, random.randint(0, len(self.labels) - 1))
r = np.random.beta(0.3, 0.3) # mixup ratio, alpha = beta = 0.3
img = (img * r + img2 * (1 - r)).astype(np.uint8)
labels = np.concatenate((labels, labels2), axis=0)
EDIT: The missing part is that the object confidences need to be multiplied by their associated mixup ratio. If image 1 is 0.10 ratio for example, and image 2 is 0.9, then all of the objects in image 1 and image 2 should probably to have their obj loss multiplied accordingly I believe.
@glenn-jocher It seems that there are not so many model training with mosaic and Mixup, from my view. I used to consider multiplying the total loss with ratio, your point is much more concise.
@ChristopherSTAN another option is to not add any new labels, and to use a random mixup ratio 0.9 < r < 1.0
. This would have the effect of adding a slight 0-10% background to all images to increase variety.
EDIT: One problem with these experiments is that they will only show benefit at the very end of a long training (i.e. COCO epoch 250+). I tested on coco128 and it showed worse results, but this makes sense as it is making it harder to overtrain, so learning progress is slower. But to get the true effect needs a full training, which will take much time and resources.
@glenn-jocher Update: my MixUP complement help wheat compete improve.
@ChristopherSTAN wow awesome!! I wonder how well that might generalize to COCO though. In the meantime I've added your implementation into the master branch, commented out for now. https://github.com/ultralytics/yolov5/blob/e169edfcf54797456e11c2d0e5c3d43b64beadb8/utils/datasets.py#L471-L483
@glenn-jocher After several compete, maybe I can have experiments on it, due to the limited usage of Colab... (Poor student) :(
@ChristopherSTAN yes, it is very hard to experiment on COCO unfortunately. I recently started working with VOC a bit more, after we had a PR that made it nice and easy to use now.
With Colab Pro VOC will train to 300 epochs in <24 hours for YOLOv5s, so it may be a better way to experiment.
EDIT: For fastest training you also want to --cache your images, which is not possible with COCO, but can be done on most other datasets.
@glenn-jocher After installing Apex, I set batch_size as large as possible. And have full use of P100 of Colab(16.1G of mem), it is also fine to add --cache?
@ChristopherSTAN yes, --cache simply loads up all images into system RAM (not GPU RAM). This removes the overhead required to read images from the hard drive each epoch. If you train for 300 epochs for example, you will be reading each image from the hard drive up to 1200 times (300 times normally, plus about 900 more when training with mosaic). These cv2 read operations are very expensive, even with multiple dataloader workers. --cache solves this:
python train.py --cache
Hi, @ChristopherSTAN, @glenn-jocher !
I notices that current ultralytics implementation doesn't scale loss according to mixup ratio. Do we need to scale it like @ChristopherSTAN did in his implementation?
@xevolesi yes I think we'd want to either scale the obj loss, or scale the obj target value according to the mix ratio for that object.
@xevolesi @glenn-jocher I forget to update my implement. There is bug.
Hi,bro @ChristopherSTAN are you stiil working on mosaic&&mixup ?
@summeryumyee Yes, I am using it. But I am setting r = 0.5
, because I think if r
is too large or too small, the mixed image and boxes are too obscure. So I make it half half and just use mixup to make model more robust to overlapping objects.
@ChristopherSTAN ah you are using ratio=0.5 for all images? Are you appending both image labels together like this?
# MixUp https://arxiv.org/pdf/1710.09412.pdf
if random.random() < 0.5:
img2, labels2 = load_mosaic(self, random.randint(0, len(self.labels) - 1))
r = 0.5 # np.random.beta(0.3, 0.3) # mixup ratio, alpha=beta=0.3
img = (img * r + img2 * (1 - r)).astype(np.uint8)
labels = np.concatenate((labels, labels2), 0)
@glenn-jocher Yes, I mix them 50/50. Because the original MixUp was proposed in Image Classification. So I am not sure it is good to chose r
from beta distribution. But I find some blogs about training tricks in Object detection Competes and they are using `r=0.5, so I just follow them. I am not sure mixup is also well-performing in COCO.
And I think some datasets are suitable for this way: For example, the Xray Detection. There are naturally existing overlapping objects: After Mixup:
I just use this way as PURE data augmentation, and do nothing to the loss function. And I have a simple research about others' work, and find nothing more.... So I am not sure I am on the right way. But the val mAP is normal, I think it can prevent overfitting on val set a little bit.
What do you think?
Hi, guys! If i understood correctly In this https://arxiv.org/pdf/1902.04103.pdf authors used MixUp with B(1.5, 1.5) and mixing all labels together. The result is shown on Fig. 3 in paper. They obtained ~ +2% of mAP0.5 on PASCAL VOC with YOLOv3 model.
Unfortunatelly i didn't test such configuration of MixUp with global wheat detection.
@xevolesi thanks for the paper link!
I suppose @ChristopherSTAN is essentially using B(inf,inf) which will produce a dirac delta function of r=0.5 for all cases. @ChristopherSTAN you are right, your dataset is already uniquely semi-transparent and overlapping, very interesting. Mixup results on this dataset may not generalize well to more typical datasets.
I've seen this paper before, but unfortunately their YOLOv3 baseline is very poor since they started from the original darknet yolov3 which only produce 33mAP. I'm running an experiment with B(20,4), which adds about 10-20% ratio of a second image, but does not add the second image labels. This is currently training slightly less than default however.
@xevolesi after reviewing the paper again, I think yes you are correct, the authors determine for object detection best results at B(1.5, 1.5) and appending both images labels together with no tricks on the labels (such as assigning them confidences based on the ratio).
I will try this scenario then also with VOC using YOLOv5m. I should have some results by tomorrow.
Ok, the results are in, it works! I trained YOLOv5m for 50 epochs on VOC, starting from the pretrained weights. mAP@0.5 increased from 84.9 to 85.9 (+1.0) with the addition of mixup B(2.0, 2.0). Obj val loss in particular overfit less. mAP@0.5:0.95 increased +0.3.
how to deal with the label weights,just mix all labels together?
Ok, the results are in, it works! I trained YOLOv5m for 50 epochs on VOC, starting from the pretrained weights. mAP@0.5 increased from 84.9 to 85.9 (+1.0) with the addition of mixup B(2.0, 2.0). Obj val loss in particular overfit less. mAP@0.5:0.95 increased +0.3.
Glad to hear, Glenn!
how to deal with the label weights,just mix all labels together?
@ZJU-lishuang , Hi! The paper https://arxiv.org/pdf/1902.04103.pdf suggests to mix labels together as a new array of labels. But i think it's strongly depends on dataset. I advise you to try some options listed in this issue and pick the best.
Ok, the results are in, it works! I trained YOLOv5m for 50 epochs on VOC, starting from the pretrained weights. mAP@0.5 increased from 84.9 to 85.9 (+1.0) with the addition of mixup B(2.0, 2.0). Obj val loss in particular overfit less. mAP@0.5:0.95 increased +0.3.
That's good news .I used the questioner's code and modified some error bug. the training program didn't work out very well. you will be open the mosaic&mixup modified code?
Happy to have a high LB score in Wheat compete with more data augmentations. So I want to apply more augmentation methods to try or to have fun.
Have you tried yolov5-v1 and v2 ? Which is better or wheat dataset ? Thanks a lot !
Ok, the results are in, it works! I trained YOLOv5m for 50 epochs on VOC, starting from the pretrained weights. mAP@0.5 increased from 84.9 to 85.9 (+1.0) with the addition of mixup B(2.0, 2.0). Obj val loss in particular overfit less. mAP@0.5:0.95 increased +0.3.
Did you remove the difficult object in VOC dataset ?
@clw5180 Hi, bro. I try v1 on wheat dataset and get the best LB 0.7721. I have not try v2.
@clw5180 no did not remove any objects from VOC. Used the default data/voc.yaml: https://github.com/ultralytics/yolov5/blob/7f8471eaebe4b192c5e6ab4e5c821d91e43cb4fe/data/voc.yaml#L1-L13
@ChristopherSTAN based on the success finetuning on VOC I tried to increase the mixup rate to 0.9 (i.e. 90% of the images would be mixed up), up from the default 0.5, but I got worse results. Based on this I started training the 4 models (s m l x) with B(8.0, 8.0) with a 0.50 rate, training from scratch.
After only about 30 epochs it's clear that mixup effect greatly favors the larger models more than the smaller models. The smaller models are performing worse than default, while 5x is performing slightly better than default.
In general it's becoming apparent to me that the default training settings/hyps are optimal when training from scratch for a model about the size of YOLOv5m. Larger models overfit too early, and miss out on later gains. 5s never overfits, and thus also leaves some mAP on the table. This may simply have to do with batch-size though. Despite my best intentions, accumulating smaller batches does not produce the same result as using a larger batch. Smaller batches lead to earlier gains and earlier overfitting, which is precisely what the larger models do. The weight decay may be playing a part here that I don't fully understand.
If my theory is correct then larger model overfitting problem may be partly addressed by simply training with multi-gpu on a larger batch size, and conversely the smaller models may benefit from a reduction in weight decay.
@glenn-jocher In your latest repo, the mixup ratio is set = 0 in default, isnt it? So while we are training, mixup and mosaic are activate to augment data?
The mixup hyperparameter sets the probability of mixup being used in any given image. Default is zero. https://github.com/ultralytics/yolov5/blob/93684531c6e71547667ee19df6ddb94af3c8c80d/train.py#L53
@glenn-jocher I run genetic algorithm but it didnt effect on mixup so the mixup parameter still be zero. I dont know how to find the best mixup parameter?
@buimanhlinh96 If you look into it, it is just a possibility setting. You have to tune the parameter of beta function inside datasets.py: https://github.com/ultralytics/yolov5/blob/93684531c6e71547667ee19df6ddb94af3c8c80d/utils/datasets.py#L486-L491
Yeah, i see
@buimanhlinh96 yes. I set r=0.5
in my project. That's up to you.
@buimanhlinh96 hyp evolution is incapable of modifying parameters with starting value of zero. Mutation is applied as a multiplication operator, so a fraction of the current generation hyps will be adjusted as x = randn()0.2 roughly.
In order to vary the mixup parameter, you must provide it a nonzero initial probability (0-1). The actual beta parameters are set to mixup images roughly evenly. As alpha=beta trends towards zero the mixup ratio trends towards a Uniform distribution. As alpha=beta trends towards infinity the mixup ratio trends towards a constant dirac delta function at r = 0.5.
@ChristopherSTAN I worked on this a bit. Came up with this code. Results look pretty good. You can see tennis player overlaid with airplane in image 1.
if self.mosaic: # Load mosaic img, labels = load_mosaic(self, index) shapes = None # Apply MixUp if random.random() < 0.99: img2, labels2 = load_mosaic(self, random.randint(0, len(self.labels) - 1)) r = np.random.beta(0.3, 0.3) # mixup ratio, alpha = beta = 0.3 img = (img * r + img2 * (1 - r)).astype(np.uint8) labels = np.concatenate((labels, labels2), axis=0)
EDIT: The missing part is that the object confidences need to be multiplied by their associated mixup ratio. If image 1 is 0.10 ratio for example, and image 2 is 0.9, then all of the objects in image 1 and image 2 should probably to have their obj loss multiplied accordingly I believe.
Hello,why the rate is 0.99,so high.
@ChristopherSTAN I worked on this a bit. Came up with this code. Results look pretty good. You can see tennis player overlaid with airplane in image 1.
if self.mosaic: # Load mosaic img, labels = load_mosaic(self, index) shapes = None # Apply MixUp if random.random() < 0.99: img2, labels2 = load_mosaic(self, random.randint(0, len(self.labels) - 1)) r = np.random.beta(0.3, 0.3) # mixup ratio, alpha = beta = 0.3 img = (img * r + img2 * (1 - r)).astype(np.uint8) labels = np.concatenate((labels, labels2), axis=0)
EDIT: The missing part is that the object confidences need to be multiplied by their associated mixup ratio. If image 1 is 0.10 ratio for example, and image 2 is 0.9, then all of the objects in image 1 and image 2 should probably to have their obj loss multiplied accordingly I believe.
does the total_loss computes include mix_up loss in script, that i cannot find it @glenn-jocher
@nanhui69 loss function does not change in our mixup implementation. Setting mixup to 0.2-0.6 seems to help with finetuning tasks on custom datasets. For VOC for example it raises mAP by about 1%.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
🚀 Feature
Apply MixUp augmentation when using mosaic.
Motivation
Happy to have a high LB score in Wheat compete with more data augmentations. So I want to apply more augmentation methods to try or to have fun. I got this idea from here: https://www.kaggle.com/shonenkov/oof-evaluation-mixup-efficientdet/comments#Out-of-fold-prediction:
Pitch
I simply implemented as follows: https://github.com/ChristopherSTAN/yolov5/blob/f1c03085b0cf5cf9b1ff72158194dc36f5753c09/utils/datasets.py#L486-L492
Waiting for your any suggestion.