samhaswon / skin_segmentation

This repository details the various methods I have attempted for skin segmentation.
GNU General Public License v3.0
9 stars 1 forks source link

Loading the PyTorch model #1

Open davidus27 opened 4 days ago

davidus27 commented 4 days ago

Hi, I like your project. When I tried to load the PyTorch model it returned error:

My code:

    checkpoint = torch.load("skin_u2net.pth", map_location='cpu', weights_only=False)

The result:

    checkpoint = torch.load("skin_u2net.pth", map_location='cpu', weights_only=False)
  File "/usr/local/lib/python3.8/site-packages/torch/serialization.py", line 1097, in load
    return _load(
  File "/usr/local/lib/python3.8/site-packages/torch/serialization.py", line 1525, in _load
    result = unpickler.load()
  File "/usr/local/lib/python3.8/site-packages/torch/serialization.py", line 1515, in find_class
    return super().find_class(mod_name, name)
ModuleNotFoundError: No module named 'model.u2net'; 'model' is not a package

Perhaps there is some problem with how the architecture is stored?

samhaswon commented 4 days ago

It's only the state_dict of the model (just the weights), so that's the root of the issue. It's just a little bit more to load it:

from model.u2net import U2NET

# ...
net = U2NET(3, 1)
net.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device(DEVICE)))
net.eval()
# Do stuff with the model...

You can get the U2Net model code from the source here: https://github.com/xuebinqin/U-2-Net/tree/master/model