xuebinqin / U-2-Net

The code for our newly accepted paper in Pattern Recognition 2020: "U^2-Net: Going Deeper with Nested U-Structure for Salient Object Detection."
Apache License 2.0
8.31k stars 1.43k forks source link

Continue training, train my own model with U-2-Net, in between due to some reasons the training was interrupted or I want to strengthen an existing model, what should I do? Can you provide a 'Continue-training.py'? #358

Closed panpengfei21 closed 1 year ago

panpengfei21 commented 1 year ago

Continue training, train my model with U-2-Net, in between due to some reasons the training was interrupted or I want to strengthen an existing model, what should I do? Can you provide a 'Continue-training.py'?

ghost commented 1 year ago
u2net_train.py

if torch.cuda.is_available():
        net.load_state_dict(torch.load(your model path))
        net.cuda()
panpengfei21 commented 1 year ago

thank you, I will try.

panpengfei21 commented 1 year ago
u2net_train.py

if torch.cuda.is_available():
        net.load_state_dict(torch.load(your model path))
        net.cuda()

My computer is MacBook Pro, so torch.cuda.is_available() is false. I change your code like this:

    checkpoint = torch.load("/xxx/xxx/xxx.pth", map_location='cpu')
    net.load_state_dict(checkpoint)
    if torch.cuda.is_available():
        net.cuda()

Is it OK?

ghost commented 1 year ago

double slash "//xxx//xxx//xxx.pth"

should work