amanchadha / iSeeBetter

iSeeBetter: Spatio-Temporal Video Super Resolution using Recurrent-Generative Back-Projection Networks | Python3 | PyTorch | GANs | CNNs | ResNets | RNNs | Published in Springer Journal of Computational Visual Media, September 2020, Tsinghua University Press
https://arxiv.org/abs/2006.11161
MIT License
363 stars 68 forks source link

gpu mode loading state_dict for DataParallel error #12

Closed rishftw closed 4 years ago

rishftw commented 4 years ago

Hi, your work is pretty impressive. I may have found a little bug. I am testing ISB in Colab in GPU mode and when I run:

!python3 iSeeBetterTest.py -c --threads 8

I get a run time error with the following output:

`Namespace(chop_forward=False, data_dir='./Vid4', debug=False, file_list='foliage_test.txt', future_frame=True, gpu_mode=True, gpus=1, model='weights/netG_epoch_4_1.pth', model_type='RBPN', nFrames=7, other_dataset=True, output='Results/', residual=False, seed=123, testBatchSize=1, threads=8, upscale_factor=4) Using GPU mode ==> Loading datasets ==> Building model RBPN [ INFO] ------------- iSeeBetter Network Architecture ------------- [ INFO] ----------------- Generator Architecture ------------------ [ INFO] DataParallel( (module): Net(

..............XXXX................OMITTED................XXXX...............

(output): ConvBlock(
  (conv): Conv2d(384, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

) ) [ INFO] Total number of parameters: 12771943 Traceback (most recent call last): File "/content/drive/My Drive/isb/iSeeBetter/utils.py", line 19, in loadPreTrainedModel model.load_state_dict(state_dict) File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 839, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for DataParallel: Missing key(s) in state_dict: "module.feat0.conv.weight", "module.feat0.conv.bias", "module.feat0.act.weight", "module.feat1.conv.weight", "module.feat1.conv.bias", ..............XXXX................OMITTED................XXXX............... "module.output.conv.weight", "module.output.conv.bias". Unexpected key(s) in state_dict: "feat0.conv.weight", "feat0.conv.bias", "feat0.act.weight", "feat1.conv.weight", "feat1.conv.bias", "feat1.act.weight", ..............XXXX................OMITTED................XXXX............... "res_feat3.5.act.weight", "output.conv.weight", "output.conv.bias".

During handling of the above exception, another exception occurred:

Traceback (most recent call last): File "iSeeBetterTest.py", line 194, in eval() File "iSeeBetterTest.py", line 79, in eval utils.loadPreTrainedModel(gpuMode=args.gpu_mode, model=model, modelPath=modelPath) File "/content/drive/My Drive/isb/iSeeBetter/utils.py", line 35, in loadPreTrainedModel model.load_state_dict(new_state_dict) File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 839, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for DataParallel: Missing key(s) in state_dict: "module.feat0.conv.weight", "module.feat0.conv.bias", "module.feat0.act.weight", "module.feat1.conv.weight", "module.feat1.conv.bias", ..............XXXX................OMITTED................XXXX............... "module.output.conv.weight", "module.output.conv.bias". Unexpected key(s) in state_dict: "onv.weight", "onv.bias", "ct.weight", "at1.conv.weight", "at1.conv.bias", "at1.act.weight", "1.up_conv1.deconv.weight", "1.up_conv1.deconv.bias", ..............XXXX................OMITTED................XXXX............... "conv.weight", "conv.bias". `

Issue seems to be that the catch block in

loadPreTrainedModel(gpuMode, model, modelPath)

function in iSeeBetter/utils.py is removing "module." from the keys in (new_)state_dict, whereas the pipline in that particular state seems to need "module." to be prepended to the front of each key instead. Changing these two lines

name = k[len('module.'):] # remove module. new_state_dict[name] = v

to

new_state_dict["module." + k] = v

seems to fix the problem for me on Colab, however I am not an expert so I am not sure if this causes any other problems. Hope it helps someone.

Full Output from !python3 iSeeBetterTest.py -c --threads 8 attached below. out.txt

amanchadha commented 4 years ago

Thanks for reporting, #14 should fix this issue which was just merged. If you still see hiccups, feel free to open a new request and let us know!