X-LANCE / SLAM-LLM

Speech, Language, Audio, Music Processing with Large Language Model
MIT License
595 stars 53 forks source link

Resume training from checkpoint with same hyperparameter #146

Open Amg9794 opened 1 month ago

Amg9794 commented 1 month ago

šŸš€ The feature, motivation and pitch

Hi

I have trained trained both speech encoder (whisperL-v3) and linear projector with freezed llama 3.2- 1b model for ASR task . I found that all steps were completed but eval loss still not saturate and there was still room for improvement in the model.

Now when i started resuming training from last saved checkpoint(which just saved trainable parameter (using original code's method to resume) . i found that result got degraded which was unexpected.

is there any way to resume training from same state with last saved hyperparameter.

I write this function to save checkpoint which save all detail like this

def save_model_checkpoint_peft(model, optimizer, lr_scheduler, epoch, step, best_val_loss, best_val_acc, scaler, cfg, checkpoint_name="checkpoint"): logger.info(f"--> saving model checkpoint...") save_dir = os.path.join(cfg.output_dir, checkpoint_name) os.makedirs(save_dir, exist_ok=True) save_full_path = os.path.join(save_dir, "checkpoint.pt")

if cfg.enable_ddp:
    model = model.module

# Save only trainable parameters
trainable_params = OrderedDict()
for name, param in model.named_parameters():
    if param.requires_grad:
        trainable_params[name] = param.data.cpu()

checkpoint = {
    'model_state_dict': trainable_params,
    'optimizer_state_dict': optimizer.state_dict(),
    'lr_scheduler_state_dict': lr_scheduler.state_dict() if lr_scheduler else None,
    'epoch': epoch,
    'step': step,
    'best_val_loss': best_val_loss,
    'best_val_acc': best_val_acc,
    'random_state': torch.get_rng_state(),
    'cuda_random_state': torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None,
    'config': cfg.__dict__,
    'scaler': scaler.state_dict() if scaler else None,
}

torch.save(checkpoint, save_full_path)
logger.info(f"Checkpoint saved at {save_full_path}")

can some one help me with this ->does all these details are necessary to save ? and also how to use this aved lr_scheduler in train function .

it will be a great help for me and others too . 

Thank you

Alternatives

No response

Additional context

No response

ddlBoJack commented 1 month ago

Hi, we did not implement the resuming of hyperparameters. Only the model parameters are saved. We want to implement it if time permits and welcome to contribute.