XiaoXiao-Woo / PanCollection

"PanCollection" for Remote Sensing Pansharpening
https://pancollection.readthedocs.io
GNU General Public License v2.0
54 stars 9 forks source link

Unable to Test the Pre-Trained Models #1

Open UsamaShami11 opened 1 year ago

UsamaShami11 commented 1 year ago

I'm using PanCollection Dataset and currently testing the DLPAN-Toolbox for WV3, so my files are train_wv3.h5, valid_wv3.h5 and test_wv3.h5 and the model under evaluation is FusionNet.

Upon following your guidelines, I'm able to Train the model, and results seem fine, however, whenever I Test the saved or pre-trained model, the results are too poor. i.e. Negative PSNR and Low ERGAS/SAM values, and some warnings are received too

Attached below is the output after I debug run_test_pansharpening.py and it seems like the code is not able to fetch the pre-trained model (or checkpoint / weights) or showing mismatch despite of giving the correct path

PS D:\VS_Code_Projects\01-DL-toolbox(Pytorch)>  d:; cd 'd:\VS_Code_Projects\01-DL-toolbox(Pytorch)'; & 'C:\Users\User\.conda\envs\Test2\python.exe' 'c:\Users\User\.vscode\extensions\ms-python.python-2022.20.2\pythonFiles\lib\python\debugpy\adapter/../..\debugpy\launcher' '51674' '--' 'd:\VS_Code_Projects\01-DL-toolbox(Pytorch)\UDL\pansharpening\run_test_pansharpening.py' 
Backend TkAgg is interactive backend. Turning interactive mode on.
111 entrypoint
111 None
222 pansharpening
d:\vs_code_projects\01-dl-toolbox(pytorch)\UDL\Basis\option.py:114: UserWarning: Note: FusionNet, DiCNN, PNN don't have high-pass filter
  warnings.warn(warning)
111 None
333 FusionNet
use_log = True
log_dir = 'logs'
tfb_dir = None
use_tfb = False
launcher = 'none'
local_rank = 0
backend = 'nccl'
dist_url = 'env://'
amp = None
amp_opt_level = 'O0'
accumulated_step = 1
clip_max_norm = 0
seed = 1
device = 'cuda'
reg = True
crop_batch_size = 128
rgb_range = 255
model_style = None
mode = None
task = 'pansharpening'
arch = 'FusionNet'
global_rank = 0
once_epoch = False
reset_lr = False
save_top_k = 5
save_print_freq = 10
start_epoch = 1
load_model_strict = True
resume_mode = 'best'
validate = False
gpu_ids = [0]
scale = [1]
data_dir = 'D:/VS_Code_Projects/01-DL-toolbox(Pytorch)/UDL/Data/pansharpening'
best_prec1 = 10000
best_prec5 = 10000
metrics = 'loss'
save_fmt = 'mat'
taskhead = 'pansharpening'
out_dir = 'd:/vs_code_projects/01-dl-toolbox(pytorch)/UDL//results/pansharpening'
lr = 0.0003
samples_per_gpu = 32
print_freq = 50
epochs = 400
workers_per_gpu = 0
resume_from = 'D:/VS_Code_Projects/01-DL-toolbox(Pytorch)/UDL/pretrained-model/WV3/fusionnet.pth'
dataset = dict(val='TestData_wv3')
eval = True
best_epoch = 1
experimental_desc = 'Test'
img_range = 2047.0
workflow = [('val', 1)]

dict_keys(['pansharpening', 'entrypoint', 'BDPN', 'DiCNN1', 'DRPNN', 'FusionNet', 'MSDCNN', 'PanNet', 'PNN'])
=> creating d:\vs_code_projects\01-dl-toolbox(pytorch)\UDL\results\pansharpening\TestData_wv3\FusionNet\Test
- Set random seed to 1
C:\Users\User\.conda\envs\Test2\lib\site-packages\torch\nn\_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='mean' instead.
  warnings.warn(warning.format(ret))
