AlexHex7 / Non-local_pytorch

Implementation of Non-local Block.
Apache License 2.0
1.57k stars 278 forks source link

在运行第三点时 $ CUDA_VISIBLE_DEVICES=0,1 python nl_map_save.py 报这个错,该怎么改啊 #47

Open wys2929 opened 2 years ago

wys2929 commented 2 years ago

Traceback (most recent call last): File "nl_map_save.py", line 20, in net.load_state_dict(torch.load('weights/net.pth')) File "/home/user/anaconda3/envs/TEST2/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1045, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for Network: Missing key(s) in state_dict: "conv_1.0.weight", "conv_1.0.bias", "conv_1.1.weight", "conv_1.1.bias", "conv_1.1.running_mean", "conv_1.1.running_var", "nl_1.g.0.weight", "nl_1.g.0.bias", "nl_1.W.0.weight", "nl_1.W.0.bias", "nl_1.W.1.weight", "nl_1.W.1.bias", "nl_1.W.1.running_mean", "nl_1.W.1.running_var", "nl_1.theta.weight", "nl_1.theta.bias", "nl_1.phi.0.weight", "nl_1.phi.0.bias", "conv_2.0.weight", "conv_2.0.bias", "conv_2.1.weight", "conv_2.1.bias", "conv_2.1.running_mean", "conv_2.1.running_var", "nl_2.g.0.weight", "nl_2.g.0.bias", "nl_2.W.0.weight", "nl_2.W.0.bias", "nl_2.W.1.weight", "nl_2.W.1.bias", "nl_2.W.1.running_mean", "nl_2.W.1.running_var", "nl_2.theta.weight", "nl_2.theta.bias", "nl_2.phi.0.weight", "nl_2.phi.0.bias", "conv_3.0.weight", "conv_3.0.bias", "conv_3.1.weight", "conv_3.1.bias", "conv_3.1.running_mean", "conv_3.1.running_var", "fc.0.weight", "fc.0.bias", "fc.3.weight", "fc.3.bias". Unexpected key(s) in state_dict: "module.conv_1.0.weight", "module.conv_1.0.bias", "module.conv_1.1.weight", "module.conv_1.1.bias", "module.conv_1.1.running_mean", "module.conv_1.1.running_var", "module.conv_1.1.num_batches_tracked", "module.nl_1.g.0.weight", "module.nl_1.g.0.bias", "module.nl_1.W.0.weight", "module.nl_1.W.0.bias", "module.nl_1.W.1.weight", "module.nl_1.W.1.bias", "module.nl_1.W.1.running_mean", "module.nl_1.W.1.running_var", "module.nl_1.W.1.num_batches_tracked", "module.nl_1.theta.weight", "module.nl_1.theta.bias", "module.nl_1.phi.0.weight", "module.nl_1.phi.0.bias", "module.conv_2.0.weight", "module.conv_2.0.bias", "module.conv_2.1.weight", "module.conv_2.1.bias", "module.conv_2.1.running_mean", "module.conv_2.1.running_var", "module.conv_2.1.num_batches_tracked", "module.nl_2.g.0.weight", "module.nl_2.g.0.bias", "module.nl_2.W.0.weight", "module.nl_2.W.0.bias", "module.nl_2.W.1.weight", "module.nl_2.W.1.bias", "module.nl_2.W.1.running_mean", "module.nl_2.W.1.running_var", "module.nl_2.W.1.num_batches_tracked", "module.nl_2.theta.weight", "module.nl_2.theta.bias", "module.nl_2.phi.0.weight", "module.nl_2.phi.0.bias", "module.conv_3.0.weight", "module.conv_3.0.bias", "module.conv_3.1.weight", "module.conv_3.1.bias", "module.conv_3.1.running_mean", "module.conv_3.1.running_var", "module.conv_3.1.num_batches_tracked", "module.fc.0.weight", "module.fc.0.bias", "module.fc.3.weight", "module.fc.3.bias".

wys2929 commented 2 years ago

python版本3.8.13 OpenCV版本4.5.5 pytorch版本1.8.0 在修改nl_map_save.py中 net.load_state_dict(torch.load('weights/net.pth'))为net.load_state_dict(torch.load('weights/net.pth'), strict=False)后不报上诉错误了,但又报了 RuntimeError: Unexpected error from cudaGetDeviceCount(). Did you run some cuda functions before calling NumCudaDevices() that might have already set an error? Error 804: forward compatibility was attempted on non supported HW 这个错误

