FunAudioLLM / CosyVoice

Multi-lingual large voice generation model, providing inference, training and deployment full-stack ability.
https://funaudiollm.github.io/
Apache License 2.0
4.89k stars 497 forks source link

Restart training from a checkpoint, with steps, etc #282

Open rlenain opened 1 month ago

rlenain commented 1 month ago

Hello,

I was wondering whether there was an easy to restart training from a checkpoint, resuming the steps, epochs, optimizer states, etc. This is in case training dies, but we want to restart training not from epoch 0.

Thanks

aluminumbox commented 1 month ago

specify --checkout, but the step epoch will start from 0

rlenain commented 1 month ago

do you mean --checkpoint? I've tried that, but it doesn't work. I think the LR scheduler being different makes the training loss go kind of crazy and in the end actually not returning to the same place

CriDora commented 3 weeks ago

do you mean --checkpoint? I've tried that, but it doesn't work. I think the LR scheduler being different makes the training loss go kind of crazy and in the end actually not returning to the same place

Hello, have you solved this problem? I also specified --checkpoint, and the learning rate of training also started from 0

drlor2k commented 3 weeks ago

I think the author will update the code for this issue, however I have a temporary solution you can refer to.

  1. When starting a new training session, you need to change the path of the last checkpoint in run.sh. Example: --checkpoint CosyVoice/examples/libritts/cosyvoice/exp/cosyvoice/llm/torch_ddp/epoch_2_whole.pt \

  2. Edit a little code in the file CosyVoice/cosyvoice/bin/train.py

    
    # Save init checkpoints
    info_dict = deepcopy(configs['train_conf'])
    save_model(model, 'init', info_dict)
    
    current_epoch = info_dict['current_epoch'] # add
    current_step = info_dict['current_step']       # add
    
    # Get executor
    executor = Executor()
    
    # Start training loop
    for epoch in range(current_epoch, info_dict['max_epoch']): # change
        executor.epoch = epoch
        executor.step = current_step # add
        train_dataset.set_epoch(epoch)
        dist.barrier()
        group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
        executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
        dist.destroy_process_group(group_join)

3. Add config in file: `CosyVoice/examples/libritts/cosyvoice/conf/cosyvoice.yaml`

Note: current_epoch and current_step correspond to the checkpoint you change in the `run.sh` file

Example:

train_conf: optim: adam optim_conf: lr: 0.001 # change to 1e-5 during sft scheduler: warmuplr # change to constantlr during sft scheduler_conf: warmup_steps: 2500 max_epoch: 200 grad_clip: 5 accum_grad: 2 log_interval: 100 save_per_step: 1500 current_epoch: 2 # add current_step: 1311 # add

CriDora commented 3 weeks ago

I think the author will update the code for this issue, however I have a temporary solution you can refer to.

  1. When starting a new training session, you need to change the path of the last checkpoint in run.sh. Example: --checkpoint CosyVoice/examples/libritts/cosyvoice/exp/cosyvoice/llm/torch_ddp/epoch_2_whole.pt \
  2. Edit a little code in the file CosyVoice/cosyvoice/bin/train.py
    # Save init checkpoints
    info_dict = deepcopy(configs['train_conf'])
    save_model(model, 'init', info_dict)

    current_epoch = info_dict['current_epoch'] # add
    current_step = info_dict['current_step']       # add

    # Get executor
    executor = Executor()

    # Start training loop
    for epoch in range(current_epoch, info_dict['max_epoch']): # change
        executor.epoch = epoch
        executor.step = current_step # add
        train_dataset.set_epoch(epoch)
        dist.barrier()
        group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
        executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
        dist.destroy_process_group(group_join)
  1. Add config in file: CosyVoice/examples/libritts/cosyvoice/conf/cosyvoice.yaml

Note: current_epoch and current_step correspond to the checkpoint you change in the run.sh file

Example:

train_conf:
    optim: adam
    optim_conf:
        lr: 0.001 # change to 1e-5 during sft
    scheduler: warmuplr # change to constantlr during sft
    scheduler_conf:
        warmup_steps: 2500
    max_epoch: 200
    grad_clip: 5
    accum_grad: 2
    log_interval: 100
    save_per_step: 1500
    current_epoch: 2    # add
    current_step: 1311 # add

Thank you, after modifying the code you provided, the checkpoint can be loaded normally.

drlor2k commented 2 weeks ago

I thought the above code would make step go back to current_step when starting a new epoch. I modified it a bit.

    # Save init checkpoints
    info_dict = deepcopy(configs['train_conf'])
    save_model(model, 'init', info_dict)

    # ADD
    current_epoch = info_dict['current_epoch']
    current_step  = info_dict['current_step']
    start_session = True

    # Get executor
    executor = Executor()

    # Start training loop
    for epoch in range(current_epoch, info_dict['max_epoch']): # change
        executor.epoch = epoch
        if start_session:
            executor.step = current_step  # add
        train_dataset.set_epoch(epoch)
        dist.barrier()
        group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
        executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
        dist.destroy_process_group(group_join)
        start_session = False
CriDora commented 2 weeks ago

I thought the above code would make step go back to current_step when starting a new epoch. I modified it a bit.

    # Save init checkpoints
    info_dict = deepcopy(configs['train_conf'])
    save_model(model, 'init', info_dict)

    # ADD
    current_epoch = info_dict['current_epoch']
    current_step  = info_dict['current_step']
    start_session = True

    # Get executor
    executor = Executor()

    # Start training loop
    for epoch in range(current_epoch, info_dict['max_epoch']): # change
        executor.epoch = epoch
        if start_session:
            executor.step = current_step  # add
        train_dataset.set_epoch(epoch)
        dist.barrier()
        group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
        executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
        dist.destroy_process_group(group_join)
        start_session = False

Thanks, you are right. Do I need to modify the current_epoch and current_step values ​​in cosyvoice.fromscratch.yaml every time I resume training from a checkpoint?

drlor2k commented 2 weeks ago

Yes, you need to modify the current_epoch and current_step values ​​in cosyvoice.fromscratch.yaml every time you resume training from a checkpoint.

You can see current_epoch and current_step in the checkpoint filename. Example: epoch_1_step_33000.pt

CriDora commented 5 days ago

Yes, you need to modify the current_epoch and current_step values ​​in cosyvoice.fromscratch.yaml every time you resume training from a checkpoint.

You can see current_epoch and current_step in the checkpoint filename. Example: epoch_1_step_33000.pt

Hello, sorry to bother you again. I want to confirm whether the learning rate will drop normally after restoring the checkpoint in your code above, because after I reply according to your code, although the number of restored steps is correct, the learning rate still starts to warm up again.