huggingface / pytorch-image-models

The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more
https://huggingface.co/docs/timm
Apache License 2.0
32.34k stars 4.76k forks source link

default training hyper-parameters #11

Closed cxxgtxy closed 5 years ago

cxxgtxy commented 5 years ago

Hi, Impressive work! The train scipts contains a large combination of various hyper-parameter options. However, there are different types of models, and even many models are contained even within the efficientnet part. I wonder whether you trained models with default ones. If not, do you plan to release model specific hyper-parameters? Thanks!

rwightman commented 5 years ago

Yeah, there are a LOT of combinations. I'm not sure I'll ever get through testing and verifying all of them for all models.

I've found some good ones for ResNet style networks that train them better than most traditional hyper-params. I posted some here: https://github.com/pytorch/vision/pull/909#issuecomment-497825578 The best one, for a resnet-34 on a single GPU was: python train.py /imagenet/ --model resnet34 -b 256 --epochs 125 --warmup-epochs 5 --sched cosine --lr 0.1 --weight-decay 1e-4 --reprob 0.4 --remode pixel

For larger resnets, or resnet like networks pushing the epochs out to 150-200 would be worthwhile.

As for efficientnets, they are a challenge to train. The techniques that work well for other networks, don't necessarily result in good accuracy here. I've found the best approach to use as close to the paper h-params as possible in PyTorch, incl my Rmsprop tweak. Scale LR with the batch size, noting they are using really large batch sizes on TPUs. And must use the EMA smoothed copy of the weights. There is a branch on this repo that has the EMA capabilty.

On the EMA branch, with distributed_train.sh across 3 GPU I've been using roughly this for b0 EffNet and Mnas/Mobilenet-v3 style networks. LR needs wiggle with diff batch size/network, .01-.03 with 3 GPU distributed (so effective batch size of 3 x 80=240 to 3 x 128=384). Decay rate in the 0.96-0.97 range is in the ballpark of being equivalent to the 2.4 epoch decays in the papers. -lr 0.027 -b 128 --drop 0.2 --img-size 224 --sched step --epochs 550 --decay-epochs 3 --decay-rate 0.97 --opt rmsproptf -j 4 --warmup-epochs 5 --warmup-lr 1e-6 --weight-decay 1e-5 --opt-eps .001 --model-ema

cxxgtxy commented 5 years ago

Thanks for detailed suggestions. I have another question about EMA. Have you tested the above hyper parameters about standard mobilenet v2? And can you report your experiment results? Thanks!

rwightman commented 5 years ago

@cxxgtxy I've not tried training MobileNet-V2 ... still trying to get some PyTorch native B1/B2/B3 efficientnets...

One other thing to mention, don't forget to set drop_connect_rate=0.2 in addition to the normal dropout when training these, it definitely has an impact. It's not set by default and I don't have an argument passed through yet (TODO)

bermanmaxim commented 5 years ago

@rwightman Thanks for the details. I was wondering if you used similar efficientnet-style parameters for mobilenet-v3 or if those train with more traditional hyperparams.

rwightman commented 5 years ago

@bermanmaxim pretty similar, b0 and mobilenet-v3 have a lot in common.. only issue with the specifics on mobilenet v3 is that I trained most of it before I added the EMA weight averaging support ... tried a few times, got frustrated that it kept topping off short of the paper

After seeing how much impact the EMA had for these models when fiddling with the EfficientNets, I revisited my best mobilenet-v3 attempt at around the 400 epoch range, resumed with EMA enabled and arrived at the model published here after 40 or so more epochs... I did the same for the MNasNet models and got those to just above the paper spec as well...

bermanmaxim commented 5 years ago

Ok thank you.

cxxgtxy commented 5 years ago

@rwightman Do you mean activate 0.2 dropconnect along with mixup tricks?
Thank you!

triangleCZH commented 5 years ago

Firstly, thank you for your great work and sharing! This is really amazing!