AlexHex7 commented 2 years ago

@wys2929 Hi, 是否因为你的系统上没有检测到GPU?net.pth可能保存为gpu上的形式。可以尝试强制将权重读取到cpu上

torch.load('weights/net.pth', map_location='cpu')

抱歉,是代码中检查GPU逻辑的疏忽。 https://github.com/AlexHex7/Non-local_pytorch/blob/ee894ba21bc038ff4892bf7e9be5ac19a896ecee/nl_map_save.py#L16

AlexHex7 commented 2 years ago

@wys2929

Traceback (most recent call last): File "nl_map_save.py", line 20, in net.load_state_dict(torch.load('weights/net.pth')) File "/home/user/anaconda3/envs/TEST2/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1045, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for Network: Missing key(s) in state_dict: "conv_1.0.weight", "conv_1.0.bias", "conv_1.1.weight", "conv_1.1.bias", "conv_1.1.running_mean", "conv_1.1.running_var", "nl_1.g.0.weight", "nl_1.g.0.bias", "nl_1.W.0.weight", "nl_1.W.0.bias", "nl_1.W.1.weight", "nl_1.W.1.bias", "nl_1.W.1.running_mean", "nl_1.W.1.running_var", "nl_1.theta.weight", "nl_1.theta.bias", "nl_1.phi.0.weight", "nl_1.phi.0.bias", "conv_2.0.weight", "conv_2.0.bias", "conv_2.1.weight", "conv_2.1.bias", "conv_2.1.running_mean", "conv_2.1.running_var", "nl_2.g.0.weight", "nl_2.g.0.bias", "nl_2.W.0.weight", "nl_2.W.0.bias", "nl_2.W.1.weight", "nl_2.W.1.bias", "nl_2.W.1.running_mean", "nl_2.W.1.running_var", "nl_2.theta.weight", "nl_2.theta.bias", "nl_2.phi.0.weight", "nl_2.phi.0.bias", "conv_3.0.weight", "conv_3.0.bias", "conv_3.1.weight", "conv_3.1.bias", "conv_3.1.running_mean", "conv_3.1.running_var", "fc.0.weight", "fc.0.bias", "fc.3.weight", "fc.3.bias". Unexpected key(s) in state_dict: "module.conv_1.0.weight", "module.conv_1.0.bias", "module.conv_1.1.weight", "module.conv_1.1.bias", "module.conv_1.1.running_mean", "module.conv_1.1.running_var", "module.conv_1.1.num_batches_tracked", "module.nl_1.g.0.weight", "module.nl_1.g.0.bias", "module.nl_1.W.0.weight", "module.nl_1.W.0.bias", "module.nl_1.W.1.weight", "module.nl_1.W.1.bias", "module.nl_1.W.1.running_mean", "module.nl_1.W.1.running_var", "module.nl_1.W.1.num_batches_tracked", "module.nl_1.theta.weight", "module.nl_1.theta.bias", "module.nl_1.phi.0.weight", "module.nl_1.phi.0.bias", "module.conv_2.0.weight", "module.conv_2.0.bias", "module.conv_2.1.weight", "module.conv_2.1.bias", "module.conv_2.1.running_mean", "module.conv_2.1.running_var", "module.conv_2.1.num_batches_tracked", "module.nl_2.g.0.weight", "module.nl_2.g.0.bias", "module.nl_2.W.0.weight", "module.nl_2.W.0.bias", "module.nl_2.W.1.weight", "module.nl_2.W.1.bias", "module.nl_2.W.1.running_mean", "module.nl_2.W.1.running_var", "module.nl_2.W.1.num_batches_tracked", "module.nl_2.theta.weight", "module.nl_2.theta.bias", "module.nl_2.phi.0.weight", "module.nl_2.phi.0.bias", "module.conv_3.0.weight", "module.conv_3.0.bias", "module.conv_3.1.weight", "module.conv_3.1.bias", "module.conv_3.1.running_mean", "module.conv_3.1.running_var", "module.conv_3.1.num_batches_tracked", "module.fc.0.weight", "module.fc.0.bias", "module.fc.3.weight", "module.fc.3.bias".

另外这个报错的问题是:模型是基于nn.DataParallel多卡方式训练的,所有权重key是"module.conv_xxx";但是nl_map_save.py代码中,没有检测到gpu时,后续代码运行DataParallel方式。这个也是代码逻辑bug问题。