sthalles / SimCLR

PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations
https://sthalles.github.io/simple-self-supervised-learning/
MIT License
2.19k stars 457 forks source link

Including code for training from checkpoint #61

Open imjohnzakkam opened 1 year ago

imjohnzakkam commented 1 year ago

Updated the simclr.py file with including code for loading from checkpoints and other minor changes

Yadino commented 7 months ago

in load_checkpoint you start by doing a check if(os.path.exists(filepath)):

Yet you try to catch an exception when you call this function in simclr.py. It will not throw an exception because of this check, but rather fail quietly. Better do something like def load_checkpoint(model, filepath): try: assert os.path.exists(filepath) ckpt = torch.load(filepath) model.load_state_dict(ckpt['state_dict']) epoch = ckpt['epoch'] return model, epoch except Exception: raise InvalidCheckpointPath()