Closed StorywithLove closed 10 months ago
After some investigation, I found that the "losing nan" problem may be caused by Adam with half-precision. In the community, I found the torch.cuda.amp package which enables automatic precision control and scaling, the modified code and screenshots are below:
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
After the modification, "loss nan" problem is solved, but the training speed of the model becomes much slower, close to 5 minutes per epoch, while before the modification almost 50 seconds per epoch.
I would like to ask if there is any way to solve the "loss nan" problem while maintaining the training speed.
-----------------------------------------------after------------------------------------------------------------------
Initializing backbone weights from: pretrained_changeformer/pretrained_changeformer.pt
0 400
lr: 0.0001000
0%| | 0/21 [00:00<?, ?it/s]/home/ubuntu/anaconda3/envs/ChangeFormer/lib/python3.8/site-packages/torch/nn/functional.py:3631: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
warnings.warn(
5%|██ | 1/21 [00:06<02:14, 6.72s/it]Is_training: True. [0,399][1,21], imps: 0.12, est: 18.91h, G_loss: 2.57264, running_mf1: 0.44941
100%|███████████████████████████████████████████| 21/21 [04:00<00:00, 11.46s/it]
Is_training: True. Epoch 0 / 399, epoch_mF1= 0.49792
acc: 0.89754 miou: 0.46145 mf1: 0.49792 iou_0: 0.89726 iou_1: 0.02563 F1_0: 0.94585 F1_1: 0.04999 precision_0: 0.94412 precision_1: 0.05165 recall_0: 0.94758 recall_1: 0.04843
Begin evaluation...
Is_training: False. [0,399][1,2], imps: 0.01, est: 289.58h, G_loss: 2.51084, running_mf1: 0.49455
Is_training: False. Epoch 0 / 399, epoch_mF1= 0.47554
acc: 0.89846 miou: 0.45036 mf1: 0.47554 iou_0: 0.89843 iou_1: 0.00229 F1_0: 0.94650 F1_1: 0.00457 precision_0: 0.90089 precision_1: 0.07873 recall_0: 0.99697 recall_1: 0.00235
------------------------------------------------before-------------------------------------------------------------
Initializing backbone weights from: pretrained_changeformer/pretrained_changeformer.pt
0 400
lr: 0.0001000
0%| | 0/21 [00:00<?, ?it/s]/home/ubuntu/anaconda3/envs/ChangeFormer/lib/python3.8/site-packages/torch/nn/functional.py:3631: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
warnings.warn(
5%|███▉ | 1/21 [00:03<01:10, 3.52s/it]Is_training: True. [0,399][1,21], imps: 0.19, est: 12.39h, G_loss: nan, running_mf1: 0.48448
100%|█████████████████████████████████████████████████████████████████████████████████| 21/21 [00:18<00:00, 1.11it/s]
Is_training: True. Epoch 0 / 399, epoch_mF1= 0.48279
acc: 0.92241 miou: 0.46269 mf1: 0.48279 iou_0: 0.92239 iou_1: 0.00299 F1_0: 0.95963 F1_1: 0.00596 precision_0: 0.94450 precision_1: 0.00983 recall_0: 0.97525 recall_1: 0.00427
Begin evaluation...
Is_training: False. [0,399][1,2], imps: 0.08, est: 30.80h, G_loss: nan, running_mf1: 0.49474
Is_training: False. Epoch 0 / 399, epoch_mF1= 0.47395
acc: 0.90095 miou: 0.45048 mf1: 0.47395 iou_0: 0.90095 iou_1: 0.00000 F1_0: 0.94790 F1_1: 0.00000 precision_0: 0.90095 precision_1: 0.00000 recall_0: 1.00000 recall_1: 0.00000
I have a set of images of size 1024*1024 and the framework converges well when img_size is set to 512. When I set img_size to 1014 and set half-precision (to accommodate GPU memory limitations), the model has a G_loss of nan and doesn't converge.
The specific changes are as follows: Add
torch.set_default_tensor_type(torch.HalfTensor)
to main_cd.py. And change six places in ChangeFormer.py, from orch.linspace to np.linsapce (to use half precision).The training information for the model is as follows: (I tried not using the pre-trained model, or adjusting the lr from 1e-4 to [1e-5. 1e-6. 1e-7], which did not solve the problem.)