def _save_checkpoint(self, epoch, verbose=True):
r"""Store the model parameters information and training information.
Args:
epoch (int): the current epoch id
"""
state = {
"model": self.model,
"optimizer": self.optimizer,
'scheduler': self.lr_scheduler,
'config': self.config,
'epoch': epoch,
'cur_step': self.cur_step,
'best_valid_score': self.best_valid_score,
'rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state()
}
self.saved_model_name = '{}-{}.pth'.format(self.config['model'], epoch)
self.lite.save(os.path.join(self.checkpoint_dir, self.saved_model_name), state=state)
if self.rank == 0 and verbose:
self.logger.info(set_color('Saving current', 'blue') + f': {self.saved_model_file}')
if self.config['auto_resume']:
checkpoint_files = [f for f in os.listdir(self.checkpoint_dir) if f.endswith('.pth')]
if checkpoint_files:
# Assuming the latest checkpoint is the one to resume from
latest_checkpoint = max(checkpoint_files, key=lambda x: int(x.split('-')[1].split('.')[0]))
self.start_epoch = int(latest_checkpoint.split('-')[1].split('.')[0]) + 1
checkpoint_path = os.path.join(self.checkpoint_dir, latest_checkpoint)
self.logger.info(f"Loading checkpoint from {checkpoint_path}")
state = {
"model": self.model,
"optimizer": self.optimizer,
'scheduler': self.lr_scheduler,
'config': self.config,
'epoch': 0,
'cur_step': 0,
'best_valid_score': self.best_valid_score,
'rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state()
}
remainder = self.lite.load(checkpoint_path, state)
When using the zero 2 method for training, the saved checkpoint file seems to contain only the model and optimizer status. When using the above code for breakpoint retraining, the remainder shows that the model parameters are not loaded correctly. Have you encountered similar problems, or do you have other reference information that can solve it?
When using the zero 2 method for training, the saved checkpoint file seems to contain only the model and optimizer status. When using the above code for breakpoint retraining, the remainder shows that the model parameters are not loaded correctly. Have you encountered similar problems, or do you have other reference information that can solve it?