dxyang / StyleTransfer

Implementation of "Perceptual Losses for Real-Time Style Transfer and Super-Resolution" in PyTorch
288 stars 69 forks source link

新版本pythorch使用nn.InstanceNorm2d出错 #2

Open eeccxin opened 4 years ago

eeccxin commented 4 years ago

RuntimeError: Error(s) in loading state_dict for UnetGenerator:         Unexpected running stats buffer(s) "model.U4.2.U3.4.U2.4.U1.23.U0.6.running_mean" and "model.U4.2.U3.4.U2.4.U1.23.U0.6.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.         Unexpected running stats buffer(s) "model.U4.2.U3.4.U2.4.U1.26.running_mean" and "model.U4.2.U3.4.U2.4.U1.26.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.         Unexpected running stats buffer(s) "model.U4.2.U3.4.U2.7.running_mean" and "model.U4.2.U3.4.U2.7.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.         Unexpected running stats buffer(s) "model.U4.2.U3.7.running_mean" and "model.U4.2.U3.7.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details. 解决方法:nn.InstanceNorm2d(3, affine=True)改成nn.InstanceNorm2d(3, affine=True,track_running_stats=True)

victor1cea commented 3 years ago

@1490581824 I had the same problem as you. Just go in network.py and replace all InstanceNorm2d(...) with InstanceNorm2d(..., track_running_stats=True) and it should work just fine.

victor1cea commented 3 years ago

@dxyang can I get this issue assigned?