bytedance / HLLM

HLLM: Enhancing Sequential Recommendations via Hierarchical Large Language Models for Item and User Modeling
Apache License 2.0
172 stars 21 forks source link

checkpoint resume #18

Closed threestone965 closed 6 days ago

threestone965 commented 6 days ago
def _save_checkpoint(self, epoch, verbose=True):
    r"""Store the model parameters information and training information.

    Args:
        epoch (int): the current epoch id

    """
    state = {
        "model": self.model,
        "optimizer": self.optimizer,
        'scheduler': self.lr_scheduler, 
        'config': self.config,
        'epoch': epoch,
        'cur_step': self.cur_step,
        'best_valid_score': self.best_valid_score,
        'rng_state': torch.get_rng_state(),
        'cuda_rng_state': torch.cuda.get_rng_state()
    }
    self.saved_model_name = '{}-{}.pth'.format(self.config['model'], epoch)
    self.lite.save(os.path.join(self.checkpoint_dir, self.saved_model_name), state=state)
    if self.rank == 0 and verbose:
        self.logger.info(set_color('Saving current', 'blue') + f': {self.saved_model_file}')

    if self.config['auto_resume']:
        checkpoint_files = [f for f in os.listdir(self.checkpoint_dir) if f.endswith('.pth')]
        if checkpoint_files:
            # Assuming the latest checkpoint is the one to resume from
            latest_checkpoint = max(checkpoint_files, key=lambda x: int(x.split('-')[1].split('.')[0]))
            self.start_epoch = int(latest_checkpoint.split('-')[1].split('.')[0]) + 1
            checkpoint_path = os.path.join(self.checkpoint_dir, latest_checkpoint)
            self.logger.info(f"Loading checkpoint from {checkpoint_path}")

            state = {
                "model": self.model,
                "optimizer": self.optimizer,
                'scheduler': self.lr_scheduler,
                'config': self.config,
                'epoch': 0,
                'cur_step': 0,
                'best_valid_score': self.best_valid_score,
                'rng_state': torch.get_rng_state(),
                'cuda_rng_state': torch.cuda.get_rng_state()
            }
            remainder = self.lite.load(checkpoint_path, state)

When using the zero 2 method for training, the saved checkpoint file seems to contain only the model and optimizer status. When using the above code for breakpoint retraining, the remainder shows that the model parameters are not loaded correctly. Have you encountered similar problems, or do you have other reference information that can solve it?