raoyongming / DynamicViT

[NeurIPS 2021] [T-PAMI] DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification
https://dynamicvit.ivg-research.xyz/
MIT License
576 stars 72 forks source link

Fail to reproduce accuracy of DynamicViT-B/0.7: lower accuracy than reported #27

Closed ShiFengyuan1999 closed 1 year ago

ShiFengyuan1999 commented 1 year ago

Hi, I follow the training command:

python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --output_dir logs/dynamicvit_deit-b --model deit-b --input_size 224 --batch_size 128 --data_path /path/to/ILSVRC2012/ --epochs 30 --base_rate 0.7 --lr 1e-3

and get the final results:

fail to reproduce 81.3% for deit-base with 0.7 keeping ratio reported in the paper (~0.5% drop). But I got 79.2% for deit-small (79.3% in the paper). These two experiments were conducted in the same python envrionments.

My envrionments: python 3.10 torch 1.12.1 torchvision 0.13.1 timm 0.3.2

Could you provide the training instructions/checkpoints that achieved ~81.3% accuracy for deit-base?

Thanks!

raoyongming commented 1 year ago

Hi, thanks for your interest in our work. I just checked our logs. We can achieve the reported accuracies in the same environment.

We also notice that the training process of the DeiT-B model and DyViT-B is less stable compared to smaller models. Therefore, it is important to keep a large global batch size (128x8). Besides, we find the training of DyViT-B/0.8 and DyViT-B/0.9 is more stable. Maybe you can use these two settings to check whether the lower performance is caused by the environment or configurations.

ShiFengyuan1999 commented 1 year ago

Thanks for reply. Could you release the training logs and the trained checkpoints for deit-b?

raoyongming commented 1 year ago

Sure. We are trying to reproduce these results. We will upload them later.

ShiFengyuan1999 commented 1 year ago

Thanks a lot!

liuzuyan commented 1 year ago

Hi, thanks for your interest in our work. We have reproduced results on DynamicViT-B/0.7 and updated the checkpoints in README.md, achieving 81.43% Top.1 Accuracy and 95.46% Top.5 Accuracy. Training logs can be found at this url

There are several issues that may help reproduce the results:

  1. The default warmup epoch in the codebase is set to 20 for DynamicSwin/DynamicCNN, which is not appropriate for DynamicViT (we recommend warmup epoch=5). We have corrected this bug and updated the training command.
  2. The training process of DeiT-B is less stable compared to smaller models and the following techniques may help. Firstly, drop_path is helpful for stabilizing larger model training and you can try setting drop_path=0.2. Secondly, ratio loss is critical for optimization when training DynamicViT. if the ratio_loss does not converge (this may happen when the base ratio is less than 0.7), you can try increasing the weight of ratio loss (we recommend ratio_weight=5.0 in this setting).
ShiFengyuan1999 commented 1 year ago

I will try again. Thanks!

ShiFengyuan1999 commented 1 year ago

Hi, I run the new code but get train loss NaN from epoch 15, and hold until the end of training. Here is the training log.