I use your EMA with 8 V100s, following your instructions in https://forums.fast.ai/t/efficientnet/46978/70. I tried EMA_rate=0.9998 and 0.999, but the changes still seem too trivial. Could you please share some experience on how to determine a suitable EMA rate? You mentioned "useful in 10 epochs", does this mean EMA will yield a better result than original after 10 epochs?

You said EMA rate is raleted to capability of a system. I have 8 v100, but I can only use a rate much smaller than your 0.9998. I feel this is quite strange, do you have any ideas?

Thank you!

rwightman commented 5 years ago

@cxxgtxy yup, definitely set drop_connect_rate=0.2 in the kwargs. I actually haven't used mixup with any of the efficientnet/mobilenetv3/mnasnet trainings, just the recommended dropout rate and drop connect rate. I have experimented with mixup for resnet networks though and it helps.

rwightman commented 5 years ago

@triangleCZH there are a lot of factors ... my comment about the 10 epochs was in relation to the center of mass equivalence for the exponential moving average vs a simple moving average with that window length. It's averaging over steps, and how many samples are in each step depends on your setup. I'm usually in the 2-5 epoch range. If you say started from scratch with an EMA window equivalence of 10 epochs, you likely wouldn't see the EMA validation results match or surpass the non-averaged model validation for close to 10 epochs. You also need to pay attention to your LR schedule, if you've got big jumps or drops, consider the relation of your EMA window length to the duration between jumps, or settling time at the end of training if you have a rapid drop towards the end as with a Cosine decay.

I've never run any of this on 8 high-end GPUs, you do need to tune everything accordingly, including the initial LR to scale with your batch size. You should be closer to the h-params in the Google papers with their TPUs. Make sure dropout and drop connect rates are set. If you're fine tuning, you should scale the LR down from the LR used for training from scratch as you normally would.

It's still a game of patience, not a magic bullet. The EMA model is worse at the very start, then it'll move well above the non-averaged training model for the bulk of training. Towards the end, as the LR falls and the model convergence slows, you'll see the EMA and training model validation scores come closer together, but you usually have an optimal EMA checkpoint that hit .5-1.5% higher than the best non-averaged checkpoint.

triangleCZH commented 5 years ago

@rwightman Thank you for your reply. The advice is really helpful for me, and I do see the increase in performance with EMA. Next step I plan to build a series of ema with different decay rate in parallel, to find one rate that suits my situation the best.

cxxgtxy commented 5 years ago

@triangleCZH Please share EMA options after you have trained on your high-end gpus. Thanks!

bermanmaxim commented 5 years ago

@rwightman If you don't mind I'm also interested in your retraining of spnasnet_100, was that also a challenge regarding EMA/...?

rwightman commented 5 years ago

@rwightman If you don't mind I'm also interested in your retraining of spnasnet_100, was that also a challenge regarding EMA/...?

@bermanmaxim those spnasnet weights have not been re-trained with EMA enabled, they were from earlier experiments before I implemented the EMA in PyTorch. I believe they were from a 200 epoch run with cosine decay, per-pixel random erasing, and no dropout ... from other attempts with different models, I found that models trained with cosine decay do not benefit from EMA when resuming training after the LR drop

I'm sure spnasnet could be trained much better starting again from scratch, not on my list of todos though

bermanmaxim commented 5 years ago

@rwightman Alright, thank you!

zhaohui-yang commented 5 years ago

@rwightman Would you please share the training parameters for FBNet-C? I used your architecture definition, and the training parameters is same with the original paper, e.g. multi-step lr, weight decay. But I achieved an 74.1 Top-1 accuracy, which is lower than reported. I think the training parameters matters.

Thanks.

rwightman commented 5 years ago

@zhaohui-yang the paper training didn't work well for me with fbnet-c, I used something like the h-params I posted above for mobilenetv3/mnasnet networks, you'll need to scale LR to your batch size depending on GPUs you're using... and one other thing to mention, a lot of the low parameter networks appear to be trained with less agressive color augmentation than the defaults, so try backing those off to something in the .1 - .05 or less range.

zhaohui-yang commented 5 years ago

