yl4579 / StyleTTS2

StyleTTS 2: Towards Human-Level Text-to-Speech through Style Diffusion and Adversarial Training with Large Speech Language Models
MIT License
4.98k stars 420 forks source link

Extremely weird DDP issue for train_second.py #7

Open yl4579 opened 1 year ago

yl4579 commented 1 year ago

So far train_second.py only works with DataParallel (DP) but not DistributedDataParalell (DDP). One major problem with this is if we simply translate DP to DDP (code in the comment section), we encounter the following problem:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [512]] is at version 6; expected version 5 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck! It is insanely difficult to debug. The tensor has no batch dimension, indicating it might be a parameter in the neural network. I found the tensor to be the bias term of the last Conv1D layer of predictor_encoder (prosodic style encoder): https://github.com/yl4579/StyleTTS2/blob/main/models.py#L152. This is extremely weird because the problem does not trigger for any Conv1D layer before this.

More mysteriously, issue surprisingly disappears if we add model.predictor_encoder.train() near line 250 of train_second.py. However, this causes the F0 loss to be much higher than without this line. This is true for both DP and DDP, so the higher F0 loss value is caused by model.predictor_encoder.train(), not DDP. Unfortunately, the predictor_encoder, which is StyleEncoder, has no module that changes the behavior depending on whether it is in train or eval mode. The output is exactly the same whether it is set to train or eval.

TLDR: There are three issues with train_second.py:

  1. DDP does not work because of the in-place operation error. The error disappears if model.predictor_encoder.train() before training.
  2. However, model.predictor_encoder.train() causes F0 loss to be much higher after convergence. This issue is independent of using DP or DDP.
  3. model.predictor_encoder is an instantiation of StyleEncoder, which has no components that change the output depending on its train or eval mode.

This problem has bugged me for more than a month, but I can't find a solution to it. It would be greatly appreciated if anyone has any insight into how to fix this problem. I have pasted the broken DDP code with accelerator below.

yl4579 commented 1 year ago

The following is the broken (and unfinished) code for train_second.py with DDP:

# load packages
import random
import yaml
import time
from munch import Munch
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
import librosa
import click
import shutil
import warnings
warnings.simplefilter('ignore')
from torch.utils.tensorboard import SummaryWriter

from meldataset import build_dataloader

from Utils.ASR.models import ASRCNN
from Utils.JDC.model import JDCNet
from Utils.PLBERT.util import load_plbert

from models import *
from losses import *
from utils import *
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule

from optimizers import build_optimizer 

from accelerate import Accelerator
from accelerate.utils import LoggerType
from accelerate import DistributedDataParallelKwargs

from torch.utils.tensorboard import SummaryWriter

import logging
from accelerate.logging import get_logger
logger = get_logger(__name__, log_level="DEBUG")

def _load(states, model, force_load=True):
    model_states = model.state_dict()
    for key, val in states.items():
        try:
            if key not in model_states:
                continue
            if isinstance(val, nn.Parameter):
                val = val.data

            if val.shape != model_states[key].shape:
                print("%s does not have same shape" % key)
                print(val.shape, model_states[key].shape)
                if not force_load:
                    continue

                min_shape = np.minimum(np.array(val.shape), np.array(model_states[key].shape))
                slices = [slice(0, min_index) for min_index in min_shape]
                model_states[key][slices].copy_(val[slices])
            else:
                model_states[key].copy_(val)
        except:
            print("not exist :%s" % key)
            print("not exist ", key)

