victoresque / pytorch-template

PyTorch deep learning projects made easy.
MIT License
4.75k stars 1.09k forks source link

A little question about loading checkpoint. #78

Closed Hoodythree closed 4 years ago

Hoodythree commented 4 years ago

Great work! I have a little question about loading checkpoints. The checkpoint in your project saved in this way:

 state = {
            'arch': arch,
            'epoch': epoch,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'monitor_best': self.mnt_best,
            'config': self.config
        }
filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch))
torch.save(state, filename)

And I just want to load the checkpoint in a very simple way for some reason, just like this:

checkpoint = torch.load('model_best.pth', map_location=torch.device('cpu'))

But an error occurred:

Traceback (most recent call last): File ".\test.py", line 11, in checkpoint = torch.load(Path('model_best.pth'), map_location=torch.device('cpu')) File "C:\Users\user\anaconda3\envs\luoqiuhong\lib\site-packages\torch\serialization.py", line 593, in load return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args) File "C:\Users\user\anaconda3\envs\luoqiuhong\lib\site-packages\torch\serialization.py", line 773, in _legacy_load result = unpickler.load() ModuleNotFoundError: No module named 'parse_config'

Is there any way can do this simple checkpoints loading(without 'parse_config')? Thanks in advance.

SunQpark commented 4 years ago

Calling torch.load('model_best.pth') will require parse_config module, since that dictionary has config object as an item. You can remove this item by simply commenting out that line in the base_trainer.py.

 state = {
            'arch': arch,
            'epoch': epoch,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'monitor_best': self.mnt_best,
            # 'config': self.config
        }

Then refering to test.py#L34-L37, load checkpoint. Unfortunately, checkpoints saved this way would not be usable for resuming training process. If you don't want this or have to use existing checkpoint which was already saved with config object, you can make a simple script removing the config object from checkpoint file.

import torch

checkpoint_path = 'model_best.pth'
checkpoint = torch.load('model_best.pth', map_location=torch.device('cpu'))
del checkpoint['config']
torch.save('updated_model_best.pth')

Then, run this script at the project root dir, since it requires parse_config module.

Hoodythree commented 4 years ago

Thanks a lot.

ahmedgamaleldin14 commented 3 years ago

@SunQpark I think you are missing checkpoint in torch.save(checkpoint, 'updated_model_best.pth')