hezhangsprinter / DCPDN

Densely Connected Pyramid Dehazing Network (CVPR'2018)
405 stars 112 forks source link

cPickle.UnpicklingError: invalid load key, '<'. demo/train execution error #39

Closed JaniceLC closed 5 years ago

JaniceLC commented 5 years ago

Dear, When I execute demo.py, I got stuck at netG.load_state_dict(model)

My settings: Linux Python 2.7 NVIDIA GPU + CUDA CuDNN (CUDA 9.0) PyTorch 0.3.1

When I execute Demo.py I got the output showing below:

Namespace(annealEvery=400, annealStart=0, batchSize=1, beta1=0.5, dataroot='./facades/nat_new4', dataset='pix2pix', display=5, evalIter=500, exp='sample', imageSize=1024, inputChannelSize=3, lambdaGAN=0.01, lambdaIMG=1, lrD=0.0002, lrG=0.0002, mode='B2A', ndf=64, netD='', netG='./demo_model/netG_epoch_8.pth', ngf=64, niter=400, originalSize=1024, outputChannelSize=3, poolSize=50, valBatchSize=1, valDataroot='./facades/nat_new4', wd=0.0, workers=1)
Random Seed:  7984
/home/snf4/anaconda3/envs/py27_env/lib/python2.7/site-packages/torchvision/transforms/transforms.py:208: UserWarning: The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.
  "please use transforms.Resize instead.")
Traceback (most recent call last):
  File "demo.py", line 134, in <module>
    netG.load_state_dict(torch.load(opt.netG))
  File "/home/snf4/anaconda3/envs/py27_env/lib/python2.7/site-packages/torch/serialization.py", line 267, in load
    return _load(f, map_location, pickle_module)
  File "/home/snf4/anaconda3/envs/py27_env/lib/python2.7/site-packages/torch/serialization.py", line 410, in _load
    magic_number = pickle_module.load(f)
cPickle.UnpicklingError: invalid load key, '<'.

This error also appeared when executing train.py I tried to replace

netG = net.dehaze(inputChannelSize, outputChannelSize, ngf)
if opt.netG != '':
      netG.load_state_dict(torch.load(opt.netG))

with the code below according to this UnicodeDecodeError However, it doesn't work under this circumstance.

from functools import partial
import pickle
pickle.load = partial(pickle.load, encoding="latin1")
pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
model = torch.load(opt.netG, map_location=lambda storage, loc: storage, pickle_module=pickle)
netG = net.dehaze(inputChannelSize, outputChannelSize, ngf)

if opt.netG != '':
      netG.load_state_dict(model)

Do you have any advice? Thanks in advance!

JaniceLC commented 5 years ago

Downloaded netG_epoch_8.pth again, and problem solved!