@click.command()
@click.option('-p', '--config_path', default='Configs/config.yml', type=str)
def main(config_path):
    config = yaml.safe_load(open(config_path))

    log_dir = config['log_dir']
    if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
    shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    accelerator = Accelerator(project_dir=log_dir, split_batches=True, kwargs_handlers=[ddp_kwargs])    
    if accelerator.is_main_process:
        writer = SummaryWriter(log_dir + "/tensorboard")

    # write logs
    file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
    file_handler.setLevel(logging.DEBUG)
    file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
    logger.logger.addHandler(file_handler)

    batch_size = config.get('batch_size', 10)

    epochs = config.get('epochs_2nd', 200)
    save_freq = config.get('save_freq', 2)
    log_interval = config.get('log_interval', 10)
    saving_epoch = config.get('save_freq', 2)

    data_params = config.get('data_params', None)
    sr = config['preprocess_params'].get('sr', 24000)
    train_path = data_params['train_data']
    val_path = data_params['val_data']
    root_path = data_params['root_path']
    min_length = data_params['min_length']
    OOD_data = data_params['OOD_data']

    max_len = config.get('max_len', 200)

    loss_params = Munch(config['loss_params'])
    diff_epoch = loss_params.diff_epoch
    joint_epoch = loss_params.joint_epoch

    optimizer_params = Munch(config['optimizer_params'])

    train_list, val_list = get_data_path_list(train_path, val_path)
    device = accelerator.device

    train_dataloader = build_dataloader(train_list,
                                        root_path,
                                        OOD_data=OOD_data,
                                        min_length=min_length,
                                        batch_size=batch_size,
                                        num_workers=2,
                                        dataset_config={},
                                        device=device)

    val_dataloader = build_dataloader(val_list,
                                      root_path,
                                      OOD_data=OOD_data,
                                      min_length=min_length,
                                      batch_size=batch_size,
                                      validation=True,
                                      num_workers=0,
                                      device=device,
                                      dataset_config={})
    with accelerator.main_process_first():
        # load pretrained ASR model
        ASR_config = config.get('ASR_config', False)
        ASR_path = config.get('ASR_path', False)
        text_aligner = load_ASR_models(ASR_path, ASR_config)

        # load pretrained F0 model
        F0_path = config.get('F0_path', False)
        pitch_extractor = load_F0_models(F0_path)

        # load PL-BERT model
        BERT_path = config.get('PLBERT_dir', False)
        plbert = load_plbert(BERT_path)

    # build model
    model_params = recursive_munch(config['model_params'])
    multispeaker = model_params.multispeaker
    model = build_model(model_params, text_aligner, pitch_extractor, plbert)
    _ = [model[key].to(device) for key in model]

    # DDP
    for k in model:
        model[k] = accelerator.prepare(model[k])
    model.predictor._set_static_graph()

    train_dataloader, val_dataloader = accelerator.prepare(
        train_dataloader, val_dataloader
    )

    start_epoch = 0
    iters = 0

    load_pretrained = config.get('pretrained_model', '') != '' and config.get('second_stage_load_pretrained', False)

    if not load_pretrained:
        if config.get('first_stage_path', '') != '':
            first_stage_path = osp.join(log_dir, config.get('first_stage_path', 'first_stage.pth'))
            print('Loading the first stage model at %s ...' % first_stage_path)
            model, _, start_epoch, iters = load_checkpoint(model, 
                None, 
                first_stage_path,
                load_only_params=True,
                ignore_modules=['bert', 'bert_encoder', 'predictor', 'predictor_encoder', 'msd', 'mpd', 'wd', 'diffusion']) # keep starting epoch for tensorboard log

            # these epochs should be counted from the start epoch
            diff_epoch += start_epoch
            joint_epoch += start_epoch
            epochs += start_epoch

