Open yl4579 opened 1 year ago
The following is the broken (and unfinished) code for train_second.py
with DDP:
# load packages
import random
import yaml
import time
from munch import Munch
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
import librosa
import click
import shutil
import warnings
warnings.simplefilter('ignore')
from torch.utils.tensorboard import SummaryWriter
from meldataset import build_dataloader
from Utils.ASR.models import ASRCNN
from Utils.JDC.model import JDCNet
from Utils.PLBERT.util import load_plbert
from models import *
from losses import *
from utils import *
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
from optimizers import build_optimizer
from accelerate import Accelerator
from accelerate.utils import LoggerType
from accelerate import DistributedDataParallelKwargs
from torch.utils.tensorboard import SummaryWriter
import logging
from accelerate.logging import get_logger
logger = get_logger(__name__, log_level="DEBUG")
def _load(states, model, force_load=True):
model_states = model.state_dict()
for key, val in states.items():
try:
if key not in model_states:
continue
if isinstance(val, nn.Parameter):
val = val.data
if val.shape != model_states[key].shape:
print("%s does not have same shape" % key)
print(val.shape, model_states[key].shape)
if not force_load:
continue
min_shape = np.minimum(np.array(val.shape), np.array(model_states[key].shape))
slices = [slice(0, min_index) for min_index in min_shape]
model_states[key][slices].copy_(val[slices])
else:
model_states[key].copy_(val)
except:
print("not exist :%s" % key)
print("not exist ", key)
@click.command()
@click.option('-p', '--config_path', default='Configs/config.yml', type=str)
def main(config_path):
config = yaml.safe_load(open(config_path))
log_dir = config['log_dir']
if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(project_dir=log_dir, split_batches=True, kwargs_handlers=[ddp_kwargs])
if accelerator.is_main_process:
writer = SummaryWriter(log_dir + "/tensorboard")
# write logs
file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
logger.logger.addHandler(file_handler)
batch_size = config.get('batch_size', 10)
epochs = config.get('epochs_2nd', 200)
save_freq = config.get('save_freq', 2)
log_interval = config.get('log_interval', 10)
saving_epoch = config.get('save_freq', 2)
data_params = config.get('data_params', None)
sr = config['preprocess_params'].get('sr', 24000)
train_path = data_params['train_data']
val_path = data_params['val_data']
root_path = data_params['root_path']
min_length = data_params['min_length']
OOD_data = data_params['OOD_data']
max_len = config.get('max_len', 200)
loss_params = Munch(config['loss_params'])
diff_epoch = loss_params.diff_epoch
joint_epoch = loss_params.joint_epoch
optimizer_params = Munch(config['optimizer_params'])
train_list, val_list = get_data_path_list(train_path, val_path)
device = accelerator.device
train_dataloader = build_dataloader(train_list,
root_path,
OOD_data=OOD_data,
min_length=min_length,
batch_size=batch_size,
num_workers=2,
dataset_config={},
device=device)
val_dataloader = build_dataloader(val_list,
root_path,
OOD_data=OOD_data,
min_length=min_length,
batch_size=batch_size,
validation=True,
num_workers=0,
device=device,
dataset_config={})
with accelerator.main_process_first():
# load pretrained ASR model
ASR_config = config.get('ASR_config', False)
ASR_path = config.get('ASR_path', False)
text_aligner = load_ASR_models(ASR_path, ASR_config)
# load pretrained F0 model
F0_path = config.get('F0_path', False)
pitch_extractor = load_F0_models(F0_path)
# load PL-BERT model
BERT_path = config.get('PLBERT_dir', False)
plbert = load_plbert(BERT_path)
# build model
model_params = recursive_munch(config['model_params'])
multispeaker = model_params.multispeaker
model = build_model(model_params, text_aligner, pitch_extractor, plbert)
_ = [model[key].to(device) for key in model]
# DDP
for k in model:
model[k] = accelerator.prepare(model[k])
model.predictor._set_static_graph()
train_dataloader, val_dataloader = accelerator.prepare(
train_dataloader, val_dataloader
)
start_epoch = 0
iters = 0
load_pretrained = config.get('pretrained_model', '') != '' and config.get('second_stage_load_pretrained', False)
if not load_pretrained:
if config.get('first_stage_path', '') != '':
first_stage_path = osp.join(log_dir, config.get('first_stage_path', 'first_stage.pth'))
print('Loading the first stage model at %s ...' % first_stage_path)
model, _, start_epoch, iters = load_checkpoint(model,
None,
first_stage_path,
load_only_params=True,
ignore_modules=['bert', 'bert_encoder', 'predictor', 'predictor_encoder', 'msd', 'mpd', 'wd', 'diffusion']) # keep starting epoch for tensorboard log
# these epochs should be counted from the start epoch
diff_epoch += start_epoch
joint_epoch += start_epoch
epochs += start_epoch
# model.predictor_encoder = copy.deepcopy(model.style_encoder)
_load(model.style_encoder.state_dict(), model.predictor_encoder)
else:
raise ValueError('You need to specify the path to the first stage model.')
gl = GeneratorLoss(model.mpd, model.msd).to(device)
dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
wl = WavLMLoss(model_params.slm.model,
model.wd,
sr,
model_params.slm.sr).to(device)
gl = accelerator.prepare(gl)
dl = accelerator.prepare(dl)
wl = accelerator.prepare(wl)
try:
n_down = model.text_aligner.module.n_down
distributed = True
except:
n_down = model.text_aligner.n_down
distributed = False
sampler = DiffusionSampler(
model.diffusion.module.diffusion if distributed else model.diffusion.diffusion,
sampler=ADPM2Sampler(),
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
clamp=False
)
scheduler_params = {
"max_lr": optimizer_params.lr,
"pct_start": float(0),
"epochs": epochs,
"steps_per_epoch": len(train_dataloader),
}
scheduler_params_dict= {key: scheduler_params.copy() for key in model}
scheduler_params_dict['bert']['max_lr'] = optimizer_params.bert_lr * 2
optimizer = build_optimizer({key: model[key].parameters() for key in model},
scheduler_params_dict=scheduler_params_dict, lr=optimizer_params.lr)
# adjust BERT learning rate
for g in optimizer.optimizers['bert'].param_groups:
g['betas'] = (0.9, 0.99)
g['lr'] = optimizer_params.bert_lr
g['initial_lr'] = optimizer_params.bert_lr
g['min_lr'] = 0
g['weight_decay'] = 0.01
# adjust acoustic module learning rate
for module in ["decoder", "style_encoder"]:
for g in optimizer.optimizers[module].param_groups:
g['betas'] = (0.0, 0.99)
g['lr'] = optimizer_params.ft_lr
g['initial_lr'] = optimizer_params.ft_lr
g['min_lr'] = 0
g['weight_decay'] = 1e-4
for k, v in optimizer.optimizers.items():
optimizer.optimizers[k] = accelerator.prepare(optimizer.optimizers[k])
optimizer.schedulers[k] = accelerator.prepare(optimizer.schedulers[k])
# load models if there is a model
if load_pretrained:
with accelerator.main_process_first():
model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'],
load_only_params=config.get('load_only_params', True))
best_loss = float('inf') # best test loss
loss_train_record = list([])
loss_test_record = list([])
iters = 0
criterion = nn.L1Loss() # F0 loss (regression)
torch.cuda.empty_cache()
stft_loss = MultiResolutionSTFTLoss().to(device)
stft_loss = accelerator.prepare(stft_loss)
print(optimizer.optimizers['bert'])
start_ds = False
for epoch in range(start_epoch, epochs):
running_loss = 0
start_time = time.time()
_ = [model[key].eval() for key in model]
model.predictor.train()
# model.predictor_encoder.train() # uncomment this line will fix the in-place operation problem but will give you a higher F0 loss and worse model
model.bert_encoder.train()
model.bert.train()
model.msd.train()
model.mpd.train()
if epoch >= diff_epoch:
start_ds = True
for i, batch in enumerate(train_dataloader):
waves = batch[0]
batch = [b.to(device) for b in batch[1:]]
texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
with torch.no_grad():
mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device)
mel_mask = length_to_mask(mel_input_length).to(device)
text_mask = length_to_mask(input_lengths).to(texts.device)
try:
_, _, s2s_attn = model.text_aligner(mels, mask, texts)
s2s_attn = s2s_attn.transpose(-1, -2)
s2s_attn = s2s_attn[..., 1:]
s2s_attn = s2s_attn.transpose(-1, -2)
except:
continue
mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
# encode
t_en = model.text_encoder(texts, input_lengths, text_mask)
asr = (t_en @ s2s_attn_mono)
d_gt = s2s_attn_mono.sum(axis=-1).detach()
# compute the style of the entire utterance
# this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool)
ss = []
gs = []
for bib in range(len(mel_input_length)):
mel_length = int(mel_input_length[bib].item())
mel = mels[bib, :, :mel_input_length[bib]]
s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
ss.append(s)
s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
gs.append(s)
s_dur = torch.stack(ss).squeeze() # global prosodic styles
gs = torch.stack(gs).squeeze() # global acoustic styles
s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser
bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
# denoiser training
if epoch >= diff_epoch:
num_steps = np.random.randint(3, 5)
if model_params.diffusion.dist.estimate_sigma_data:
model.diffusion.module.diffusion.sigma_data = s_trg.std().item()
if multispeaker:
s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
embedding=bert_dur,
embedding_scale=1,
features=ref, # reference from the same speaker as the embedding
embedding_mask_proba=0.1,
num_steps=num_steps).squeeze(1)
loss_diff = model.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean() # EDM loss
loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
else:
s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
embedding=bert_dur,
embedding_scale=1,
embedding_mask_proba=0.1,
num_steps=num_steps).squeeze(1)
loss_diff = model.diffusion.module.diffusion(s_trg.unsqueeze(1), embedding=bert_dur).mean() # EDM loss
loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
else:
loss_sty = 0
loss_diff = 0
d, p = model.predictor(d_en, s_dur,
input_lengths,
s2s_attn_mono,
text_mask)
# get clips
mel_input_length_all = accelerator.gather(mel_input_length) # for balanced load
mel_len = min([int(mel_input_length_all.min().item() / 2 - 1), max_len // 2])
en = []
gt = []
p_en = []
wav = []
for bib in range(len(mel_input_length)):
mel_length = int(mel_input_length[bib].item() / 2)
random_start = np.random.randint(0, mel_length - mel_len)
en.append(asr[bib, :, random_start:random_start+mel_len])
p_en.append(p[bib, :, random_start:random_start+mel_len])
gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
wav.append(torch.from_numpy(y).to(device))
wav = torch.stack(wav).float().detach()
en = torch.stack(en)
p_en = torch.stack(p_en)
gt = torch.stack(gt).detach()
if gt.size(-1) < 80:
continue
s_dur = model.predictor_encoder(gt.unsqueeze(1))
with torch.no_grad():
s = model.style_encoder(gt.unsqueeze(1))
F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()
N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
# ground truth from reconstruction
y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
# ground truth from recording
y_rec_gt = wav.unsqueeze(1)
if epoch >= joint_epoch:
wav = y_rec_gt # use recording since decoder is tuned
else:
wav = y_rec_gt_pred # use reconstruction since decoder is fixed
F0_fake, N_fake = model.predictor.module.F0Ntrain(p_en, s_dur)
y_rec = model.decoder(en, F0_fake, N_fake, s)
loss_F0_rec = (F.smooth_l1_loss(F0_real, F0_fake)) / 10
loss_norm_rec = F.smooth_l1_loss(N_real, N_fake)
if start_ds:
optimizer.zero_grad()
d_loss = dl(wav.detach(), y_rec.detach()).mean()
accelerator.backward(d_loss)
optimizer.step('msd')
optimizer.step('mpd')
else:
d_loss = 0
# generator loss
optimizer.zero_grad()
loss_mel = stft_loss(y_rec, wav)
if start_ds:
loss_gen_all = gl(wav, y_rec).mean()
else:
loss_gen_all = 0
loss_lm = wl(wav.detach().squeeze(), y_rec.squeeze()).mean()
loss_ce = 0
loss_dur = 0
for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
_s2s_pred = _s2s_pred[:_text_length, :]
_text_input = _text_input[:_text_length].long()
_s2s_trg = torch.zeros_like(_s2s_pred)
for p in range(_s2s_trg.shape[0]):
_s2s_trg[p, :_text_input[p]] = 1
_dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
_text_input[1:_text_length-1])
loss_ce += F.binary_cross_entropy_with_logits(_s2s_pred.flatten(), _s2s_trg.flatten())
loss_ce /= texts.size(0)
loss_dur /= texts.size(0)
g_loss = loss_params.lambda_mel * loss_mel + \
loss_params.lambda_F0 * loss_F0_rec + \
loss_params.lambda_ce * loss_ce + \
loss_params.lambda_norm * loss_norm_rec + \
loss_params.lambda_dur * loss_dur + \
loss_params.lambda_gen * loss_gen_all + \
loss_params.lambda_slm * loss_lm
running_loss += accelerator.gather(loss_mel).mean().item()
with torch.autograd.set_detect_anomaly(True):
accelerator.backward(g_loss)
if torch.isnan(g_loss):
from IPython.core.debugger import set_trace
set_trace()
optimizer.step('bert_encoder')
optimizer.step('bert')
optimizer.step('predictor')
optimizer.step('predictor_encoder')
if epoch >= diff_epoch:
optimizer.step('diffusion')
if epoch >= joint_epoch:
optimizer.step('style_encoder')
optimizer.step('decoder')
iters = iters + 1
d_loss_slm = 0
loss_gen_lm = 0
if (i+1)%log_interval == 0 and accelerator.is_main_process:
print ('Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f'
%(epoch+1, epochs, i+1, len(train_list)//batch_size, running_loss / log_interval, d_loss, loss_dur, loss_ce, loss_norm_rec, loss_F0_rec, loss_lm, loss_gen_all, loss_sty, loss_diff, d_loss_slm, loss_gen_lm))
writer.add_scalar('train/mel_loss', running_loss / log_interval, iters)
writer.add_scalar('train/gen_loss', loss_gen_all, iters)
writer.add_scalar('train/d_loss', d_loss, iters)
writer.add_scalar('train/ce_loss', loss_ce, iters)
writer.add_scalar('train/dur_loss', loss_dur, iters)
writer.add_scalar('train/slm_loss', loss_lm, iters)
writer.add_scalar('train/norm_loss', loss_norm_rec, iters)
writer.add_scalar('train/F0_loss', loss_F0_rec, iters)
writer.add_scalar('train/sty_loss', loss_sty, iters)
writer.add_scalar('train/diff_loss', loss_diff, iters)
writer.add_scalar('train/d_loss_slm', d_loss_slm, iters)
writer.add_scalar('train/gen_loss_slm', loss_gen_lm, iters)
running_loss = 0
print('Time elasped:', time.time()-start_time)
loss_test = 0
loss_align = 0
loss_f = 0
_ = [model[key].eval() for key in model]
with torch.no_grad():
iters_test = 0
for batch_idx, batch in enumerate(val_dataloader):
optimizer.zero_grad()
try:
waves = batch[0]
batch = [b.to(device) for b in batch[1:]]
texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
with torch.no_grad():
mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda')
text_mask = length_to_mask(input_lengths).to(texts.device)
_, _, s2s_attn = model.text_aligner(mels, mask, texts)
s2s_attn = s2s_attn.transpose(-1, -2)
s2s_attn = s2s_attn[..., 1:]
s2s_attn = s2s_attn.transpose(-1, -2)
mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
# encode
t_en = model.text_encoder(texts, input_lengths, text_mask)
asr = (t_en @ s2s_attn_mono)
d_gt = s2s_attn_mono.sum(axis=-1).detach()
ss = []
gs = []
for bib in range(len(mel_input_length)):
mel_length = int(mel_input_length[bib].item())
mel = mels[bib, :, :mel_input_length[bib]]
s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
ss.append(s)
s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
gs.append(s)
s = torch.stack(ss).squeeze()
gs = torch.stack(gs).squeeze()
s_trg = torch.cat([s, gs], dim=-1).detach()
bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
d, p = model.predictor(d_en, s,
input_lengths,
s2s_attn_mono,
text_mask)
# get clips
mel_len = int(mel_input_length.min().item() / 2 - 1)
en = []
gt = []
p_en = []
wav = []
for bib in range(len(mel_input_length)):
mel_length = int(mel_input_length[bib].item() / 2)
random_start = np.random.randint(0, mel_length - mel_len)
en.append(asr[bib, :, random_start:random_start+mel_len])
p_en.append(p[bib, :, random_start:random_start+mel_len])
gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
wav.append(torch.from_numpy(y).to(device))
wav = torch.stack(wav).float().detach()
en = torch.stack(en)
p_en = torch.stack(p_en)
gt = torch.stack(gt).detach()
s = model.predictor_encoder(gt.unsqueeze(1))
F0_fake, N_fake = model.predictor.module.F0Ntrain(p_en, s)
loss_dur = 0
for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
_s2s_pred = _s2s_pred[:_text_length, :]
_text_input = _text_input[:_text_length].long()
_s2s_trg = torch.zeros_like(_s2s_pred)
for bib in range(_s2s_trg.shape[0]):
_s2s_trg[bib, :_text_input[bib]] = 1
_dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
_text_input[1:_text_length-1])
loss_dur /= texts.size(0)
s = model.style_encoder(gt.unsqueeze(1))
y_rec = model.decoder(en, F0_fake, N_fake, s)
loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
loss_F0 = F.l1_loss(F0_real, F0_fake) / 10
loss_test += accelerator.gather(loss_mel).mean()
loss_align += accelerator.gather(loss_dur).mean()
loss_f += accelerator.gather(loss_F0).mean()
iters_test += 1
except:
continue
if accelerator.is_main_process:
print('Epochs:', epoch + 1)
print('Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f' % (loss_test / iters_test, loss_align / iters_test, loss_f / iters_test) + '\n\n\n')
print('\n\n\n')
writer.add_scalar('eval/mel_loss', loss_test / iters_test, epoch + 1)
writer.add_scalar('eval/dur_loss', loss_test / iters_test, epoch + 1)
writer.add_scalar('eval/F0_loss', loss_f / iters_test, epoch + 1)
with torch.no_grad():
for bib in range(len(asr)):
mel_length = int(mel_input_length[bib].item())
gt = mels[bib, :, :mel_length].unsqueeze(0)
en = asr[bib, :, :mel_length // 2].unsqueeze(0)
F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
F0_real = F0_real.unsqueeze(0)
s = model.style_encoder(gt.unsqueeze(1))
real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
y_rec = model.decoder(en, F0_real, real_norm, s)
writer.add_audio('eval/y' + str(bib), y_rec.cpu().numpy().squeeze(), epoch, sample_rate=sr)
s_dur = model.predictor_encoder(gt.unsqueeze(1))
p_en = p[bib, :, :mel_length // 2].unsqueeze(0)
F0_fake, N_fake = model.predictor.module.F0Ntrain(p_en, s_dur)
y_pred = model.decoder(en, F0_fake, N_fake, s)
writer.add_audio('pred/y' + str(bib), y_pred.cpu().numpy().squeeze(), epoch, sample_rate=sr)
if epoch == 0:
writer.add_audio('gt/y' + str(bib), waves[bib].squeeze(), epoch, sample_rate=sr)
if bib >= 5:
break
if epoch % saving_epoch == 0:
if (loss_test / iters_test) < best_loss:
best_loss = loss_test / iters_test
print('Saving..')
state = {
'net': {key: model[key].state_dict() for key in model},
'optimizer': optimizer.state_dict(),
'iters': iters,
'val_loss': loss_test / iters_test,
'epoch': epoch,
}
save_path = osp.join(log_dir, 'epoch_2nd_%05d.pth' % epoch)
torch.save(state, save_path)
if __name__=="__main__":
main()
Did you try other versions of PyTorch?
@zhouyong64 This issue (in-place operation) was first identified by @ABC0408, who used a different PyTorch version than me, though we both used the PyTorch > 2.0. Not sure if it is relevant. I will try PyTorch < 2.0 when I get time.
This is a pytorch gotcha at the intersection of ddp, buffers, and gans (multiple forward passes). DDP modules broadcast the root process module's buffers at every forward pass, which is treated as an inplace op. Buffers show up mostly from batchnorm, which can be solved by using syncbatchnorm. Instancenorm can also be a culprit, since it inherits from the same primitive as batchnorm, but only if track_running_stats is set to True, which it isn't. I think the culprit in this case is probably spectral_norm, which has buffers but is also supposed to handle this broadcasting issue by cloning (reference https://pytorch.org/docs/stable/_modules/torch/nn/utils/spectral_norm.html#SpectralNorm). Not sure why that wouldn't be working here, but regardless of the root cause, you can disable the broadcasting by changing the ddp kwargs to be
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True, broadcast_buffers=False)
. Please let me know if this results in the same regression with F0 loss as calling .train() did!
Hi @stevenhillis , thanks for your help. The problem happens even before the discriminator kicks in, so it is unlikely caused by spectral_norm
. I have tried your suggestions and luckily the in-place error disappeared without having to set model.predictor_encoder.train()
! I will keep you posted when I have extra GPUs but now my GPUs are used for the multispeaker LibriTTS model training now.
I can sponsor 3 to 4 T4 instances in azure cloud for a week. Not sure whether that will help with current accelerate problem to speed up multispeaker tts training any further
@lawlietlight Thanks for your willingness to help. Maybe you can debug this problem if you have time?
Look forward to this problem being solved. I have calculated the current DP, I use 4 *A100, batch size 16, training libritts-460, I need to spend (15epoch x 7h+5epoch x 14h+15epoch x 18h)/24h=18.5days. It is really too long. If increase the training data to thousands or tens of thousands of hours, this time is even longer. I'll also start debugging this problem. We look forward to discussing and solving it together.
@yl4579 look forward to your share, too! Best wish!!!
As duration and f0 are irrelevant in ProsodyPredictor, I sperate ProsodyPredictor into two class, one for duration, the other one for f0, and also change function F0Ntrain to forward. And then I delete this line model.predictor._set_static_graph(). With these modifications, I can train the second stage normally in DDP mode until diff_epoch, when diff_epoch, DDP will be deadlocked at accelerator.backward(g_loss). when disable loss_diff = model.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean(), I can train the second stage normally. Maybe there have something wrong in diffusion when DDP.
@hermanseu I think separating F0 and duration is probably fine but you also need to sample more dimensions in diffusion model. Did you notice any performance drop by doing these?
Yes, without diffusion model, the performance droped. So we should find out why DDP will be deadlocked when using diffsuion model. But now I have no idea about that. π
During my experiments, I found that certain phenomena can lead to DDP hangs. For code, the continuation of a particular process with slm_out
is None can cause other processes to wait. In another case, if there is a "nan" value in a certain process, it may result in gradient anomalies, and similarly, it can also cause all processes to hang. Hope it helps.
one possible solutiong is that, with torch.distributed.all_reduce, if slm_out is None or loss is NaN in one process, then skip the current iteration for ALL process, I tried to trained one with this, but the model is not good, sad.
UPD: nvm, seems like it's another issue and it is already reported here: https://github.com/yl4579/StyleTTS2/issues/72
@joe-none416 Why is it not good?
Those errors are caused because of in place operations, it's typically fine when non distributed but when you switch to distributed computation if one tensor is working on something and you modify it, then it will throw this error.
to debug, we have to first see what can be modifying the data the tensors are using from anywhere other than the tensor it was initially assigned to. This tends to happen a lot in GPU programming when data is touched from different contexts.
Has anyone here tried the fix by @stevenhillis?
Has anyone here tried the fix by @stevenhillis?
where is the fix?
Has anyone here tried the fix by @stevenhillis?
where is the fix?
broadcast_buffers=False in DistributedDataParallelKwargs. Seems to work OK if you remove the isnan() check (I read that GradScaler is supposed to skip steps with nan loss?)
Doesn't seem to work with slmadv training though.
Hi!
Joining the conversation a lil late, but would it help if we could sponsor some dev/gpu time? @yl4579
Also happy to help with test cases/help reproduce w/ different dependency versions if it'll help π
Thank you!
Doesn't seem to work with slmadv training though.
I encountered a similar issue. If slm_out is None, the next iteration at accelerator.backward(g_loss) would throw an error:
[rank0]: RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
[rank0]: Parameter at index 246 with name wl.wavlm.encoder.layers.11.final_layer_norm.weight has been marked as ready twice. This means that multiple autograd engine hooks have fired for this particular parameter during this iteration.
When find_unused_parameters=True, if you call the forward method of SLMAdversarialLoss, PyTorch automatically marks the gradients of unused parameters in SLMAdversarialLoss as ready. When returning None early due to data-related reasons, all parameters in SLMAdversarialLoss, such as all parameters of wl, will have their gradients marked as ready. The next time accelerator.backward(g_loss) is executed, it attempts to calculate gradients for wl.wavlm.encoder.layers.11.final_layer_norm.weight again, but finds that its gradients have already been marked as ready, resulting in an error.
The solution is to modify the SLMAdversarialLoss forward method. Change the return None to raise an exception, like this:
global_min_batch = accelerator.gather(torch.tensor([len(wav)], device=ref_text.device)).min().item()
if global_min_batch <= 1:
raise SomeError("skip slmadv")
Then, in train_second.py, use a try-except block:
try:
slm_out = slmadv(...)
except:
slm_out = None
Raising an exception in forward func can prevent the gradients of wavlm-related parameters from being marked as ready. I hope this helps you.
Another issue is that updating d_loss_slm and loss_gen_lm using data from the same batch will result in an error:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 256, 1]] is at version 75; expected version 73 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
The solution is to stagger these updates. For example:
if d_loss_slm != 0:
optimizer.zero_grad()
accelerator.backward(d_loss_slm)
optimizer.step('wd')
else:
# Process loss_gen_lm
...
hello @starmoon-1134 did you find success after training using that solution?
hello @starmoon-1134 did you find success after training using that solution?
Yeah, it works for me
Hello, @starmoon-1134 How about the result? Is the performance the same compare to training using current script, second_train.py
in terms of losses and sound quality.
Hello, @starmoon-1134 How about the result? Is the performance the same compare to training using current script,
second_train.py
in terms of losses and sound quality.
Hi @schnekk, I haven't conducted rigorous tests, but based on the loss and synthesized speech quality, there are no noticeable negative impacts.
Hi @starmoon-1134, thanks for the info. Could you please share your code if it is hosted somewhere? I think it'll be a great contribution to everybody here π
@starmoon-1134 just pinging you to see if you could help us with a DDP version of the second stage training, since I suspect many of us (like me) are not proficient enough in writing such code ourselves
I'm sorry, I cannot share the code:
I'm sorry, I cannot share the code:
1. It has become an internal company project, and I do not have permission to open-source it. 2. We have made extensive modifications for generating speech in other languages, making it incompatible with the current repository, such as replacing PL-BERT with other types of BERT, adding tone embeddings, etc.
I understand, thank you kindly for letting us know.
White this problem still remains to be solved, I recently found out that using a VAST.AI instance, some VRAM sharing seems to be possible - which eliminates this issue.
When I used A100 SXM4 instances, I could use 8x of those GPUs and just use the train_second.py script as is - with config set to use all of the 8x64GB of VRAM (I had 96 batch size
and 800 max_len
set) - and it worked just fine.
So far train_second.py only works with DataParallel (DP) but not DistributedDataParalell (DDP). One major problem with this is if we simply translate DP to DDP (code in the comment section), we encounter the following problem:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [512]] is at version 6; expected version 5 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
It is insanely difficult to debug. The tensor has no batch dimension, indicating it might be a parameter in the neural network. I found the tensor to be the bias term of the last Conv1D layer ofpredictor_encoder
(prosodic style encoder): https://github.com/yl4579/StyleTTS2/blob/main/models.py#L152. This is extremely weird because the problem does not trigger for any Conv1D layer before this.More mysteriously, issue surprisingly disappears if we add
model.predictor_encoder.train()
near line 250 oftrain_second.py
. However, this causes the F0 loss to be much higher than without this line. This is true for both DP and DDP, so the higher F0 loss value is caused bymodel.predictor_encoder.train()
, not DDP. Unfortunately, thepredictor_encoder
, which isStyleEncoder
, has no module that changes the behavior depending on whether it is in train or eval mode. The output is exactly the same whether it is set to train or eval.TLDR: There are three issues with
train_second.py
:model.predictor_encoder.train()
before training.model.predictor_encoder.train()
causes F0 loss to be much higher after convergence. This issue is independent of using DP or DDP.model.predictor_encoder
is an instantiation ofStyleEncoder
, which has no components that change the output depending on its train or eval mode.This problem has bugged me for more than a month, but I can't find a solution to it. It would be greatly appreciated if anyone has any insight into how to fix this problem. I have pasted the broken DDP code with accelerator below.