Open davidus27 opened 4 days ago
It's only the state_dict
of the model (just the weights), so that's the root of the issue. It's just a little bit more to load it:
from model.u2net import U2NET
# ...
net = U2NET(3, 1)
net.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device(DEVICE)))
net.eval()
# Do stuff with the model...
You can get the U2Net model code from the source here: https://github.com/xuebinqin/U-2-Net/tree/master/model
Hi, I like your project. When I tried to load the PyTorch model it returned error:
My code:
The result:
Perhaps there is some problem with how the architecture is stored?