#             model.predictor_encoder = copy.deepcopy(model.style_encoder)
            _load(model.style_encoder.state_dict(), model.predictor_encoder)
        else:
            raise ValueError('You need to specify the path to the first stage model.') 

    gl = GeneratorLoss(model.mpd, model.msd).to(device)
    dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
    wl = WavLMLoss(model_params.slm.model, 
                   model.wd, 
                   sr, 
                   model_params.slm.sr).to(device)

    gl = accelerator.prepare(gl)
    dl = accelerator.prepare(dl)
    wl = accelerator.prepare(wl)

    try:
        n_down = model.text_aligner.module.n_down
        distributed = True
    except:
        n_down = model.text_aligner.n_down
        distributed = False

    sampler = DiffusionSampler(
        model.diffusion.module.diffusion if distributed else model.diffusion.diffusion,
        sampler=ADPM2Sampler(),
        sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
        clamp=False
    )

    scheduler_params = {
        "max_lr": optimizer_params.lr,
        "pct_start": float(0),
        "epochs": epochs,
        "steps_per_epoch": len(train_dataloader),
    }
    scheduler_params_dict= {key: scheduler_params.copy() for key in model}
    scheduler_params_dict['bert']['max_lr'] = optimizer_params.bert_lr * 2

    optimizer = build_optimizer({key: model[key].parameters() for key in model},
                                          scheduler_params_dict=scheduler_params_dict, lr=optimizer_params.lr)

    # adjust BERT learning rate
    for g in optimizer.optimizers['bert'].param_groups:
        g['betas'] = (0.9, 0.99)
        g['lr'] = optimizer_params.bert_lr
        g['initial_lr'] = optimizer_params.bert_lr
        g['min_lr'] = 0
        g['weight_decay'] = 0.01

    # adjust acoustic module learning rate
    for module in ["decoder", "style_encoder"]:
        for g in optimizer.optimizers[module].param_groups:
            g['betas'] = (0.0, 0.99)
            g['lr'] = optimizer_params.ft_lr
            g['initial_lr'] = optimizer_params.ft_lr
            g['min_lr'] = 0
            g['weight_decay'] = 1e-4

    for k, v in optimizer.optimizers.items():
        optimizer.optimizers[k] = accelerator.prepare(optimizer.optimizers[k])
        optimizer.schedulers[k] = accelerator.prepare(optimizer.schedulers[k])

    # load models if there is a model
    if load_pretrained:
        with accelerator.main_process_first():
            model, optimizer, start_epoch, iters = load_checkpoint(model,  optimizer, config['pretrained_model'],
                                        load_only_params=config.get('load_only_params', True))

    best_loss = float('inf')  # best test loss
    loss_train_record = list([])
    loss_test_record = list([])
    iters = 0

    criterion = nn.L1Loss() # F0 loss (regression)
    torch.cuda.empty_cache()

    stft_loss = MultiResolutionSTFTLoss().to(device)
    stft_loss = accelerator.prepare(stft_loss)

    print(optimizer.optimizers['bert'])

    start_ds = False

    for epoch in range(start_epoch, epochs):
        running_loss = 0
        start_time = time.time()

        _ = [model[key].eval() for key in model]

        model.predictor.train()
