Open mratsim opened 6 years ago
set track_running_stats=True
in InstanceNorm2d should be able to fix this
track_running_stats = True
is buggy and does not work (or I missed something).
I went the other way with the following:
# https://github.com/zhanghang1989/PyTorch-Multi-Style-Transfer/issues/21
# Compatibility shim for PyTorch 0.4
model_dict = torch.load('21styles.model')
model_dict_clone = model_dict.copy() # We can't mutate while iterating
for key, value in model_dict_clone.items():
if key.endswith(('running_mean', 'running_var')):
del model_dict[key]
### Next cell
style_model = Net(ngf=128)
style_model.load_state_dict(model_dict, False)
pip install torch==0.3.0.post4
camera_demo.py
and main.py
files, the above translates into changingstyle_model = Net(ngf=args.ngf)
style_model.load_state_dict(torch.load(args.model))
to
model_dict = torch.load(args.model)
model_dict_clone = model_dict.copy() # We can't mutate while iterating
for key, value in model_dict_clone.items():
if key.endswith(('running_mean', 'running_var')):
del model_dict[key]
style_model.load_state_dict(model_dict, False)
style_v.data()
to style_v.data
.Just got camera_demo.py
and main.py
working - thanks @alvinwan and @mratsim for the hints above.
For a while I was getting this error:
File "camera_demo.py", line 105, in <module>
main()
File "camera_demo.py", line 102, in main
run_demo(args, mirror=True)
File "camera_demo.py", line 75, in run_demo
simg = simg.transpose(1, 2, 0).astype('uint8')
ValueError: axes don't match array
The quick way to debug is was by replacing my command-line python
with python -m pdb
and, once it crashed and gave me a prompt, checking the shape of simg. Evidently simg now has 4 dimensions rather than 3, which I fixed with the reshape
in step 3 below.
My full fixes were:
1. Downgrade torch:
pip uninstall torch
pip install torch==0.3.0.post4
2. In camera_demo.py and main.py replace
style_model = Net(ngf=args.ngf)
With
model_dict = torch.load(args.model) # or args.resume,
# matching what's in the line with style_model.load_state_dict
model_dict_clone = model_dict.copy() # We can't mutate while iterating
for key, value in model_dict_clone.items():
if key.endswith(('running_mean', 'running_var')):
del model_dict[key]
style_model = Net(ngf=128) # to run with torch-0.3.0.post4
# style_model = Net(ngf=args.ngf) # to run main.py with torch-0.4.0
Replace
style_model.load_state_dict(torch.load(args.model)) # or (args.resume) one place
With
style_model.load_state_dict(model_dict, False)
3. Replace
simg = style_v.data().numpy()
With
simg = style_v.data.numpy().reshape((3,512,512))
QUESTION
Instead of downgrading torch, I also tried setting track_running_stats=True
for InstanceNorm2d in net.py. I had to do this in a few places: follow norm_layer through the code, including in the Bottleneck and UpBottleneck classes.
(Note that the documentation shows that track_running_stats=True is the default for most normalization layer classes.)
I've gotten main.py working with torch upgraded, but camera_demo gives an all-black image as output. I'm interested in comments, or ideas!
@karenerobinson I think the most reasonable way would be to wait for PyTorch 1.0 that should happen within days so that APIs are more stable we don't have to fix something new once again once it hits.
How do you set track_running_stats = True
? I am a beginner sorry if it's too obvious I can't find it for the past hour or so.
Thanks
How do you set
track_running_stats = True
? I am a beginner sorry if it's too obvious I can't find it for the past hour or so.Thanks
try what @mratsim has mentioned above. model_dict = torch.load('21styles.model') model_dict_clone = model_dict.copy() # We can't mutate while iterating
for key, value in model_dict_clone.items(): if key.endswith(('running_mean', 'running_var')): del model_dict[key]
style_model = Net(ngf=128) style_model.load_state_dict(model_dict, False)
How do you set
track_running_stats = True
? I am a beginner sorry if it's too obvious I can't find it for the past hour or so. Thankstry what @mratsim has mentioned above. model_dict = torch.load('21styles.model') model_dict_clone = model_dict.copy() # We can't mutate while iterating
for key, value in model_dict_clone.items(): if key.endswith(('running_mean', 'running_var')): del model_dict[key]
Next cell
style_model = Net(ngf=128) style_model.load_state_dict(model_dict, False)
I fixed my issue, I went to NN packages in my python site packages dir and set track_running_stats=True
on the instanceNorm file. I didn't know how to do that. After a bit more tweaking, I got it to work. Thanks anyways :)
I really appreciate the comments for fixing the compatibility issue for the code. I haven't worked on this project for a while. Could you consider providing a pull request to the master branch? Thanks a lot :)
Thanks to @alvinwan for sharing the fixes. I have tried it and it worked for both main.py
and camera_demo.py
. @zhanghang1989 As this is still not fixed in the master
branch, I have created a pull request for it (including another fix for load_lua
) here.
How do you set
track_running_stats = True
? I am a beginner sorry if it's too obvious I can't find it for the past hour or so. Thankstry what @mratsim has mentioned above. model_dict = torch.load('21styles.model') model_dict_clone = model_dict.copy() # We can't mutate while iterating for key, value in model_dict_clone.items(): if key.endswith(('running_mean', 'running_var')): del model_dict[key]
Next cell
style_model = Net(ngf=128) style_model.load_state_dict(model_dict, False)
I fixed my issue, I went to NN packages in my python site packages dir and set
track_running_stats=True
on the instanceNorm file. I didn't know how to do that. After a bit more tweaking, I got it to work. Thanks anyways :)
I have been looking for this for 50 hours, thanks
PyTorch 0.4.0 was released on April 24 and unfortunately the pre-trained weights from before are not compatible.
On the notebook I get