{"train_lr": 9.993604093380238e-05, "train_min_lr": 0.0, "train_loss": 3.64607219213395, "train_weight_decay": 0.049999999999998865, "test_loss": 0.8946035196158019, "test_acc1": 80.4740024484253, "test_acc5": 94.80800252166748, "epoch": 0, "n_parameters": 89452846, "test_loss_ema": 1.2746423744794093, "test_acc1_ema": 72.67000229095459, "test_acc5_ema": 90.0580026019287}
{"train_lr": 0.00029996802046690124, "train_min_lr": 0.0, "train_loss": 3.340054861718802, "train_weight_decay": 0.049999999999998865, "test_loss": 0.8810712225509413, "test_acc1": 80.38200237304687, "test_acc5": 94.84800268798828, "epoch": 1, "n_parameters": 89452846, "test_loss_ema": 0.9927789774356466, "test_acc1_ema": 78.19400250762939, "test_acc5_ema": 93.53400257720948}
{"train_lr": 0.0004999999999999999, "train_min_lr": 0.0, "train_loss": 3.3404851809537095, "train_weight_decay": 0.049999999999998865, "test_loss": 0.885483177548105, "test_acc1": 80.56200244689941, "test_acc5": 94.988002527771, "epoch": 2, "n_parameters": 89452846, "test_loss_ema": 0.9043222276324575, "test_acc1_ema": 80.36000228790283, "test_acc5_ema": 94.82400244537354}
{"train_lr": 0.0007000319795330988, "train_min_lr": 0.0, "train_loss": 3.3414161773822864, "train_weight_decay": 0.049999999999998865, "test_loss": 0.8791586456425262, "test_acc1": 80.63400248321533, "test_acc5": 94.97400250305176, "epoch": 3, "n_parameters": 89452846, "test_loss_ema": 0.8997948506113255, "test_acc1_ema": 80.46400232574463, "test_acc5_ema": 94.92400255371093}
{"train_lr": 0.0009000639590661981, "train_min_lr": 0.0, "train_loss": 3.3440589879175646, "train_weight_decay": 0.049999999999998865, "test_loss": 0.8772101524201307, "test_acc1": 80.53000243804932, "test_acc5": 94.96400236114502, "epoch": 4, "n_parameters": 89452846, "test_loss_ema": 0.8990779295563698, "test_acc1_ema": 80.49200249908448, "test_acc5_ema": 94.89200254180908}
{"train_lr": 0.0009986879800992594, "train_min_lr": 9.986879800992602e-06, "train_loss": 3.263468243711763, "train_weight_decay": 0.049999999999998865, "test_loss": 0.8302653981654933, "test_acc1": 81.08400257141113, "test_acc5": 95.3740024053955, "epoch": 5, "n_parameters": 89452846, "test_loss_ema": 0.8908959823575887, "test_acc1_ema": 80.57600244140625, "test_acc5_ema": 94.92200240234375}
{"train_lr": 0.0009908344064966182, "train_min_lr": 9.908344064966173e-06, "train_loss": 3.2477996952170662, "train_weight_decay": 0.049999999999998865, "test_loss": 0.8275814583017067, "test_acc1": 81.1560024017334, "test_acc5": 95.36400250488282, "epoch": 6, "n_parameters": 89452846, "test_loss_ema": 0.8804009578218965, "test_acc1_ema": 80.71000223083496, "test_acc5_ema": 95.00200256652832}
{"train_lr": 0.000975247966391946, "train_min_lr": 9.752479663919477e-06, "train_loss": 3.221702263116551, "train_weight_decay": 0.049999999999998865, "test_loss": 0.8215054690160535, "test_acc1": 81.19800274383545, "test_acc5": 95.35400278656006, "epoch": 7, "n_parameters": 89452846, "test_loss_ema": 0.8695734071009087, "test_acc1_ema": 80.81000244049072, "test_acc5_ema": 95.08000265472413}
{"train_lr": 0.0009521744672565878, "train_min_lr": 9.521744672565882e-06, "train_loss": 3.220676154553366, "train_weight_decay": 0.049999999999998865, "test_loss": 0.8234780598996263, "test_acc1": 81.13200243591308, "test_acc5": 95.41200253234864, "epoch": 8, "n_parameters": 89452846, "test_loss_ema": 0.8603077018351266, "test_acc1_ema": 80.91400250183105, "test_acc5_ema": 95.17600260314941}
{"train_lr": 0.0009219777919553448, "train_min_lr": 9.219777919553441e-06, "train_loss": 3.23265779596105, "train_weight_decay": 0.049999999999998865, "test_loss": 0.821601357423898, "test_acc1": 81.1300023260498, "test_acc5": 95.36600256439209, "epoch": 9, "n_parameters": 89452846, "test_loss_ema": 0.8542838512044965, "test_acc1_ema": 80.96200245147705, "test_acc5_ema": 95.23400258636475}
{"train_lr": 0.000885134160096337, "train_min_lr": 8.851341600963352e-06, "train_loss": 3.2203376480429577, "train_weight_decay": 0.049999999999998865, "test_loss": 0.8211326608152101, "test_acc1": 81.20400239440917, "test_acc5": 95.41200243743897, "epoch": 10, "n_parameters": 89452846, "test_loss_ema": 0.8496657200157642, "test_acc1_ema": 80.98400244659423, "test_acc5_ema": 95.23800256866456}
{"train_lr": 0.0008422246177632995, "train_min_lr": 8.422246177632971e-06, "train_loss": 3.194862368545658, "train_weight_decay": 0.049999999999998865, "test_loss": 0.821151819193002, "test_acc1": 81.17600247772216, "test_acc5": 95.39800268524169, "epoch": 11, "n_parameters": 89452846, "test_loss_ema": 0.8454932100845106, "test_acc1_ema": 81.02000249206543, "test_acc5_ema": 95.2720022467041}
{"train_lr": 0.0007939258740717405, "train_min_lr": 7.93925874071742e-06, "train_loss": 3.2017458463601356, "train_weight_decay": 0.049999999999998865, "test_loss": 0.8211172954038237, "test_acc1": 81.08800261077882, "test_acc5": 95.3540026107788, "epoch": 12, "n_parameters": 89452846, "test_loss_ema": 0.8418626639653336, "test_acc1_ema": 81.09000255706788, "test_acc5_ema": 95.29400256530762}
{"train_lr": 0.0007409996290619515, "train_min_lr": 7.409996290619525e-06, "train_loss": 3.2007912561523733, "train_weight_decay": 0.049999999999998865, "test_loss": 0.8213914643521562, "test_acc1": 81.13200245300293, "test_acc5": 95.39400229705811, "epoch": 13, "n_parameters": 89452846, "test_loss_ema": 0.8404941935882424, "test_acc1_ema": 81.10000246551513, "test_acc5_ema": 95.26800264953613}
{"train_lr": 0.0006842805612343424, "train_min_lr": 6.842805612343424e-06, "train_loss": 3.1888397486208917, "train_weight_decay": 0.049999999999998865, "test_loss": 0.8174005874968839, "test_acc1": 81.23800253723145, "test_acc5": 95.4720027243042, "epoch": 14, "n_parameters": 89452846, "test_loss_ema": 0.8393016357087728, "test_acc1_ema": 81.09400246459961, "test_acc5_ema": 95.30600252807618}
{"train_lr": 0.0006246631641708816, "train_min_lr": 6.2466316417088085e-06, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 15, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 0.0005630876398369641, "train_min_lr": 5.6308763983696485e-06, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 16, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 0.0005005250710347763, "train_min_lr": 5.005250710347766e-06, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 17, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 0.0004379621068473917, "train_min_lr": 4.37962106847392e-06, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 18, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 0.00037638540259335207, "train_min_lr": 3.763854025933525e-06, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 19, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 0.00031676605968288186, "train_min_lr": 3.167660596828817e-06, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 20, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 0.00026004431076854807, "train_min_lr": 2.6004431076854766e-06, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 21, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 0.00020711469171466005, "train_min_lr": 2.071146917146599e-06, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 22, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 0.00015881193423231992, "train_min_lr": 1.5881193423232004e-06, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 23, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 0.00011589780166169272, "train_min_lr": 1.1589780166169294e-06, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 24, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 7.904907550907799e-05, "train_min_lr": 7.904907550907813e-07, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 25, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 4.884688219826372e-05, "train_min_lr": 4.884688219826364e-07, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 26, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 2.5767528359677518e-05, "train_min_lr": 2.576752835967749e-07, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 27, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 1.0174989190291555e-05, "train_min_lr": 1.0174989190291552e-07, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 28, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}
{"train_lr": 2.3151683473384506e-06, "train_min_lr": 2.3151683473384558e-08, "train_loss": NaN, "train_weight_decay": 0.049999999999998865, "test_loss": NaN, "test_acc1": 0.10000000122070313, "test_acc5": 0.5000000244140625, "epoch": 29, "n_parameters": 89452846, "test_loss_ema": NaN, "test_acc1_ema": 0.10000000122070313, "test_acc5_ema": 0.5000000244140625}

The code can continue to be run is because I comment out assert codes in engine.py (from line 64 to line 66). I also find train_loss turns to NaN when using amp training.

        loss_value = loss.item()

        # if not math.isfinite(loss_value): # this could trigger if using AMP
        #     print("Loss is {}, stopping training".format(loss_value))
        #     assert math.isfinite(loss_value)

The train_loss becomes NaN regardless of using full precision or mixed precision training. Have you ever encountered such a situation?

raoyongming commented 1 year ago

We also notice that the training of the ViT-B model is not stable if a small sparsification ratio is applied (<=0.7). If the training loss becomes NaN, it would be better to resume from the last checkpoint or re-train the model from the beginning.

The NaN occurs not so frequently when we use a large batch size. I am not sure which reason causes the NaN loss in your environment. Maybe you can change the random seed by setting args.seed and re-train the model, or check the environment by using larger sparsification ratios (e.g., 0.8/0.9).