#        model.predictor_encoder.train() # uncomment this line will fix the in-place operation problem but will give you a higher F0 loss and worse model
        model.bert_encoder.train()
        model.bert.train()
        model.msd.train()
        model.mpd.train()

        if epoch >= diff_epoch:
            start_ds = True

        for i, batch in enumerate(train_dataloader):
            waves = batch[0]
            batch = [b.to(device) for b in batch[1:]]
            texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch

            with torch.no_grad():
                mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device)
                mel_mask = length_to_mask(mel_input_length).to(device)
                text_mask = length_to_mask(input_lengths).to(texts.device)

                try:
                    _, _, s2s_attn = model.text_aligner(mels, mask, texts)
                    s2s_attn = s2s_attn.transpose(-1, -2)
                    s2s_attn = s2s_attn[..., 1:]
                    s2s_attn = s2s_attn.transpose(-1, -2)
                except:
                    continue

                mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
                s2s_attn_mono = maximum_path(s2s_attn, mask_ST)

                # encode
                t_en = model.text_encoder(texts, input_lengths, text_mask)
                asr = (t_en @ s2s_attn_mono)

                d_gt = s2s_attn_mono.sum(axis=-1).detach()

            # compute the style of the entire utterance
            # this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool)
            ss = []
            gs = []
            for bib in range(len(mel_input_length)):
                mel_length = int(mel_input_length[bib].item())
                mel = mels[bib, :, :mel_input_length[bib]]
                s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
                ss.append(s)
                s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
                gs.append(s)

            s_dur = torch.stack(ss).squeeze()  # global prosodic styles
            gs = torch.stack(gs).squeeze() # global acoustic styles
            s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser

            bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
            d_en = model.bert_encoder(bert_dur).transpose(-1, -2) 

            # denoiser training
            if epoch >= diff_epoch:
                num_steps = np.random.randint(3, 5)

                if model_params.diffusion.dist.estimate_sigma_data:
                    model.diffusion.module.diffusion.sigma_data = s_trg.std().item()

                if multispeaker:
                    s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device), 
                          embedding=bert_dur,
                          embedding_scale=1,
                                   features=ref, # reference from the same speaker as the embedding
                             embedding_mask_proba=0.1,
                             num_steps=num_steps).squeeze(1)
                    loss_diff = model.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean() # EDM loss
                    loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
                else:
                    s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device), 
                          embedding=bert_dur,
                          embedding_scale=1,
                             embedding_mask_proba=0.1,
                             num_steps=num_steps).squeeze(1)                    
                    loss_diff = model.diffusion.module.diffusion(s_trg.unsqueeze(1), embedding=bert_dur).mean() # EDM loss
                    loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
            else:
                loss_sty = 0
                loss_diff = 0

            d, p = model.predictor(d_en, s_dur, 
                                                    input_lengths, 
                                                    s2s_attn_mono, 
                                                    text_mask)

            # get clips
            mel_input_length_all = accelerator.gather(mel_input_length) # for balanced load
            mel_len = min([int(mel_input_length_all.min().item() / 2 - 1), max_len // 2])

            en = []
            gt = []
            p_en = []
            wav = []

            for bib in range(len(mel_input_length)):
                mel_length = int(mel_input_length[bib].item() / 2)

                random_start = np.random.randint(0, mel_length - mel_len)
                en.append(asr[bib, :, random_start:random_start+mel_len])
                p_en.append(p[bib, :, random_start:random_start+mel_len])
                gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])

                y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
                wav.append(torch.from_numpy(y).to(device))

            wav = torch.stack(wav).float().detach()

            en = torch.stack(en)
            p_en = torch.stack(p_en)
            gt = torch.stack(gt).detach()

            if gt.size(-1) < 80:
                continue

            s_dur = model.predictor_encoder(gt.unsqueeze(1))

            with torch.no_grad():
                s = model.style_encoder(gt.unsqueeze(1))

                F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
                F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()

                N_real = log_norm(gt.unsqueeze(1)).squeeze(1)

                # ground truth from reconstruction
                y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
                # ground truth from recording
                y_rec_gt = wav.unsqueeze(1)

                if epoch >= joint_epoch:
                    wav = y_rec_gt # use recording since decoder is tuned
                else:
                    wav = y_rec_gt_pred # use reconstruction since decoder is fixed

            F0_fake, N_fake = model.predictor.module.F0Ntrain(p_en, s_dur)

            y_rec = model.decoder(en, F0_fake, N_fake, s)

            loss_F0_rec =  (F.smooth_l1_loss(F0_real, F0_fake)) / 10
            loss_norm_rec = F.smooth_l1_loss(N_real, N_fake)

            if start_ds:
                optimizer.zero_grad()
                d_loss = dl(wav.detach(), y_rec.detach()).mean()
                accelerator.backward(d_loss)
                optimizer.step('msd')
                optimizer.step('mpd')
            else:
                d_loss = 0

            # generator loss
            optimizer.zero_grad()

            loss_mel = stft_loss(y_rec, wav)
            if start_ds:
                loss_gen_all = gl(wav, y_rec).mean()
            else:
                loss_gen_all = 0
            loss_lm = wl(wav.detach().squeeze(), y_rec.squeeze()).mean()

            loss_ce = 0
            loss_dur = 0
            for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
                _s2s_pred = _s2s_pred[:_text_length, :]
                _text_input = _text_input[:_text_length].long()
                _s2s_trg = torch.zeros_like(_s2s_pred)
                for p in range(_s2s_trg.shape[0]):
                    _s2s_trg[p, :_text_input[p]] = 1
                _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)

                loss_dur += F.l1_loss(_dur_pred[1:_text_length-1], 
                                       _text_input[1:_text_length-1])
                loss_ce += F.binary_cross_entropy_with_logits(_s2s_pred.flatten(), _s2s_trg.flatten())

            loss_ce /= texts.size(0)
            loss_dur /= texts.size(0)

            g_loss = loss_params.lambda_mel * loss_mel + \
                     loss_params.lambda_F0 * loss_F0_rec + \
                     loss_params.lambda_ce * loss_ce + \
                     loss_params.lambda_norm * loss_norm_rec + \
                     loss_params.lambda_dur * loss_dur + \
                     loss_params.lambda_gen * loss_gen_all + \
                     loss_params.lambda_slm * loss_lm

            running_loss += accelerator.gather(loss_mel).mean().item()

            with torch.autograd.set_detect_anomaly(True):
                accelerator.backward(g_loss)

            if torch.isnan(g_loss):
                from IPython.core.debugger import set_trace
                set_trace()

            optimizer.step('bert_encoder')
            optimizer.step('bert')
            optimizer.step('predictor')
            optimizer.step('predictor_encoder')

            if epoch >= diff_epoch:
                optimizer.step('diffusion')

            if epoch >= joint_epoch:
                optimizer.step('style_encoder')
                optimizer.step('decoder')

            iters = iters + 1
            d_loss_slm = 0
            loss_gen_lm = 0

            if (i+1)%log_interval == 0 and accelerator.is_main_process:
                print ('Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f'
                    %(epoch+1, epochs, i+1, len(train_list)//batch_size, running_loss / log_interval, d_loss, loss_dur, loss_ce, loss_norm_rec, loss_F0_rec, loss_lm, loss_gen_all, loss_sty, loss_diff, d_loss_slm, loss_gen_lm))

                writer.add_scalar('train/mel_loss', running_loss / log_interval, iters)
                writer.add_scalar('train/gen_loss', loss_gen_all, iters)
                writer.add_scalar('train/d_loss', d_loss, iters)
                writer.add_scalar('train/ce_loss', loss_ce, iters)
                writer.add_scalar('train/dur_loss', loss_dur, iters)
                writer.add_scalar('train/slm_loss', loss_lm, iters)
                writer.add_scalar('train/norm_loss', loss_norm_rec, iters)
                writer.add_scalar('train/F0_loss', loss_F0_rec, iters)
                writer.add_scalar('train/sty_loss', loss_sty, iters)
                writer.add_scalar('train/diff_loss', loss_diff, iters)
                writer.add_scalar('train/d_loss_slm', d_loss_slm, iters)
                writer.add_scalar('train/gen_loss_slm', loss_gen_lm, iters)

                running_loss = 0

                print('Time elasped:', time.time()-start_time)

        loss_test = 0
        loss_align = 0
        loss_f = 0
        _ = [model[key].eval() for key in model]

        with torch.no_grad():
            iters_test = 0
            for batch_idx, batch in enumerate(val_dataloader):
                optimizer.zero_grad()

                try:
                    waves = batch[0]
                    batch = [b.to(device) for b in batch[1:]]
                    texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
                    with torch.no_grad():
                        mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda')
                        text_mask = length_to_mask(input_lengths).to(texts.device)

                        _, _, s2s_attn = model.text_aligner(mels, mask, texts)
                        s2s_attn = s2s_attn.transpose(-1, -2)
                        s2s_attn = s2s_attn[..., 1:]
                        s2s_attn = s2s_attn.transpose(-1, -2)

                        mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
                        s2s_attn_mono = maximum_path(s2s_attn, mask_ST)

                        # encode
                        t_en = model.text_encoder(texts, input_lengths, text_mask)
                        asr = (t_en @ s2s_attn_mono)

                        d_gt = s2s_attn_mono.sum(axis=-1).detach()

                    ss = []
                    gs = []

                    for bib in range(len(mel_input_length)):
                        mel_length = int(mel_input_length[bib].item())
                        mel = mels[bib, :, :mel_input_length[bib]]
                        s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
                        ss.append(s)
                        s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
                        gs.append(s)

                    s = torch.stack(ss).squeeze()
                    gs = torch.stack(gs).squeeze()
                    s_trg = torch.cat([s, gs], dim=-1).detach()

                    bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
                    d_en = model.bert_encoder(bert_dur).transpose(-1, -2) 
                    d, p = model.predictor(d_en, s, 
                                                        input_lengths, 
                                                        s2s_attn_mono, 
                                                        text_mask)
                    # get clips
                    mel_len = int(mel_input_length.min().item() / 2 - 1)
                    en = []
                    gt = []
                    p_en = []
                    wav = []

                    for bib in range(len(mel_input_length)):
                        mel_length = int(mel_input_length[bib].item() / 2)

                        random_start = np.random.randint(0, mel_length - mel_len)
                        en.append(asr[bib, :, random_start:random_start+mel_len])
                        p_en.append(p[bib, :, random_start:random_start+mel_len])

                        gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])

                        y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
                        wav.append(torch.from_numpy(y).to(device))

                    wav = torch.stack(wav).float().detach()

                    en = torch.stack(en)
                    p_en = torch.stack(p_en)
                    gt = torch.stack(gt).detach()

                    s = model.predictor_encoder(gt.unsqueeze(1))

                    F0_fake, N_fake = model.predictor.module.F0Ntrain(p_en, s)

                    loss_dur = 0
                    for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
                        _s2s_pred = _s2s_pred[:_text_length, :]
                        _text_input = _text_input[:_text_length].long()
                        _s2s_trg = torch.zeros_like(_s2s_pred)
                        for bib in range(_s2s_trg.shape[0]):
                            _s2s_trg[bib, :_text_input[bib]] = 1
                        _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
                        loss_dur += F.l1_loss(_dur_pred[1:_text_length-1], 
                                               _text_input[1:_text_length-1])

                    loss_dur /= texts.size(0)

                    s = model.style_encoder(gt.unsqueeze(1))

                    y_rec = model.decoder(en, F0_fake, N_fake, s)
                    loss_mel = stft_loss(y_rec.squeeze(), wav.detach())

                    F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1)) 

                    loss_F0 = F.l1_loss(F0_real, F0_fake) / 10

                    loss_test += accelerator.gather(loss_mel).mean()
                    loss_align += accelerator.gather(loss_dur).mean()
                    loss_f += accelerator.gather(loss_F0).mean()

                    iters_test += 1
                except:
                    continue
        if accelerator.is_main_process:

            print('Epochs:', epoch + 1)
            print('Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f' % (loss_test / iters_test, loss_align / iters_test, loss_f / iters_test) + '\n\n\n')
            print('\n\n\n')
            writer.add_scalar('eval/mel_loss', loss_test / iters_test, epoch + 1)
            writer.add_scalar('eval/dur_loss', loss_test / iters_test, epoch + 1)
            writer.add_scalar('eval/F0_loss', loss_f / iters_test, epoch + 1)

            with torch.no_grad():
                for bib in range(len(asr)):
                    mel_length = int(mel_input_length[bib].item())
                    gt = mels[bib, :, :mel_length].unsqueeze(0)
                    en = asr[bib, :, :mel_length // 2].unsqueeze(0)

                    F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
                    F0_real = F0_real.unsqueeze(0)
                    s = model.style_encoder(gt.unsqueeze(1))
                    real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)

                    y_rec = model.decoder(en, F0_real, real_norm, s)

                    writer.add_audio('eval/y' + str(bib), y_rec.cpu().numpy().squeeze(), epoch, sample_rate=sr)

                    s_dur = model.predictor_encoder(gt.unsqueeze(1))
                    p_en = p[bib, :, :mel_length // 2].unsqueeze(0)

                    F0_fake, N_fake = model.predictor.module.F0Ntrain(p_en, s_dur)

                    y_pred = model.decoder(en, F0_fake, N_fake, s)

                    writer.add_audio('pred/y' + str(bib), y_pred.cpu().numpy().squeeze(), epoch, sample_rate=sr)

                    if epoch == 0:
                        writer.add_audio('gt/y' + str(bib), waves[bib].squeeze(), epoch, sample_rate=sr)

                    if bib >= 5:
                        break

            if epoch % saving_epoch == 0:
                if (loss_test / iters_test) < best_loss:
                    best_loss = loss_test / iters_test
                print('Saving..')
                state = {
                    'net':  {key: model[key].state_dict() for key in model}, 
                    'optimizer': optimizer.state_dict(),
                    'iters': iters,
                    'val_loss': loss_test / iters_test,
                    'epoch': epoch,
                }
                save_path = osp.join(log_dir, 'epoch_2nd_%05d.pth' % epoch)
                torch.save(state, save_path)

if __name__=="__main__":
    main()
zhouyong64 commented 1 year ago

Did you try other versions of PyTorch?

yl4579 commented 1 year ago

@zhouyong64 This issue (in-place operation) was first identified by @ABC0408, who used a different PyTorch version than me, though we both used the PyTorch > 2.0. Not sure if it is relevant. I will try PyTorch < 2.0 when I get time.

stevenhillis commented 1 year ago

This is a pytorch gotcha at the intersection of ddp, buffers, and gans (multiple forward passes). DDP modules broadcast the root process module's buffers at every forward pass, which is treated as an inplace op. Buffers show up mostly from batchnorm, which can be solved by using syncbatchnorm. Instancenorm can also be a culprit, since it inherits from the same primitive as batchnorm, but only if track_running_stats is set to True, which it isn't. I think the culprit in this case is probably spectral_norm, which has buffers but is also supposed to handle this broadcasting issue by cloning (reference https://pytorch.org/docs/stable/_modules/torch/nn/utils/spectral_norm.html#SpectralNorm). Not sure why that wouldn't be working here, but regardless of the root cause, you can disable the broadcasting by changing the ddp kwargs to be ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True, broadcast_buffers=False). Please let me know if this results in the same regression with F0 loss as calling .train() did!

yl4579 commented 1 year ago

Hi @stevenhillis , thanks for your help. The problem happens even before the discriminator kicks in, so it is unlikely caused by spectral_norm. I have tried your suggestions and luckily the in-place error disappeared without having to set model.predictor_encoder.train()! I will keep you posted when I have extra GPUs but now my GPUs are used for the multispeaker LibriTTS model training now.

lawlietlight commented 1 year ago

I can sponsor 3 to 4 T4 instances in azure cloud for a week. Not sure whether that will help with current accelerate problem to speed up multispeaker tts training any further

yl4579 commented 1 year ago

@lawlietlight Thanks for your willingness to help. Maybe you can debug this problem if you have time?

WendongGan commented 1 year ago

Look forward to this problem being solved. I have calculated the current DP, I use 4 *A100, batch size 16, training libritts-460, I need to spend (15epoch x 7h+5epoch x 14h+15epoch x 18h)/24h=18.5days. It is really too long. If increase the training data to thousands or tens of thousands of hours, this time is even longer. I'll also start debugging this problem. We look forward to discussing and solving it together.

@yl4579 look forward to your share, too! Best wish!!!

hermanseu commented 1 year ago

As duration and f0 are irrelevant in ProsodyPredictor, I sperate ProsodyPredictor into two class, one for duration, the other one for f0, and also change function F0Ntrain to forward. And then I delete this line model.predictor._set_static_graph(). With these modifications, I can train the second stage normally in DDP mode until diff_epoch, when diff_epoch, DDP will be deadlocked at accelerator.backward(g_loss). when disable loss_diff = model.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean(), I can train the second stage normally. Maybe there have something wrong in diffusion when DDP.

yl4579 commented 1 year ago

@hermanseu I think separating F0 and duration is probably fine but you also need to sample more dimensions in diffusion model. Did you notice any performance drop by doing these?

hermanseu commented 1 year ago

Yes, without diffusion model, the performance droped. So we should find out why DDP will be deadlocked when using diffsuion model. But now I have no idea about that. πŸ˜…

joe-none416 commented 1 year ago

During my experiments, I found that certain phenomena can lead to DDP hangs. For code, the continuation of a particular process with slm_out is None can cause other processes to wait. In another case, if there is a "nan" value in a certain process, it may result in gradient anomalies, and similarly, it can also cause all processes to hang. Hope it helps.

one possible solutiong is that, with torch.distributed.all_reduce, if slm_out is None or loss is NaN in one process, then skip the current iteration for ALL process, I tried to trained one with this, but the model is not good, sad.

danielmsu commented 1 year ago

UPD: nvm, seems like it's another issue and it is already reported here: https://github.com/yl4579/StyleTTS2/issues/72

Old comment I have a similar issue while running train_finetune.py script (which uses DP) on a custom dataset: ``` Traceback (most recent call last): File "/workspace/StyleTTS2/train_finetune.py", line 714, in main() File "/usr/local/lib/python3.10/dist-packages/click/core.py", line 1157, in __call__ return self.main(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/click/core.py", line 1078, in main rv = self.invoke(ctx) File "/usr/local/lib/python3.10/dist-packages/click/core.py", line 1434, in invoke return ctx.invoke(self.callback, **ctx.params) File "/usr/local/lib/python3.10/dist-packages/click/core.py", line 783, in invoke return __callback(*args, **kwargs) File "/workspace/StyleTTS2/train_finetune.py", line 509, in main loss_gen_lm.backward() File "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py", line 492, in backward torch.autograd.backward( File "/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py", line 251, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 1, 1]] is at version 3; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True). ``` For some reason the script runs fine on 4090, but the error above appears when I run it on A100 (80G) at the first joint training epoch (30th by default).
yl4579 commented 1 year ago

@joe-none416 Why is it not good?

teamblubee commented 11 months ago

Those errors are caused because of in place operations, it's typically fine when non distributed but when you switch to distributed computation if one tensor is working on something and you modify it, then it will throw this error.

to debug, we have to first see what can be modifying the data the tensors are using from anywhere other than the tensor it was initially assigned to. This tends to happen a lot in GPU programming when data is touched from different contexts.

effusiveperiscope commented 10 months ago

Has anyone here tried the fix by @stevenhillis?

logikstate commented 10 months ago

Has anyone here tried the fix by @stevenhillis?

where is the fix?

effusiveperiscope commented 10 months ago

Has anyone here tried the fix by @stevenhillis?

where is the fix?

broadcast_buffers=False in DistributedDataParallelKwargs. Seems to work OK if you remove the isnan() check (I read that GradScaler is supposed to skip steps with nan loss?)

effusiveperiscope commented 10 months ago

Doesn't seem to work with slmadv training though.

soaking-proses commented 9 months ago

Hi!

Joining the conversation a lil late, but would it help if we could sponsor some dev/gpu time? @yl4579

Also happy to help with test cases/help reproduce w/ different dependency versions if it'll help πŸ˜…

Thank you!

starmoon-1134 commented 5 months ago

Doesn't seem to work with slmadv training though.

I encountered a similar issue. If slm_out is None, the next iteration at accelerator.backward(g_loss) would throw an error:

[rank0]: RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
[rank0]: Parameter at index 246 with name wl.wavlm.encoder.layers.11.final_layer_norm.weight has been marked as ready twice. This means that multiple autograd engine  hooks have fired for this particular parameter during this iteration.

When find_unused_parameters=True, if you call the forward method of SLMAdversarialLoss, PyTorch automatically marks the gradients of unused parameters in SLMAdversarialLoss as ready. When returning None early due to data-related reasons, all parameters in SLMAdversarialLoss, such as all parameters of wl, will have their gradients marked as ready. The next time accelerator.backward(g_loss) is executed, it attempts to calculate gradients for wl.wavlm.encoder.layers.11.final_layer_norm.weight again, but finds that its gradients have already been marked as ready, resulting in an error.

The solution is to modify the SLMAdversarialLoss forward method. Change the return None to raise an exception, like this:

global_min_batch = accelerator.gather(torch.tensor([len(wav)], device=ref_text.device)).min().item()
if global_min_batch <= 1:
    raise SomeError("skip slmadv")

Then, in train_second.py, use a try-except block:

try:
    slm_out = slmadv(...)
except:
    slm_out = None

Raising an exception in forward func can prevent the gradients of wavlm-related parameters from being marked as ready. I hope this helps you.

starmoon-1134 commented 5 months ago

Another issue is that updating d_loss_slm and loss_gen_lm using data from the same batch will result in an error:

 RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 256, 1]] is at version 75; expected version 73 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

