facebookresearch / deit

Official DeiT repository
Apache License 2.0
4.02k stars 552 forks source link

can you provide the hyperparameters for reproduce your results of resMLP? #106

Closed Dongshengjiang closed 3 years ago

Dongshengjiang commented 3 years ago

Hi, can you provide the hyperparameters for reproduce your results of resMLP? is the final results for resMLP-12 is 76.6 without distilled?

TouvronHugo commented 3 years ago

Hi @Dongshengjiang , Yes, the hparams are similar to the default setting used in DeiT. We only adapt stochastic depth as a function of depth as in CaiT, we use Lamb optimizer with lr= 5e-3, Weight Decay =0.2 and we don't use FP16 training.

To summarize we use:

Yes the final results for resMLP-S12 without distillation is 76.6%

Best, Hugo

Dongshengjiang commented 3 years ago

Hugo, Thanks for repid reply. I will try to reproduce your results according to the hparams. By the way, another question is you said as above "adapt stochastic depth as a function of depth as in CaiT", but in Cait and resMLP code "dpr = [drop_path_rate for i in range(depth)] " and in deit code "dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]". It means the drop path rate is fixed with depth in Cait and resMLP and increased with depth in Deit?

TouvronHugo commented 3 years ago

Yes exactly. But when I said "adapt stochastic depth as a function of depth as in CaiT", I was referring to figure 3 of the Cait paper. Best, Hugo

Aguin commented 3 years ago

Hi @TouvronHugo, Thanks for your work and for providing details. It's common to skip weight decay for norm and bias (at least timm library sets it as default). We found this really matters when reproducing DeiT. Did you apply weight decay to Affine Layers and LayerScale parameters in ResMLP?

TouvronHugo commented 3 years ago

Hi @Aguin , Thanks for your question. Yes, we skip weight decay for Affine and LayerScale parameters. Best, Hugo

ggjy commented 3 years ago

Hi @TouvronHugo , When I choose Adamw as the optimizer, it's easy to train ResMLP-12 (AMP O1 mixed precision), however, I always got NAN (e.g. at 4 epochs, 22 epochs) when train ResMLP-24/36, no matter the hyper-parameters (e.g. AMPO0 batch 2048 lr5e-3 got NAN, AMPO0 batch 2048 lr2e-3 (same as deit) got NAN). Is this situation normal? Or this happens because I choose the Adamw, maybe Lamb is the crucial component?

TouvronHugo commented 3 years ago

Hi @ggjy , I think the learning rate 2e-3 and 5e-3 are not suitable for this architecture with AdamW and batch size of 2048. It is better to use LAMB with the hparams given above. Best, Hugo

ggjy commented 3 years ago

Hi Huge @TouvronHugo , Sorry to bother you again. I try the LAMB, however, I still got NAN loss, is there something special I have missed? I trained the resmlp_24 on 8 V100 (256 batch per card) and set the args.lr = 1.25e-3 according to

linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0

And the hparams are as follows

Namespace(aa='rand-m9-mstd0.5-inc1', amp_opt_level='O0', apex_amp=True, batch_size=256, clip_grad=None, color_jitter=0.4,
cooldown_epochs=10, cutmix=1.0, cutmix_minmax=None, data_path='/cache/data/imagenet/', data_set='IMNET', debug=False,
decay_epochs=30, decay_rate=0.1, decouple_decay=0.05, device='cuda', dist_backend='nccl', dist_url='env://', distributed=True,
drop=0.0, drop_block=None, drop_path=0.1, epochs=300, eval=False, gpu=0, inat_category='name', init_method=None,
input_size=224, local_rank=0, lr=0.00125, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, min_lr=1e-05, mixup=0.8,
mixup_mode='batch', mixup_prob=1.0, mixup_switch_prob=0.5, model='resmlp_24',momentum=0.9, num_workers=4,
opt='lamb', opt_betas=None, opt_eps=1e-08, recount=1, remode='pixel', repeated_aug=True, reprob=0.25, resplit=False,
resume='', sched='cosine', seed=0, smoothing=0.1, start_epoch=0, warmup_epochs=5, warmup_lr=1e-06, weight_decay=0.2, world_size=8)
Creating model: resmlp_24

But it is still NAN during early epochs:

Epoch: [2]  [  0/625]  eta: 0:49:13  lr: 0.001001  loss: 6.9097 (6.9097)  time: 4.7257  data: 3.3613  max mem: 18965
Epoch: [2]  [200/625]  eta: 0:05:28  lr: 0.001001  loss: 6.8753 (6.8938)  time: 0.7524  data: 0.0005  max mem: 18965
Epoch: [2]  [400/625]  eta: 0:02:50  lr: 0.001001  loss: 6.8523 (6.8810)  time: 0.7367  data: 0.0004  max mem: 18965
Epoch: [2]  [600/625]  eta: 0:00:18  lr: 0.001001  loss: 6.8468 (6.8725)  time: 0.7457  data: 0.0004  max mem: 18965
Epoch: [2]  [624/625]  eta: 0:00:00  lr: 0.001001  loss: 6.8566 (6.8717)  time: 0.7299  data: 0.0008  max mem: 18965
Epoch: [2] Total time: 0:07:50 (0.7532 s / it)
Averaged stats: lr: 0.001001  loss: 6.8566 (6.8710)
Epoch: [3]  [  0/625]  eta: 0:48:29  lr: 0.002001  loss: 6.8584 (6.8584)  time: 4.6553  data: 3.8440  max mem: 18965
Epoch: [3]  [200/625]  eta: 0:05:23  lr: 0.002001  loss: 6.8790 (6.8592)  time: 0.7401  data: 0.0004  max mem: 18965
Epoch: [3]  [400/625]  eta: 0:02:49  lr: 0.002001  loss: 6.8791 (6.8868)  time: 0.7442  data: 0.0004  max mem: 18965
Epoch: [3]  [600/625]  eta: 0:00:18  lr: 0.002001  loss: 6.8770 (6.8833)  time: 0.7380  data: 0.0003  max mem: 18965
Epoch: [3]  [624/625]  eta: 0:00:00  lr: 0.002001  loss: 6.8799 (6.8831)  time: 0.7252  data: 0.0007  max mem: 18965
Epoch: [3] Total time: 0:07:47 (0.7478 s / it)
Averaged stats: lr: 0.002001  loss: 6.8799 (6.8842)
Epoch: [4]  [  0/625]  eta: 0:47:28  lr: 0.003000  loss: 6.8747 (6.8747)  time: 4.5572  data: 3.3035  max mem: 18965
ERROR:root:Loss is nan, stopping training
TouvronHugo commented 3 years ago

Hi @ggjy , What modification did you do to add the elements corresponding to the flags: amp_opt_level and apex_amp ? (I do not use mixed precision with my ResMLP models). Did you make any other changes in the code? Best, Hugo

ggjy commented 3 years ago

Hi Hugo, Thanks for your reply. I use the apex_amp to train the Deit like this:

if args.apex_amp:
    model = ApexDDP(model, delay_allreduce=True)
    model, optimizer = amp.initialize(model, optimizer, opt_level=args.amp_opt_level)
    loss_scaler = ApexScaler()
else:
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
    loss_scaler = NativeScaler()

And I don't make any other changes in the code. I will try to use the NativeScale() and torch.nn.parallel.DistributedDataParallel() to initilize the model, maybe the APEX can cause the above NAN problem.

ggjy commented 3 years ago

Using torch.nn.parallel.DistributedDataParallel() can train the resmlp normally, thanks.