ultralytics / yolov3

YOLOv3 in PyTorch > ONNX > CoreML > TFLite
https://docs.ultralytics.com
GNU Affero General Public License v3.0
10.16k stars 3.44k forks source link

yolov3-darknet scheduler #974

Closed pfeatherstone closed 4 years ago

pfeatherstone commented 4 years ago

Context: It is INSANELY hard to train yolov3 accurately. If one simply uses Adam with a basic scheduler, one doesn't even get close to the original weights trained by pjreddie. It seems that you have performed some wizardry to setup the lambda scheduler to get it working.

Question: Can we/you write a pytorch optimizer-scheluder pair specifically for yolov3. Ideally, we would simply need two objects: optimizer and scheduler and call .step() on each one after each batch. That would be amazing, and would probably benefit other projects too.

pfeatherstone commented 4 years ago

It seems improving the loss function would help the optimizer. Maybe we need to do a "loss-landscape" study to figure out works best. For example, i've found using a CIOU loss for box regression helps an awful lot compared to MSE. Furthermore, i've found using focal loss doesn't help a great deal for objectness and classfication. It seems to me we either have to improve the overall loss function and stick to an Adam-like optimizer or stick with the current yolo loss function and improve the optimizer-scheduler. Another thought: maybe some of the optimizer hyperparameters can themselves be learnable parameters.

FranciscoReveriano commented 4 years ago

The data of the repository is optimized for COCO. On custom datasets, I have found that it takes great effort on hyperparameters and scheduler to get good results. Some would call that the "art of machine learning".

I have experimented with most of the schedulers. And noticed that most datasets required their own specific scheduler.

glenn-jocher commented 4 years ago

@pfeatherstone this repo trains with default settings to better results than pjreddie's. The scheduler nor the optimizer are special, they are simply SGD with a common cosine scheduler. To reproduce: https://github.com/ultralytics/yolov3#reproduce-our-results

pfeatherstone commented 4 years ago

@glenn-jocher There's the whole burn-in funny business that is special. There's stuff like this:

for k, v in dict(model.named_parameters()).items():
        if '.bias' in k:
            pg2 += [v]  # biases
        elif 'Conv2d.weight' in k:
            pg1 += [v]  # apply weight_decay
        else:
            pg0 += [v]  # all else

There's stuff to do with pre-bias settings. So it's not just a cosine scheduler... I think it's great work, I think it needs wrapping up into it's own optimizer-scheduler.

pfeatherstone commented 4 years ago

I remember yonks back, there was a paper released where they made the learning parameter inside Adam a learnable parameter. Maybe, we could do the same with the fine tuned parameters found here:

hyp = {'giou': 3.54,  # giou loss gain
       'cls': 37.4,  # cls loss gain
       'cls_pw': 1.0,  # cls BCELoss positive_weight
       'obj': 64.3,  # obj loss gain (*=img_size/320 if img_size != 320)
       'obj_pw': 1.0,  # obj BCELoss positive_weight
       'iou_t': 0.225,  # iou training threshold
       'lr0': 0.01,  # initial learning rate (SGD=5E-3, Adam=5E-4)
       'lrf': -4.,  # final LambdaLR learning rate = lr0 * (10 ** lrf)
       'momentum': 0.937,  # SGD momentum
       'weight_decay': 0.000484,  # optimizer weight decay
       'fl_gamma': 0.0,  # focal loss gamma (efficientDet default is gamma=1.5)
       'hsv_h': 0.0138,  # image HSV-Hue augmentation (fraction)
       'hsv_s': 0.678,  # image HSV-Saturation augmentation (fraction)
       'hsv_v': 0.36,  # image HSV-Value augmentation (fraction)
       'degrees': 1.98 * 0,  # image rotation (+/- deg)
       'translate': 0.05 * 0,  # image translation (+/- fraction)
       'scale': 0.05 * 0,  # image scale (+/- gain)
       'shear': 0.641 * 0}  # image shear (+/- deg)

I bet the parameters above are overfitted for COCO.

glenn-jocher commented 4 years ago

@pfeatherstone yes burnin is a pretty common practice. The darknet repos do it as well. It just means you ramp up your LR from 0.

The parameter groups you show are there to prevent decay being applied to batchnorm layers, which is also best practices.

Yes the hyperparameters are evolved for yolov3-spp on COCO. They can be evolved on other architectures and datasets using https://github.com/ultralytics/yolov3/issues/392, though this takes time.

glenn-jocher commented 4 years ago

@pfeatherstone a lot of the added functionality here should really be natively in pytorch probably. Some of it is already present in TF/Keras, such as EMA.

github-actions[bot] commented 4 years ago

This issue is stale because it has been open 30 days with no activity. Remove Stale label or comment or this will be closed in 5 days.

pfeatherstone commented 4 years ago

As an experiment, I added bias to all the conv layers. It turns out that the instability problem early on in training vanishes for me. So no need for burnin, i don't get any NaN or Inf losses/gradients even when using diou or ciou losses (ciou uses atan2 which can give you some infinities during training). The motivation was that adding biases barely adds to overall weights, i don't think it impacts computational speed too much and biases converge faster (or at least they are most stable to higher learning rates). Interesting...

glenn-jocher commented 4 years ago

@pfeatherstone that's an interesting observation, but we remove the bias because it is redundant with the bias in the batchnorm layer. You are simply adding two biases if you do not remove the conv2d bias.

You may be able to achieve the effect you are looking for by modifying the batchnorm parameters instead.

pfeatherstone commented 4 years ago

Good point. I forgot about the bias in the batchnorm layer. Hmm it’s the second time someone has suggested tampering with the batchnorm parameters. The first time was to adjust for minibatching. I was hoping I was never going to have to do that but maybe I need to care about it more. Thanks @glenn-jocher

glenn-jocher commented 10 months ago

@pfeatherstone you're welcome! Batch normalization adjustments can indeed have a big impact, particularly when dealing with specific training challenges. Feel free to reach out if you have further questions or need assistance with YOLOv3. Good luck with your experiments!