yitu-opensource / T2T-ViT

ICCV2021, Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet
Other
1.14k stars 177 forks source link

Nan during training even without '--amp' #26

Closed Evgeneus closed 3 years ago

Evgeneus commented 3 years ago

Hello! I would like to train T2t_vit_14 model on ImageNet-100 dataset and 3 gpus Quadro RTX 5000 but I have gotten Nan in loss and the error. Could you please help to run the code?

I run the following command: CUDA_VISIBLE_DEVICES=1,2,6,7 bash distributed_train.sh 4 /data/datasets/imagenet-100/ --model T2t_vit_14 -b 128 --lr 1e-3 --weight-decay .03 --cutmix 0.0 --reprob 0.25 --img-size 224

Some printout: 200,4.4025687376658125,4.13563300743103,7.760000015258789,24.43999983520508 201,4.374216079711914,4.1354576759338375,7.639999951171875,24.200000134277342 202,4.392171382904053,4.136218957519532,7.7399999694824215,24.439999853515626 203,4.384297768274943,4.140018928909302,7.659999923706055,24.29999981689453 204,4.371897141138713,4.14544691696167,7.619999905395508,23.91999990234375 205,4.374680519104004,4.1505038471221924,7.7600000366210935,23.93999978027344 206,4.359750032424927,4.154387146377563,7.619999943542481,24.040000201416017 207,4.37085485458374,4.158743778991699,7.820000009155273,24.019999743652345 208,4.367704391479492,4.161868629837036,7.62000002746582,23.86000018310547 209,nan,4.160795520019532,7.7200000274658205,24.01999976196289 210,nan,nan,1.0,5.0 211,nan,nan,1.0,5.0 212,nan,nan,1.0,5.0 213,nan,nan,1.0,5.0

Error:

File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/apex/amp/handle.py", line 123, in scale_loss optimizer._post_amp_backward(loss_scaler) File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/apex/amp/_process_optimizer.py", line 249, in post_backward_no_master_weights post_backward_models_are_masters(scaler, params, stashed_grads) File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/apex/amp/_process_optimizer.py", line 135, in post_backward_models_are_masters scale_override=(grads_have_scale, stashed_have_scale, out_scale)) File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/apex/amp/scaler.py", line 183, in unscale_with_stashed out_scale/grads_have_scale, ZeroDivisionError: float division by zero Traceback (most recent call last): File "main.py", line 764, in <module> main() File "main.py", line 560, in main amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) File "main.py", line 637, in train_epoch loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order) File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/timm/utils/cuda.py", line 20, in __call__ scaled_loss.backward(create_graph=create_graph) File "/usr/lib/python3.6/contextlib.py", line 88, in __exit__ next(self.gen) File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/apex/amp/handle.py", line 123, in scale_loss optimizer._post_amp_backward(loss_scaler) File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/apex/amp/_process_optimizer.py", line 249, in post_backward_no_master_weights post_backward_models_are_masters(scaler, params, stashed_grads) File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/apex/amp/_process_optimizer.py", line 135, in post_backward_models_are_masters scale_override=(grads_have_scale, stashed_have_scale, out_scale)) File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/apex/amp/scaler.py", line 183, in unscale_with_stashed out_scale/grads_have_scale, ZeroDivisionError: float division by zero Traceback (most recent call last): File "/usr/lib/python3.6/runpy.py", line 193, in _run_module_as_main "__main__", mod_spec) File "/usr/lib/python3.6/runpy.py", line 85, in _run_code exec(code, run_globals) File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/torch/distributed/launch.py", line 261, in <module> main() File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/torch/distributed/launch.py", line 257, in main cmd=cmd) subprocess.CalledProcessError: Command '['/home/ekrivosheev/cv_env/bin/python', '-u', 'main.py', '--local_rank=3', '/data/datasets/imagenet-100/', '--model', 'T2t_vit_14', '-b', '128', '--lr', '1e-3', '--weight-decay', '.03', '--cutmix', '0.0', '--reprob', '0.25', '--img-size', '224']' returned non-zero exit status 1.

yuanli2333 commented 3 years ago

Hi, it seems that the error is still the amp. Do you use our newest 'main.py'? Or have you changed this default value of --amp as Flase in this line?

michuanhaohao commented 3 years ago

Hi, it seems that the error is still the amp. Do you use our newest 'main.py'? Or have you changed this default value of --amp as Flase in this line?

Hi! I also got NAN in loss with FP16 training. I run the code on V100 GPUs.

I resolved problem following the solution https://github.com/yitu-opensource/T2T-ViT/issues/19#issuecomment-782592358

yuanli2333 commented 3 years ago

@michuanhaohao Hi, very nice try!

Which solution you tried?

  1. disable the amp
  2. or modify this line in token_performer.py as:
    return torch.exp((wtx - xd) - torch.max((wtx - xd), dim=-1, keepdim=True).values + self.epsilon) / math.sqrt(self.m)
michuanhaohao commented 3 years ago

@michuanhaohao Hi, very nice try!

Which solution you tried?

  1. disable the amp
  2. or modify this line in token_performer.py as:
return torch.exp((wtx - xd) - torch.max((wtx - xd), dim=-1, keepdim=True).values + self.epsilon) / math.sqrt(self.m)
  1. modify this line in token_performer.py
yuanli2333 commented 3 years ago

@michuanhaohao Great! Please use the updated 'main.py' with new hyper-parameters, it can achieve higher performance than before.

huixiancheng commented 3 years ago

Still appear nan loss in a single 2080TI GPU with torch.amp in downstream missions. QQ截图20210319205349 Since it's just for one special batch,try to use the follow way to avoid influence training.(Refer to other repo) QQ截图20210319210209 Not sure really work since I am a green hand and without test.:runner:

yuanli2333 commented 3 years ago

@huixiancheng Please disable the amp in downstream tasks!

Evgeneus commented 3 years ago

For me removing amp helped.