Closed ShiFengyuan1999 closed 1 year ago
Hi, thanks for your interest in our work. I just checked our logs. We can achieve the reported accuracies in the same environment.
We also notice that the training process of the DeiT-B model and DyViT-B is less stable compared to smaller models. Therefore, it is important to keep a large global batch size (128x8). Besides, we find the training of DyViT-B/0.8 and DyViT-B/0.9 is more stable. Maybe you can use these two settings to check whether the lower performance is caused by the environment or configurations.
Thanks for reply. Could you release the training logs and the trained checkpoints for deit-b?
Sure. We are trying to reproduce these results. We will upload them later.
Thanks a lot!
Hi, thanks for your interest in our work. We have reproduced results on DynamicViT-B/0.7 and updated the checkpoints in README.md, achieving 81.43% Top.1 Accuracy and 95.46% Top.5 Accuracy. Training logs can be found at this url
There are several issues that may help reproduce the results:
warmup epoch
in the codebase is set to 20 for DynamicSwin/DynamicCNN, which is not appropriate for DynamicViT (we recommend warmup epoch=5
). We have corrected this bug and updated the training command. drop_path
is helpful for stabilizing larger model training and you can try setting drop_path=0.2
. Secondly, ratio loss is critical for optimization when training DynamicViT. if the ratio_loss
does not converge (this may happen when the base ratio is less than 0.7), you can try increasing the weight of ratio loss (we recommend ratio_weight=5.0
in this setting). I will try again. Thanks!
Hi, I run the new code but get train loss NaN from epoch 15, and hold until the end of training. Here is the training log.
{"train_lr": 9.993604093380238e-05, "train_min_lr": 0.0, "train_loss": 3.64607219213395, "train_weight_decay": 0.049999999999998865, "test_loss": 0.8946035196158019, "test_acc1": 80.4740024484253, "test_acc5": 94.80800252166748, "epoch": 0, "n_parameters": 89452846, "test_loss_ema": 1.2746423744794093, "test_acc1_ema": 72.67000229095459, "test_acc5_ema": 90.0580026019287}
{"train_lr": 0.00029996802046690124, "train_min_lr": 0.0, "train_loss": 3.340054861718802, "train_weight_decay": 0.049999999999998865, "test_loss": 0.8810712225509413, "test_acc1": 80.38200237304687, "test_acc5": 94.84800268798828, "epoch": 1, "n_parameters": 89452846, "test_loss_ema": 0.9927789774356466, "test_acc1_ema": 78.19400250762939, "test_acc5_ema": 93.53400257720948}
{"train_lr": 0.0004999999999999999, "train_min_lr": 0.0, "train_loss": 3.3404851809537095, "train_weight_decay": 0.049999999999998865, "test_loss": 0.885483177548105, "test_acc1": 80.56200244689941, "test_acc5": 94.988002527771, "epoch": 2, "n_parameters": 89452846, "test_loss_ema": 0.9043222276324575, "test_acc1_ema": 80.36000228790283, "test_acc5_ema": 94.82400244537354}
{"train_lr": 0.0007000319795330988, "train_min_lr": 0.0, "train_loss": 3.3414161773822864, "train_weight_decay": 0.049999999999998865, "test_loss": 0.8791586456425262, "test_acc1": 80.63400248321533, "test_acc5": 94.97400250305176, "epoch": 3, "n_parameters": 89452846, "test_loss_ema": 0.8997948506113255, "test_acc1_ema": 80.46400232574463, "test_acc5_ema": 94.92400255371093}
{"train_lr": 0.0009000639590661981, "train_min_lr": 0.0, "train_loss": 3.3440589879175646, "train_weight_decay": 0.049999999999998865, "test_loss": 0.8772101524201307, "test_acc1": 80.53000243804932, "test_acc5": 94.96400236114502, "epoch": 4, "n_parameters": 89452846, "test_loss_ema": 0.8990779295563698, "test_acc1_ema": 80.49200249908448, "test_acc5_ema": 94.89200254180908}
{"train_lr": 0.0009986879800992594, "train_min_lr": 9.986879800992602e-06, "train_loss": 3.263468243711763, "train_weight_decay": 0.049999999999998865, "test_loss": 0.8302653981654933, "test_acc1": 81.08400257141113, "test_acc5": 95.3740024053955, "epoch": 5, "n_parameters": 89452846, "test_loss_ema": 0.8908959823575887, "test_acc1_ema": 80.57600244140625, "test_acc5_ema": 94.92200240234375}
{"train_lr": 0.0009908344064966182, "train_min_lr": 9.908344064966173e-06, "train_loss": 3.2477996952170662, "train_weight_decay": 0.049999999999998865, "test_loss": 0.8275814583017067, "test_acc1": 81.1560024017334, "test_acc5": 95.36400250488282, "epoch": 6, "n_parameters": 89452846, "test_loss_ema": 0.8804009578218965, "test_acc1_ema": 80.71000223083496, "test_acc5_ema": 95.00200256652832}
{"train_lr": 0.000975247966391946, "train_min_lr": 9.752479663919477e-06, "train_loss": 3.221702263116551, "train_weight_decay": 0.049999999999998865, "test_loss": 0.8215054690160535, "test_acc1": 81.19800274383545, "test_acc5": 95.35400278656006, "epoch": 7, "n_parameters": 89452846, "test_loss_ema": 0.8695734071009087, "test_acc1_ema": 80.81000244049072, "test_acc5_ema": 95.08000265472413}
{"train_lr": 0.0009521744672565878, "train_min_lr": 9.521744672565882e-06, "train_loss": 3.220676154553366, "train_weight_decay": 0.049999999999998865, "test_loss": 0.8234780598996263, "test_acc1": 81.13200243591308, "test_acc5": 95.41200253234864, "epoch": 8, "n_parameters": 89452846, "test_loss_ema": 0.8603077018351266, "test_acc1_ema": 80.91400250183105, "test_acc5_ema": 95.17600260314941}
{"train_lr": 0.0009219777919553448, "train_min_lr": 9.219777919553441e-06, "train_loss": 3.23265779596105, "train_weight_decay": 0.049999999999998865, "test_loss": 0.821601357423898, "test_acc1": 81.1300023260498, "test_acc5": 95.36600256439209, "epoch": 9, "n_parameters": 89452846, "test_loss_ema": 0.8542838512044965, "test_acc1_ema": 80.96200245147705, "test_acc5_ema": 95.23400258636475}
{"train_lr": 0.000885134160096337, "train_min_lr": 8.851341600963352e-06, "train_loss": 3.2203376480429577, "train_weight_decay": 0.049999999999998865, "test_loss": 0.8211326608152101, "test_acc1": 81.20400239440917, "test_acc5": 95.41200243743897, "epoch": 10, "n_parameters": 89452846, "test_loss_ema": 0.8496657200157642, "test_acc1_ema": 80.98400244659423, "test_acc5_ema": 95.23800256866456}
{"train_lr": 0.0008422246177632995, "train_min_lr": 8.422246177632971e-06, "train_loss": 3.194862368545658, "train_weight_decay": 0.049999999999998865, "test_loss": 0.821151819193002, "test_acc1": 81.17600247772216, "test_acc5": 95.39800268524169, "epoch": 11, "n_parameters": 89452846, "test_loss_ema": 0.8454932100845106, "test_acc1_ema": 81.02000249206543, "test_acc5_ema": 95.2720022467041}
{"train_lr": 0.0007939258740717405, "train_min_lr": 7.93925874071742e-06, "train_loss": 3.2017458463601356, "train_weight_decay": 0.049999999999998865, "test_loss": 0.8211172954038237, "test_acc1": 81.08800261077882, "test_acc5": 95.3540026107788, "epoch": 12, "n_parameters": 89452846, "test_loss_ema": 0.8418626639653336, "test_acc1_ema": 81.09000255706788, "test_acc5_ema": 95.29400256530762}
{"train_lr": 0.0007409996290619515, "train_min_lr": 7.409996290619525e-06, "train_loss": 3.2007912561523733, "train_weight_decay": 0.049999999999998865, "test_loss": 0.8213914643521562, "test_acc1": 81.13200245300293, "test_acc5": 95.39400229705811, "epoch": 13, "n_parameters": 89452846, "test_loss_ema": 0.8404941935882424, "test_acc1_ema": 81.10000246551513, "test_acc5_ema": 95.26800264953613}
{"train_lr": 0.0006842805612343424, "train_min_lr": 6.842805612343424e-06, "train_loss": 3.1888397486208917, "train_weight_decay": 0.049999999999998865, "test_loss": 0.8174005874968839, "test_acc1": 81.23800253723145, "test_acc5": 95.4720027243042, "epoch": 14, "n_parameters": 89452846, "test_loss_ema": 0.8393016357087728, "test_acc1_ema": 81.09400246459961, "test_acc5_ema": 95.30600252807618}
{"train_lr": 0.0006246631641708816, "train_min_lr": 6.2466316417088085e-06, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 15, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 0.0005630876398369641, "train_min_lr": 5.6308763983696485e-06, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 16, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 0.0005005250710347763, "train_min_lr": 5.005250710347766e-06, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 17, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 0.0004379621068473917, "train_min_lr": 4.37962106847392e-06, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 18, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 0.00037638540259335207, "train_min_lr": 3.763854025933525e-06, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 19, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 0.00031676605968288186, "train_min_lr": 3.167660596828817e-06, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 20, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 0.00026004431076854807, "train_min_lr": 2.6004431076854766e-06, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 21, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 0.00020711469171466005, "train_min_lr": 2.071146917146599e-06, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 22, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 0.00015881193423231992, "train_min_lr": 1.5881193423232004e-06, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 23, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 0.00011589780166169272, "train_min_lr": 1.1589780166169294e-06, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 24, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 7.904907550907799e-05, "train_min_lr": 7.904907550907813e-07, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 25, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 4.884688219826372e-05, "train_min_lr": 4.884688219826364e-07, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 26, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 2.5767528359677518e-05, "train_min_lr": 2.576752835967749e-07, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 27, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 1.0174989190291555e-05, "train_min_lr": 1.0174989190291552e-07, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 28, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 2.3151683473384506e-06, "train_min_lr": 2.3151683473384558e-08, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 29, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
The code can continue to be run is because I comment out assert codes in engine.py (from line 64 to line 66). I also find train_loss turns to NaN when using amp training.
loss_value = loss.item()
# if not math.isfinite(loss_value): # this could trigger if using AMP
# print("Loss is {}, stopping training".format(loss_value))
# assert math.isfinite(loss_value)
The train_loss becomes NaN regardless of using full precision or mixed precision training. Have you ever encountered such a situation?
We also notice that the training of the ViT-B model is not stable if a small sparsification ratio is applied (<=0.7). If the training loss becomes NaN, it would be better to resume from the last checkpoint or re-train the model from the beginning.
The NaN occurs not so frequently when we use a large batch size. I am not sure which reason causes the NaN loss in your environment. Maybe you can change the random seed by setting args.seed
and re-train the model, or check the environment by using larger sparsification ratios (e.g., 0.8/0.9).
Hi, I follow the training command:
python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --output_dir logs/dynamicvit_deit-b --model deit-b --input_size 224 --batch_size 128 --data_path /path/to/ILSVRC2012/ --epochs 30 --base_rate 0.7 --lr 1e-3
and get the final results:
Acc@1 80.754 Acc@5 94.950 loss 0.888 Accuracy of the model on the 50000 test images: 80.8% Max accuracy: 80.80%
Acc@1 80.760 Acc@5 94.970 loss 0.887 Accuracy of the model EMA on 50000 test images: 80.8% Max EMA accuracy: 80.86%
fail to reproduce 81.3% for deit-base with 0.7 keeping ratio reported in the paper (~0.5% drop). But I got 79.2% for deit-small (79.3% in the paper). These two experiments were conducted in the same python envrionments.
My envrionments: python 3.10 torch 1.12.1 torchvision 0.13.1 timm 0.3.2
Could you provide the training instructions/checkpoints that achieved ~81.3% accuracy for deit-base?
Thanks!