**d:\vs_code_projects\01-dl-toolbox(pytorch)\UDL\AutoDL\trainer.py:88: UserWarning: config is now expected to have a `runner` section, please set `runner` in your config.**
  warnings.warn(
loading MultiExmTest_h5: D:/VS_Code_Projects/01-DL-toolbox(Pytorch)/UDL/Data/pansharpening/test_data/TestData_wv3.h5 with 2047.0
<KeysViewHDF5 ['gt', 'lms', 'ms', 'pan']>
lms: torch.Size([20, 8, 256, 256]), ms: torch.Size([20, 8, 64, 64]), pan: torch.Size([20, 1, 256, 256]), gt: torch.Size([20, 256, 256, 8])
**- no checkpoint found at d:\vs_code_projects\01-dl-toolbox(pytorch)\UDL\results\pansharpening\TestData_wv3\FusionNet\Test\WV3\model_best_85**
- resumed epoch 1, iter 1
- Start running, host: User@DESKTOP-EDNA95N, work_dir: d:\vs_code_projects\01-dl-toolbox(pytorch)\UDL\results\pansharpening\TestData_wv3\FusionNet\Test\WV3
- Hooks will be executed in the following order:
before_run:
(NORMAL      ) ModelCheckpoint
(LOW         ) EvalHook
(VERY_LOW    ) TextLoggerHook
 --------------------
before_train_epoch:
(LOW         ) IterTimerHook
(LOW         ) EvalHook
(VERY_LOW    ) TextLoggerHook
 --------------------
before_train_iter:
(LOW         ) IterTimerHook
(LOW         ) EvalHook
 --------------------
after_train_iter:
(ABOVE_NORMAL) OptimizerHook
(NORMAL      ) ModelCheckpoint
(LOW         ) IterTimerHook
(LOW         ) EvalHook
(VERY_LOW    ) TextLoggerHook
 --------------------
after_train_epoch:
(NORMAL      ) ModelCheckpoint
(LOW         ) EvalHook
(VERY_LOW    ) TextLoggerHook
 --------------------
before_val_epoch:
(LOW         ) IterTimerHook
(VERY_LOW    ) TextLoggerHook
 --------------------
before_val_iter:
(LOW         ) IterTimerHook
 --------------------
after_val_iter:
(LOW         ) IterTimerHook
(VERY_LOW    ) TextLoggerHook
 --------------------
after_val_epoch:
(VERY_LOW    ) TextLoggerHook
 --------------------
after_run:
(VERY_LOW    ) TextLoggerHook
 --------------------
- workflow: [('val', 1)], max: 1 epochs
- Checkpoints will be saved to d:\vs_code_projects\01-dl-toolbox(pytorch)\UDL\results\pansharpening\TestData_wv3\FusionNet\Test\WV3
Backend TkAgg is interactive backend. Turning interactive mode on.
- Epoch(val) [1][1]     SAM: 22.84085, ERGAS: 34.62882, PSNR: -46.77319
- Epoch(val) [1][2]     SAM: 22.48370, ERGAS: 31.51278, PSNR: -48.47853
- Epoch(val) [1][3]     SAM: 23.08300, ERGAS: 30.93331, PSNR: -48.41457
- Epoch(val) [1][4]     SAM: 23.56104, ERGAS: 30.82047, PSNR: -48.42678
- Epoch(val) [1][5]     SAM: 23.37117, ERGAS: 30.51880, PSNR: -48.82153
- Epoch(val) [1][6]     SAM: 23.27561, ERGAS: 30.52608, PSNR: -48.84397
- Epoch(val) [1][7]     SAM: 23.12976, ERGAS: 29.98178, PSNR: -48.95483
- Epoch(val) [1][8]     SAM: 22.90906, ERGAS: 29.75901, PSNR: -49.06594
- Epoch(val) [1][9]     SAM: 23.12697, ERGAS: 29.50100, PSNR: -48.84274
- Epoch(val) [1][10]    SAM: 23.19153, ERGAS: 29.14926, PSNR: -48.68628
- Epoch(val) [1][11]    SAM: 23.06567, ERGAS: 29.31538, PSNR: -48.92280
- Epoch(val) [1][12]    SAM: 22.98491, ERGAS: 29.27910, PSNR: -49.09107
- Epoch(val) [1][13]    SAM: 22.46682, ERGAS: 29.52684, PSNR: -49.38488
- Epoch(val) [1][14]    SAM: 22.38683, ERGAS: 29.92165, PSNR: -49.36148
- Epoch(val) [1][15]    SAM: 21.96438, ERGAS: 29.72373, PSNR: -49.71417
- Epoch(val) [1][16]    SAM: 21.58497, ERGAS: 29.56194, PSNR: -50.03238
- Epoch(val) [1][17]    SAM: 21.39214, ERGAS: 29.48476, PSNR: -50.26566
- Epoch(val) [1][18]    SAM: 21.17842, ERGAS: 29.38613, PSNR: -50.47970
- Epoch(val) [1][19]    SAM: 21.03338, ERGAS: 29.24421, PSNR: -50.60560
- Epoch(val) [1][20]    SAM: 20.92074, ERGAS: 29.09417, PSNR: -50.69608
test time: 26.28206777572632
- Epoch(val) [1][20]    SAM: 20.92074, ERGAS: 29.09417, PSNR: -50.69608
- Training time 0:00:29

I even tried pasting the same checkpoint (produced during training) in the Test folder but following error was received (along with the same results as above)

d:\vs_code_projects\01-dl-toolbox(pytorch)\UDL\mmcv\mmcv\runner\checkpoint.py:585: UserWarning: checkpoint in directory d:\vs_code_projects\01-dl-toolbox(pytorch)\UDL\results\pansharpening\TestData_wv3\FusionNet\Test\WV3\checkpoint don't exist or is empty
  warnings.warn(msg)
**- loading best model failed, maybe it's from scratch currently.
- load checkpoint from local path: D:/VS_Code_Projects/01-DL-toolbox(Pytorch)/UDL/pretrained-model/WV3/fusionnet.pth
- The model and loaded state dict do not match exactly**

unexpected key in source state_dict: meta

Please suggest what's the solution for this, as I'm unable to reproduce the exact quoted results. Thanks!

songwenhao123 commented 1 year ago

I ran into the same problem when I tested with test_wv3_multiExm1 provided by the author, which worked just as well as yours, but I tested with the verification set a little better, but not as well as in the author's paper. If you can solve this problem, I hope you can help me. Thank you very much.