The solution is to stagger these updates. For example:

if d_loss_slm != 0:
    optimizer.zero_grad()
    accelerator.backward(d_loss_slm)
    optimizer.step('wd')
else:
    # Process loss_gen_lm
    ...
acul3 commented 4 months ago

hello @starmoon-1134 did you find success after training using that solution?

starmoon-1134 commented 4 months ago

hello @starmoon-1134 did you find success after training using that solution?

Yeah, it works for me

schnekk commented 4 months ago

Hello, @starmoon-1134 How about the result? Is the performance the same compare to training using current script, second_train.py in terms of losses and sound quality.

starmoon-1134 commented 4 months ago

Hello, @starmoon-1134 How about the result? Is the performance the same compare to training using current script, second_train.py in terms of losses and sound quality.

Hi @schnekk, I haven't conducted rigorous tests, but based on the loss and synthesized speech quality, there are no noticeable negative impacts.

w11wo commented 4 months ago

Hi @starmoon-1134, thanks for the info. Could you please share your code if it is hosted somewhere? I think it'll be a great contribution to everybody here πŸ™

martinambrus commented 2 months ago

@starmoon-1134 just pinging you to see if you could help us with a DDP version of the second stage training, since I suspect many of us (like me) are not proficient enough in writing such code ourselves

starmoon-1134 commented 2 months ago

I'm sorry, I cannot share the code:

  1. It has become an internal company project, and I do not have permission to open-source it.
  2. We have made extensive modifications for generating speech in other languages, making it incompatible with the current repository, such as replacing PL-BERT with other types of BERT, adding tone embeddings, etc.
martinambrus commented 2 months ago

I'm sorry, I cannot share the code:

1. It has become an internal company project, and I do not have permission to open-source it.

2. We have made extensive modifications for generating speech in other languages, making it incompatible with the current repository, such as replacing PL-BERT with other types of BERT, adding tone embeddings, etc.

I understand, thank you kindly for letting us know.

martinambrus commented 1 month ago

White this problem still remains to be solved, I recently found out that using a VAST.AI instance, some VRAM sharing seems to be possible - which eliminates this issue.

When I used A100 SXM4 instances, I could use 8x of those GPUs and just use the train_second.py script as is - with config set to use all of the 8x64GB of VRAM (I had 96 batch size and 800 max_len set) - and it worked just fine.