@rwightman Thanks for your advice, I will try the h-params for mobilenetv3. For the LR, I used batchsize=2048, and learning rate set to 0.8 at the beginning. Besides, I also tested with and without ColorJitter used in DARTS, and it would bring approximately 0.1% Top-1 accuracy improvement. I will try less aggressive color augmentation as you suggested, thank you!

triangleCZH commented 5 years ago

@cxxgtxy Just to share some personal experience. I build my own mobilenet v3 with se and h-swish, it has the same number of params with rwightman's. After several tests on 8 V100s, the best EMA for me is 0.9, which helps me to reach 74.5. Tried 0.9999, 0.9998 and 0.95, the results are not promising. FYI.

jiefengpeng commented 5 years ago

@rwightman Thanks for your work. Would you please share the hyper-parameters for b0(76.912)? I am trying to reproduce the training of efficient_b0 but I got 74.5 with batchsize=512 and other config same as paper except auto_agumentation. Thanks.

rwightman commented 5 years ago

@jiefengpeng what's your exact command line, did you manually enable drop_connect in the code? I finetuned the b0 model from tf weights so hyper-params won't match, but I have trained other models in the family well...

jiefengpeng commented 5 years ago

@jiefengpeng what's your exact command line, did you manually enable drop_connect in the code? I finetuned the b0 model from tf weights so hyper-params won't match, but I have trained other models in the family well...

Thank you. I've trained the b0 from scrach with 8 RTX 2080ti. ./distributed_train.sh 8 /data/imagenet --model efficientnet_b0 --sched cosine --epochs 150 --warmup-epochs 5 --lr 0.4 --batch-size 128 -j 8 --drop 0.2 I also enable drop_connect to 0.2, however I only got 76.3.

Is it the command that you mention? ./distributed_train.sh 2 /data/x/imagenet/ --model mixnet_xl -b 128 --amp --sched step --epochs 500 --decay-epochs 3 --decay-rate 0.963 --opt rmsproptf --opt-eps .001 -j 4 --warmup-epochs 5 --weight-decay 1e-5 --drop 0.225 --color-jitter .1 --lr .02

rwightman commented 5 years ago

@jiefengpeng 76.3 isn't bad using those params, I get good ResNet results with that but never had decent EfficientNet results.

Yes, the other command is closer to what I'd use for EfficientNet, but with a few tweaks. I didn't include the EMA argument even though it was enabled.

Set LR based on .016 global_batch / 256, for global batch of 128 8, that'd be .064. If you use AMP you should be able to get a batch of 256 per card no problem, LR 0.128 for that.

You could try decay-rate anywhere in 0.96 to 0.97. Another arg you could enable is --bn-tf to use batch norm eps + momentum values that are used in the TF code. You have to continue to use the same non default BN eps value at inference time though.

./distributed_train.sh 8 /data/x/imagenet/ --model efficientnet_b0 -b 128 --sched step --epochs 500 --decay-epochs 3 --decay-rate 0.963 --opt rmsproptf --opt-eps .001 -j 4 --warmup-epochs 5 --weight-decay 1e-5 --drop 0.2 --color-jitter .06 --model-ema --lr .064

and make sure drop_connect=0.2 in the model entrypoint function in code as you've already done.

You won't need all 500 epochs. I leave margin in case it still happens to still be converging, but the code keeps the best checkpoints along the way by default.

jiefengpeng commented 5 years ago

Thank you so much. I will try it.

jiefengpeng commented 5 years ago

Fantistic, I have reproduced 76.91 from scrach with your suggestion. Here is my command. ./distributed_train.sh 8 ../ImageNet/ --model efficientnet_b0 -b 256 --sched step --epochs 500 --decay-epochs 3 --decay-rate 0.963 --opt rmsproptf --opt-eps .001 -j 8 --warmup-epochs 5 --weight-decay 1e-5 --drop 0.2 --color-jitter .06 --model-ema --lr .128 Thanks again.

rwightman commented 5 years ago

@jiefengpeng good to hear, thanks for the update.