zhysora / PSGan-Family

PyTorch Code for "PSGAN: A Generative Adversarial Network for Remote Sensing Image Pan-sharpening". TGRS 2021
48 stars 14 forks source link

Why is loss nan after several epochs of model training by QB dataset provided? #13

Open codgodtao opened 2 years ago

codgodtao commented 2 years ago

I found some bugs when I reproduced your experiment. Use the qb data set you provided and generate it according to the process After the tfrecord file, in the process of model training, the first few epochs perform normally, but after a few epochs, the training loss becomes nan, which causes the generated model to fail to work on the test set, as shown below: you can found that the loss the D & G both are nan,which is really confused that i didn't update your code except params

WARNING:tensorflow:From C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\compat\v2_compat.py:107: disable_resource_variables (from tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version. Instructions for updating: non-resource variables are not supported in the long term ../../data/TFRecords/QB_train_64.tfrecords ../../data/TFRecords/QB_test_64.tfrecords ../../data/Output/QB_test_64_psgan ../../data/Output/QB_test_64_psgan train_tfrecord = ../../data/TFRecords/QB_train_64.tfrecords test_tfrecord = ../../data/TFRecords/QB_test_64.tfrecords mode = train output_dir = ../../data/Output/QB_test_64_psgan checkpoint = None max_steps = None max_epochs = 5 summary_freq = 0 progress_freq = 200 trace_freq = 0 display_freq = 0 save_freq = 1000 batch_size = 4 lr = 0.0001 beta1 = 0.5 l1_weight = 100.0 gan_weight = 1.0 ndf = 32 train_count = 4821 test_count = 81 gpus = 0 blk = 64 Queue-based input pipelines have been replaced by tf.data. Use tf.data.Dataset.from_tensor_slices(string_tensor).shuffle(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs). If shuffle=False, omit the .shuffle(...). WARNING:tensorflow:From C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\training\input.py:262: input_producer (from tensorflow.python.training.input) is deprecated and will be removed in a future version. Instructions for updating: Queue-based input pipelines have been replaced by tf.data. Use tf.data.Dataset.from_tensor_slices(input_tensor).shuffle(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs). If shuffle=False, omit the .shuffle(...). WARNING:tensorflow:From C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\training\input.py:184: limit_epochs (from tensorflow.python.training.input) is deprecated and will be removed in a future version. Instructions for updating: Queue-based input pipelines have been replaced by tf.data. Use tf.data.Dataset.from_tensors(tensor).repeat(num_epochs). WARNING:tensorflow:From C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\training\input.py:192: QueueRunner.init (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: To construct input pipelines, use the tf.data module. WARNING:tensorflow:From C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\training\input.py:191: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version. Instructions for updating: parameter_count = 2277536 2022-10-21 21:56:15.110273: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8401 2022-10-21 21:56:16.443692: I tensorflow/stream_executor/cuda/cuda_blas.cc:1614] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once. progress epoch 1 step 199 image/sec 18.3 remaining 21m discrim_loss 0.73236054 gen_loss_GAN 1.3829831 gen_loss_L1 68.04657 progress epoch 1 step 399 image/sec 19.7 remaining 19m discrim_loss 1.0204148 gen_loss_GAN 1.1590607 gen_loss_L1 37.19431 progress epoch 1 step 599 image/sec 20.4 remaining 17m discrim_loss 0.8676947 gen_loss_GAN 1.3314455 gen_loss_L1 23.225008 progress epoch 1 step 799 image/sec 20.6 remaining 16m discrim_loss 0.9771492 gen_loss_GAN 1.5663068 gen_loss_L1 26.68383 progress epoch 1 step 999 image/sec 20.8 remaining 16m discrim_loss 0.83579296 gen_loss_GAN 1.7441618 gen_loss_L1 25.68571 saving model progress epoch 1 step 1199 image/sec 20.9 remaining 15m discrim_loss 0.59473586 gen_loss_GAN 2.3351698 gen_loss_L1 27.571407 progress epoch 2 step 193 image/sec 21.1 remaining 14m discrim_loss nan gen_loss_GAN nan gen_loss_L1 nan progress epoch 2 step 393 image/sec 21.3 remaining 13m discrim_loss nan gen_loss_GAN nan gen_loss_L1 nan progress epoch 2 step 593 image/sec 21.5 remaining 13m discrim_loss nan gen_loss_GAN nan gen_loss_L1 nan progress epoch 2 step 793 image/sec 21.6 remaining 12m discrim_loss nan gen_loss_GAN nan gen_loss_L1 nan saving model WARNING:tensorflow:From C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\training\saver.py:1064: remove_checkpoint (from tensorflow.python.checkpoint.checkpoint_management) is deprecated and will be removed in a future version. Instructions for updating: Use standard file APIs to delete files with this prefix. progress epoch 2 step 993 image/sec 21.7 remaining 11m discrim_loss nan gen_loss_GAN nan gen_loss_L1 nan progress epoch 2 step 1193 image/sec 21.8 remaining 11m discrim_loss nan gen_loss_GAN nan gen_loss_L1 nan progress epoch 3 step 187 image/sec 21.9 remaining 10m discrim_loss nan gen_loss_GAN nan gen_loss_L1 nan progress epoch 3 step 387 image/sec 22.0 remaining 9m discrim_loss nan gen_loss_GAN nan gen_loss_L1 nan progress epoch 3 step 587 image/sec 22.1 remaining 9m discrim_loss nan gen_loss_GAN nan gen_loss_L1 nan saving model progress epoch 3 step 787 image/sec 22.1 remaining 8m discrim_loss nan gen_loss_GAN nan gen_loss_L1 nan progress epoch 3 step 987 image/sec 22.1 remaining 7m discrim_loss nan gen_loss_GAN nan gen_loss_L1 nan progress epoch 3 step 1187 image/sec 22.2 remaining 7m discrim_loss nan gen_loss_GAN nan gen_loss_L1 nan progress epoch 4 step 181 image/sec 22.2 remaining 6m discrim_loss nan gen_loss_GAN nan gen_loss_L1 nan progress epoch 4 step 381 image/sec 22.3 remaining 6m discrim_loss nan gen_loss_GAN nan gen_loss_L1 nan saving model

JUSTM0VE0N commented 1 year ago

Helllo, I have the save problem as you.Have you saved it?

codgodtao commented 1 year ago

Helllo, I have the save problem as you.Have you saved it?

sorry for that,i didn't solve this problem,maybe this GAN-based model is hard to train

bingkun99 commented 1 month ago

Helllo, I have the save problem as you.Have you saved it?你好,我也遇到了和你一样的保存问题,你保存了吗?

sorry for that,i didn't solve this problem,maybe this GAN-based model is hard to train抱歉,我没有解决这个问题,也许这个基于 GAN 的模型很难训练

您好,我最近也在看PSgan的代码,作者给出的原始数据,制作数据集的话,是直接按照代码进行切割,还是需要经过滤波矫正之类的操作在进行切割,然后to_patch。希望您能给我解答一下。刚开始学习这方面知识,如有冒犯请见谅!