pkuanjie / StyleNAS

Official pytorch implementation of the paper: "Ultrafast Photorealistic Style Transfer via Neural Architecture Search"
MIT License
72 stars 14 forks source link

Unable to run photo_transfer.py #1

Open rkoppula opened 4 years ago

rkoppula commented 4 years ago

Hi, I am trying to run photo_transfer.py as documented in README. But I have the following errors: 1) load_lua is not available in pytorch 1.3. So import load_lua fails.

The most frequent recommendation is to use torchfile, which I tried. https://github.com/bshillingford/python-torchfile

Screen Shot 2020-05-31 at 5 55 39 PM

2) But this fails in VGG_with_decoder.py on line 12.

Screen Shot 2020-05-31 at 5 56 14 PM

Any help is appreciated.

Thanks!

pkuanjie commented 4 years ago

Hi, There are two solutions:

  1. try to set up a pytorch0.4.1 environment (you will have to use cuda 9.0 or lower cuda versions).
  2. temporarily start a pytorch0.4.1 environment (no cuda needed), then convert the torch model to a .pkl file with torch.save. In this way, you can load the converted pre-trained model via pytorch1.3 with cuda integrated.
rkoppula commented 4 years ago

ok, thanks! I will try number 2.

xurong1981 commented 4 years ago

@rkoppula Hi, have you resolved this problem by solution 2 ? I tried it, but there was an error on "vgg.get(0)". what about your case ?

lauraset commented 2 years ago

Hi, I have converted it successfully using the convert code convert_torch.py. The environment is windows10. Firstly, install pytorch 0.4.1 (cpu) to run the convert code, and then add some codes in the photo_transfer.py as follows:


    # encoder_param = load_lua('./models/models_photorealistic_nas/vgg_normalised_conv5_1.t7')
    weightpath = './models/models_photorealistic_nas/vgg_normalised_conv5_1.pth'
    encoder_param = torch.load(weightpath)
    weight_keys = list(encoder_param.keys())
    net_e = encoder()
    state = net_e.state_dict()
    state_keys = list(state.keys())
    for i in range(len(weight_keys)):
        k1 = state_keys[i]
        k2 = weight_keys[i]
        state[k1] = encoder_param[k2]
    net_e.load_state_dict(state)```