Open zhangqizky opened 3 years ago
I tested it again. It works fine in my machine. I think the problem is the pytorch version. You can try different pytorch versions e.g. 1.4.0 and test it again.
On Mon, Nov 23, 2020 at 2:42 AM buaazhangqi notifications@github.com wrote:
Hi, Thanks for sharing. I wonder if the pretrained model u2net_portrait.pth is matched with the model in the u2net_portrait_test.py, in this file , the model is net = U2NET(3,1), becase when I run the u2net_portrait_test.py as you told, but I got this error.... Traceback (most recent call last): File "u2net_portrait_test.py", line 117, in main() File "u2net_portrait_test.py", line 87, in main net.load_state_dict(torch.load(model_dir)) File "/home/vis/dingyukang/env/python_build/lib/python2.7/site-packages/torch/nn/modules/module.py", line 721, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for U2NET: Unexpected key(s) in state_dict: "stage1.rebnconvin.bn_s1.num_batches_tracked", "stage1.rebnconv1.bn_s1.num_batches_tracked", "stage1.rebnconv2.bn_s1.num_batches_tracked", "stage1.rebnconv3.bn_s1.num_batches_tracked", "stage1.rebnconv4.bn_s1.num_batches_tracked", "stage1.rebnconv5.bn_s1.num_batches_tracked", "stage1.rebnconv6.bn_s1.num_batches_tracked", "stage1.rebnconv7.bn_s1.num_batches_tracked", "stage1.rebnconv6d.bn_s1.num_batches_tracked", "stage1.rebnconv5d.bn_s1.num_batches_tracked", "stage1.rebnconv4d.bn_s1.num_batches_tracked", "stage1.rebnconv3d.bn_s1.num_batches_tracked", "stage1.rebnconv2d.bn_s1.num_batches_tracked", "stage1.rebnconv1d.bn_s1.num_batches_tracked", "stage2.rebnconvin.bn_s1.num_batches_tracked", "stage2.rebnconv1.bn_s1.num_batches_tracked", "stage2.rebnconv2.bn_s1.num_batches_tracked", "stage2.rebnconv3.bn_s1.num_batches_tracked", "stage2.rebnconv4.bn_s1.num_batches_tracked", "stage2.rebnc
— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/NathanUA/U-2-Net/issues/98, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADSGOROFQVIDXJLAKF2DTADSRIVAXANCNFSM4T7HVZ4A .
-- Xuebin Qin PhD Department of Computing Science University of Alberta, Edmonton, AB, Canada Homepage:https://webdocs.cs.ualberta.ca/~xuebin/
You can change the "net.load_state_dict(torch.load(model_dir))" on line 88 to "net.load_state_dict(torch.load(model_dir),False)", and run "u2net_portrait_test.py" again.
You can change the "net.load_state_dict(torch.load(model_dir))" on line 88 to "net.load_state_dict(torch.load(model_dir),False)", and run "u2net_portrait_test.py" again.
It works,3q
Hi, Thanks for sharing. I wonder if the pretrained model u2net_portrait.pth is matched with the model in the u2net_portrait_test.py, in this file , the model is net = U2NET(3,1), becase when I run the u2net_portrait_test.py as you told, but I got this error.... Traceback (most recent call last): File "u2net_portrait_test.py", line 117, in
main()
File "u2net_portrait_test.py", line 87, in main
net.load_state_dict(torch.load(model_dir))
File "/module.py", line 721, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for U2NET:
Unexpected key(s) in state_dict: "stage1.rebnconvin.bn_s1.num_batches_tracked", "stage1.rebnconv1.bn_s1.num_batches_tracked", "stage1.rebnconv2.bn_s1.num_batches_tracked", "stage1.rebnconv3.bn_s1.num_batches_tracked", "stage1.rebnconv4.bn_s1.num_batches_tracked", "stage1.rebnconv5.bn_s1.num_batches_tracked", "stage1.rebnconv6.bn_s1.num_batches_tracked", "stage1.rebnconv7.bn_s1.num_batches_tracked", "stage1.rebnconv6d.bn_s1.num_batches_tracked", "stage1.rebnconv5d.bn_s1.num_batches_tracked", "stage1.rebnconv4d.bn_s1.num_batches_tracked", "stage1.rebnconv3d.bn_s1.num_batches_tracked", "stage1.rebnconv2d.bn_s1.num_batches_tracked", "stage1.rebnconv1d.bn_s1.num_batches_tracked", "stage2.rebnconvin.bn_s1.num_batches_tracked", "stage2.rebnconv1.bn_s1.num_batches_tracked", "stage2.rebnconv2.bn_s1.num_batches_tracked", "stage2.rebnconv3.bn_s1.num_batches_tracked", "stage2.rebnconv4.bn_s1.num_batches_tracked", "stage2.rebnc