nbasyl / OFQ

The official implementation of the ICML 2023 paper OFQ-ViT
MIT License
27 stars 0 forks source link

Does mixed precision training greatly affect performance? #1

Closed lixcli closed 5 months ago

lixcli commented 1 year ago

Hi, Thanks for your code sharing! I'm a bit confused when I run this code with mix-precision training. I reproducing the w2a2 experiment using the script under train_script and slightly change this script for mix-precision training as followed:

python3 train.py -c./configs/ours_imagenet_recipe.attn_q.yml --model deit_tiny_distilled_patch16_224 \
/media/DATASET/ImageNet \
--dataset 'torch/folder' \
--epochs 300 \
--batch-size 140 \
--weight-decay 0.05 \
--warmup-lr 1.0e-6 \
--lr 5.47e-4 \
--warmup-epochs 5 \
--mixup 0.0 --cutmix 0.0 \
--aq-enable \
--aq-mode lsq \
--aq-per-channel \
--aq_clip_learnable \
--aq-bitw 2 \
--wq-enable \
--wq-per-channel \
--wq-bitw 2 \
--wq-mode statsq \
--model_type deit \
--quantized \
--pretrained \
--pretrained_initialized \
--use-kd --teacher deit_tiny_distilled_patch16_224 \
--kd_hard_and_soft 1 \
--qk_reparam \
--qk_reparam_type 0 \
--teacher_pretrained \
--output ./outputs/w2a2_deit_t_qkreparam/ \
--visible_gpu '4,5,6,7' \
--world_size '4' \
--tcp_port '36969' \
--amp 

But the best result I got is 62+% whether train with cga or not, which is lower than the result of the paper. Does mixed precision training greatly affect performance?

nbasyl commented 1 year ago

Hi, from your script, you are still using W2A2, which is single precision training, and before CGA fine-tuning, the model accuracy is indeed 62+%. You need to fine-tune the trained W2A2 model with cga.py as indicated in the training script to get the reported result.

lixcli commented 1 year ago

I still got 62%+ accuracy after CGA fine-tuning for w2a2 :(
Here is my training script

python3 cga.py -c./configs/ours_imagenet_recipe.attn_q.yml --model deit_tiny_distilled_patch16_224 \
/media/DATASET/ImageNet \
--dataset 'torch/folder' \
--epochs 300 \
--batch-size 140 \
--weight-decay 0.05 \
--warmup-lr 1.0e-6 \
--lr 5.47e-4 \
--warmup-epochs 5 \
--mixup 0.0 --cutmix 0.0 \
--aq-enable \
--aq-mode lsq \
--aq-per-channel \
--aq_clip_learnable \
--aq-bitw 2 \
--wq-enable \
--wq-per-channel \
--wq-bitw 2 \
--wq-mode statsq \
--model_type deit \
--quantized \
--pretrained \
--pretrained_initialized \
--use-kd --teacher deit_tiny_distilled_patch16_224 \
--kd_hard_and_soft 1 \
--qk_reparam \
--qk_reparam_type 1 \
--boundaryRange 0.005 \
--freeze_for_n_epochs 30 \
--teacher_pretrained \
--resume outputs/w2a2_deit_t_qkreparam/20230614-222953-deit_tiny_distilled_patch16_224-224/model_best.pth.tar \
--output ./outputs/w2a2_deit_t_qkreparam_cga_0005/ \
--visible_gpu '4,5,6,7' \
--world_size '4' \
--tcp_port '36969' \
--amp \
nbasyl commented 1 year ago

What is the mixed precision training you are referring to? It seems to me that the script you provided is identical to the one in the training script. Are you saying you are training the network with, say Nvidia apex fp16? If this is the case, we recommend keeping the training process in FP32, as we didn't test out the training result with apex and other training frameworks.

Ohh, I just saw it, you use amp in the script, please use FP32 training to reproduce the result :)

nbasyl commented 1 year ago

I will mark this issue as resolved

nbasyl commented 1 year ago

For others who might also have similar questions