Closed steven850 closed 1 year ago
in freevc-24.json segment_size=8640 is a casual multiply of 480. for 48khz a casual multiply of 960 is fine. [10,8,6,2] is fine for 48k.
so im running into some issues trying to make this work, and I believe there are still some bugs in the 24khz that are now showing up in the 48khz.
so in your code for datautils24 you have
def __init__(self, audiopaths, hparams):
self.audiopaths = load_filepaths_and_text(audiopaths)
self.max_wav_value = hparams.data.max_wav_value
self.sampling_rate = hparams.data.sampling_rate
self.filter_length = hparams.data.filter_length
self.hop_length = hparams.data.hop_length
self.win_length = hparams.data.win_length
self.use_sr = hparams.train.use_sr
self.use_spk = hparams.model.use_spk
self.spec_len = hparams.train.max_speclen
random.seed(1234)
random.shuffle(self.audiopaths)
self._filter()
def _filter(self):
"""
Filter text & store spec lengths
"""
# Store spectrogram lengths for Bucketing
# wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
# spec_length = wav_length // hop_length
lengths = []
for audiopath in self.audiopaths:
lengths.append(os.path.getsize(audiopath[0]) // (2 * self.hop_length))
self.lengths = lengths
def get_audio(self, filename):
audio, sampling_rate = load_wav_to_torch(filename.replace("DUMMY", "dataset/vctk-24k"))
if sampling_rate != 24000:
raise ValueError("{} SR doesn't match target {} SR".format(
sampling_rate, self.sampling_rate))
audio_norm = audio / self.max_wav_value
audio_norm = audio_norm.unsqueeze(0)
spec_filename = filename.replace(".wav", ".spec.pt")
if os.path.exists(spec_filename):
spec = torch.load(spec_filename)
else:
spec = spectrogram_torch(audio_norm, self.filter_length,
self.sampling_rate, self.hop_length, self.win_length,
center=False)
spec = torch.squeeze(spec, 0)
torch.save(spec, spec_filename)
So for Lengths, you pull from the 16khz folder correct?
lengths = []
for audiopath in self.audiopaths:
lengths.append(os.path.getsize(audiopath[0]) // (2 * self.hop_length))
self.lengths = lengths
but for audio, sampling_rate = load_wav_to_torch(filename.replace("DUMMY", "dataset/vctk-24k"))
you pull from the 24khz folder.
but if thats the case, then then this section doesnt make any sense
spec = spectrogram_torch(audio_norm, self.filter_length,
self.sampling_rate, self.hop_length, self.win_length,
center=False)
because all of these values are pulled from the config.json and are still set to 16khz values so hop 320 etc....
this explains why likezz had to manually change the hop to 480 to make it work, but he didnt change any of the other values, so now the files are not using the correct Sr etc. The training runs, but I don't think the results will be any good.
so when trying to run a model with 48khz, I get the following error if I leave the values at self RuntimeError: The expanded size of the tensor (111360) must match the existing size (90560) at non-singleton dimension 1. Target sizes: [1, 111360]. Tensor sizes: [90560]
if I change hop length manually to 960, I get this error: RuntimeError: negative padding is not supported
if I change all values manually I get this error: RuntimeError: Given groups=1, weight of size [192, 641, 1], expected input[64, 1921, 37] to have 641 channels, but got 1921 channels instead
z
as input (50 frames per second), and outputs 24khz 240 hop length wav (100 frames per second), so all we need to do is to set the upsample rates of decoder to 480=240*(100/50). and for the dataloader, only wav is 24khz, spectrogram and wavlm are all 16khz. the same applies to 48khz. hope these explaination can help you debug.I can set it to use the 16khz folder and the existing .pt files, then there is no error regarding the PT files, but then its also only loading the 16khz files and not the 48khz files.
I mean this line here loads the wav files correct?
def get_audio(self, filename): audio, sampling_rate = load_wav_to_torch(filename.replace("DUMMY", "dataset/vctk-48k"))
so to train a model on 48khz, this need to point to the 48khz folder correct?
if I do that, I get this error RuntimeError: Given groups=1, weight of size [192, 641, 1], expected input[64, 1921, 37] to have 641 channels, but got 1921 channels instead
and If i try to run it like this
lengths = []
for audiopath in self.audiopaths:
audiopath = audiopath[0].replace("\\","/").replace("DUMMY", "dataset/vctk-16k")
lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length))
self.lengths = lengths
def get_audio(self, filename):
audio, sampling_rate = load_wav_to_torch(filename.replace("DUMMY", "dataset/vctk-48k"))
if sampling_rate != 48000:
raise ValueError("{} SR doesn't match target {} SR".format(
sampling_rate, self.sampling_rate))
audio_norm = audio / self.max_wav_value
audio_norm = audio_norm.unsqueeze(0)
spec_filename = filename.replace("DUMMY", "dataset/vctk-16k").replace(".wav", ".spec.pt")
if os.path.exists(spec_filename):
spec = torch.load(spec_filename)
else:
spec = spectrogram_torch(audio_norm, self.filter_length,
self.sampling_rate, self.hop_length, self.win_length,
center=False)
spec = torch.squeeze(spec, 0)
torch.save(spec, spec_filename)
I get this error
Traceback (most recent call last):
File "train_48.py", line 295, in <module>
main()
File "train_48.py", line 49, in main
mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,))
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\multiprocessing\spawn.py", line 230, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\multiprocessing\spawn.py", line 188, in start_processes
while not context.join():
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\multiprocessing\spawn.py", line 150, in join
raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:
-- Process 0 terminated with the following error:
Traceback (most recent call last):
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\multiprocessing\spawn.py", line 59, in _wrap
fn(i, *args)
File "Z:\FreeVC-TrainHIFI\train_48.py", line 115, in run
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, eval_loader], logger, [writer, writer_eval])
File "Z:\FreeVC-TrainHIFI\train_48.py", line 149, in train_and_evaluate
(z, z_p, m_p, logs_p, m_q, logs_q) = net_g(c, spec, g=g)
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\parallel\distributed.py", line 886, in forward
output = self.module(*inputs[0], **kwargs[0])
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "Z:\FreeVC-TrainHIFI\models.py", line 336, in forward
o = self.dec(z_slice, g=g)
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "Z:\FreeVC-TrainHIFI\models.py", line 113, in forward
x = self.ups[i](x)
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\modules\module.py", line 1120, in _call_impl
result = forward_call(*input, **kwargs)
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\modules\conv.py", line 774, in forward
output_padding, self.groups, self.dilation)
RuntimeError: negative padding is not supported
yes this line points to 48khz folder. there are many hardcoded numbers in these dirty codes, for example this, maybe you did not modify all of these hardcoded numbers.
I compared all of the changes between the original files, and the 24khz version using compare in notepad++. So I was able to find everything you changed for the 24khz version, I then doubled all the values to make the 48khz version. I did find those and changed them.
Here is the complete code if you would like to take a look
data_utils_48.py
import time
import os
import random
import numpy as np
import torch
import torch.utils.data
import commons
from mel_processing import spectrogram_torch, spec_to_mel_torch
from utils import load_wav_to_torch, load_filepaths_and_text, transform
#import h5py
"""Multi speaker version"""
class TextAudioSpeakerLoader(torch.utils.data.Dataset):
"""
1) loads audio, speaker_id, text pairs
2) normalizes text and converts them to sequences of integers
3) computes spectrograms from audio files.
"""
def __init__(self, audiopaths, hparams):
self.audiopaths = load_filepaths_and_text(audiopaths)
self.max_wav_value = hparams.data.max_wav_value
self.sampling_rate = hparams.data.sampling_rate
self.filter_length = hparams.data.filter_length
self.hop_length = hparams.data.hop_length
self.win_length = hparams.data.win_length
self.use_sr = hparams.train.use_sr
self.use_spk = hparams.model.use_spk
self.spec_len = hparams.train.max_speclen
random.seed(1234)
random.shuffle(self.audiopaths)
self._filter()
def _filter(self):
"""
Filter text & store spec lengths
"""
# Store spectrogram lengths for Bucketing
# wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
# spec_length = wav_length // hop_length
lengths = []
for audiopath in self.audiopaths:
audiopath = audiopath[0].replace("\\","/").replace("DUMMY", "dataset/vctk-16k")
lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length))
self.lengths = lengths
def get_audio(self, filename):
audio, sampling_rate = load_wav_to_torch(filename.replace("DUMMY", "dataset/vctk-48k"))
if sampling_rate != 48000:
raise ValueError("{} SR doesn't match target {} SR".format(
sampling_rate, self.sampling_rate))
audio_norm = audio / self.max_wav_value
audio_norm = audio_norm.unsqueeze(0)
spec_filename = filename.replace("DUMMY", "dataset/vctk-16k").replace(".wav", ".spec.pt")
if os.path.exists(spec_filename):
spec = torch.load(spec_filename)
else:
spec = spectrogram_torch(audio_norm, self.filter_length,
self.sampling_rate, self.hop_length, self.win_length,
center=False)
spec = torch.squeeze(spec, 0)
torch.save(spec, spec_filename)
if self.use_spk:
spk_filename = filename.replace(".wav", ".npy")
spk_filename = spk_filename.replace("DUMMY", "dataset/spk")
spk = torch.from_numpy(np.load(spk_filename))
if not self.use_sr:
c_filename = filename.replace(".wav", ".pt")
c_filename = c_filename.replace("DUMMY", "dataset/wavlm")
c = torch.load(c_filename).squeeze(0)
else:
i = random.randint(68,92)
'''
basename = os.path.basename(filename)[:-4]
spkname = basename[:4]
#print(basename, spkname)
with h5py.File(f"dataset/rs/wavlm/{spkname}/{i}.hdf5","r") as f:
c = torch.from_numpy(f[basename][()]).squeeze(0)
#print(c)
'''
c_filename = filename.replace(".wav", f"_{i}.pt")
c_filename = c_filename.replace("DUMMY", "dataset/sr/wavlm")
c = torch.load(c_filename).squeeze(0)
'''
lmin = min(c.size(-1), spec.size(-1))
spec, c = spec[:, :lmin], c[:, :lmin]
audio_norm = audio_norm[:, :lmin*960]
_spec, _c, _audio_norm = spec, c, audio_norm
while spec.size(-1) < self.spec_len:
spec = torch.cat((spec, _spec), -1)
c = torch.cat((c, _c), -1)
audio_norm = torch.cat((audio_norm, _audio_norm), -1)
start = random.randint(0, spec.size(-1) - self.spec_len)
end = start + self.spec_len
spec = spec[:, start:end]
c = c[:, start:end]
audio_norm = audio_norm[:, start*960:end*960]
'''
if self.use_spk:
return c, spec, audio_norm, spk
else:
return c, spec, audio_norm
def __getitem__(self, index):
return self.get_audio(self.audiopaths[index][0])
def __len__(self):
return len(self.audiopaths)
class TextAudioSpeakerCollate():
""" Zero-pads model inputs and targets
"""
def __init__(self, hps):
self.hps = hps
self.use_sr = hps.train.use_sr
self.use_spk = hps.model.use_spk
def __call__(self, batch):
"""Collate's training batch from normalized text, audio and speaker identities
PARAMS
------
batch: [text_normalized, spec_normalized, wav_normalized, sid]
"""
# Right zero-pad all one-hot text sequences to max input length
_, ids_sorted_decreasing = torch.sort(
torch.LongTensor([x[0].size(1) for x in batch]),
dim=0, descending=True)
max_spec_len = max([x[1].size(1) for x in batch])
max_wav_len = max([x[2].size(1) for x in batch])
spec_lengths = torch.LongTensor(len(batch))
wav_lengths = torch.LongTensor(len(batch))
if self.use_spk:
spks = torch.FloatTensor(len(batch), batch[0][3].size(0))
else:
spks = None
c_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len)
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
c_padded.zero_()
spec_padded.zero_()
wav_padded.zero_()
for i in range(len(ids_sorted_decreasing)):
row = batch[ids_sorted_decreasing[i]]
c = row[0]
c_padded[i, :, :c.size(1)] = c
spec = row[1]
spec_padded[i, :, :spec.size(1)] = spec
spec_lengths[i] = spec.size(1)
wav = row[2]
wav_padded[i, :, :wav.size(1)] = wav
wav_lengths[i] = wav.size(1)
if self.use_spk:
spks[i] = row[3]
spec_seglen = spec_lengths[-1] if spec_lengths[-1] < self.hps.train.max_speclen + 1 else self.hps.train.max_speclen + 1
wav_seglen = spec_seglen * 960
spec_padded, ids_slice = commons.rand_spec_segments(spec_padded, spec_lengths, spec_seglen)
wav_padded = commons.slice_segments(wav_padded, ids_slice * 960, wav_seglen)
c_padded = commons.slice_segments(c_padded, ids_slice, spec_seglen)[:,:,:-1]
spec_padded = spec_padded[:,:,:-1]
wav_padded = wav_padded[:,:,:-960]
if self.use_spk:
return c_padded, spec_padded, wav_padded, spks
else:
return c_padded, spec_padded, wav_padded
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
"""
Maintain similar input lengths in a batch.
Length groups are specified by boundaries.
Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
It removes samples which are not included in the boundaries.
Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
"""
def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True):
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
self.lengths = dataset.lengths
self.batch_size = batch_size
self.boundaries = boundaries
self.buckets, self.num_samples_per_bucket = self._create_buckets()
self.total_size = sum(self.num_samples_per_bucket)
self.num_samples = self.total_size // self.num_replicas
def _create_buckets(self):
buckets = [[] for _ in range(len(self.boundaries) - 1)]
for i in range(len(self.lengths)):
length = self.lengths[i]
idx_bucket = self._bisect(length)
if idx_bucket != -1:
buckets[idx_bucket].append(i)
for i in range(len(buckets) - 1, 0, -1):
if len(buckets[i]) == 0:
buckets.pop(i)
self.boundaries.pop(i+1)
num_samples_per_bucket = []
for i in range(len(buckets)):
len_bucket = len(buckets[i])
total_batch_size = self.num_replicas * self.batch_size
rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size
num_samples_per_bucket.append(len_bucket + rem)
return buckets, num_samples_per_bucket
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = []
if self.shuffle:
for bucket in self.buckets:
indices.append(torch.randperm(len(bucket), generator=g).tolist())
else:
for bucket in self.buckets:
indices.append(list(range(len(bucket))))
batches = []
for i in range(len(self.buckets)):
bucket = self.buckets[i]
len_bucket = len(bucket)
ids_bucket = indices[i]
num_samples_bucket = self.num_samples_per_bucket[i]
# add extra samples to make it evenly divisible
rem = num_samples_bucket - len_bucket
ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)]
# subsample
ids_bucket = ids_bucket[self.rank::self.num_replicas]
# batching
for j in range(len(ids_bucket) // self.batch_size):
batch = [bucket[idx] for idx in ids_bucket[j*self.batch_size:(j+1)*self.batch_size]]
batches.append(batch)
if self.shuffle:
batch_ids = torch.randperm(len(batches), generator=g).tolist()
batches = [batches[i] for i in batch_ids]
self.batches = batches
assert len(self.batches) * self.batch_size == self.num_samples
return iter(self.batches)
def _bisect(self, x, lo=0, hi=None):
if hi is None:
hi = len(self.boundaries) - 1
if hi > lo:
mid = (hi + lo) // 2
if self.boundaries[mid] < x and x <= self.boundaries[mid+1]:
return mid
elif x <= self.boundaries[mid]:
return self._bisect(x, lo, mid)
else:
return self._bisect(x, mid + 1, hi)
else:
return -1
def __len__(self):
return self.num_samples // self.batch_size
Train_48.py
import os
import json
import argparse
import itertools
import math
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
import time
import commons
import utils
from data_utils_48 import (
TextAudioSpeakerLoader,
TextAudioSpeakerCollate,
DistributedBucketSampler
)
from models import (
SynthesizerTrn,
MultiPeriodDiscriminator,
)
from losses import (
generator_loss,
discriminator_loss,
feature_loss,
kl_loss
)
from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
torch.backends.cudnn.benchmark = True
global_step = 0
#os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'INFO'
def main():
"""Assume Single Node Multi GPUs Training Only"""
assert torch.cuda.is_available(), "CPU training is not allowed."
hps = utils.get_hparams()
n_gpus = torch.cuda.device_count()
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = hps.train.port
mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,))
def run(rank, n_gpus, hps):
global global_step
if rank == 0:
logger = utils.get_logger(hps.model_dir)
logger.info(hps)
utils.check_git_hash(hps.model_dir)
writer = SummaryWriter(log_dir=hps.model_dir)
writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
dist.init_process_group(backend='gloo', init_method='env://', world_size=n_gpus, rank=rank)
torch.manual_seed(hps.train.seed)
torch.cuda.set_device(rank)
train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps)
train_sampler = DistributedBucketSampler(
train_dataset,
hps.train.batch_size,
[32,300,400,500,600,700,800,900,1000],
num_replicas=n_gpus,
rank=rank,
shuffle=True)
collate_fn = TextAudioSpeakerCollate(hps)
train_loader = DataLoader(train_dataset, num_workers=4, shuffle=False, pin_memory=True,
collate_fn=collate_fn, batch_sampler=train_sampler)
if rank == 0:
eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps)
eval_loader = DataLoader(eval_dataset, num_workers=4, shuffle=True,
batch_size=hps.train.batch_size, pin_memory=False,
drop_last=False, collate_fn=collate_fn)
net_g = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // 960,
**hps.model).cuda(rank)
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
optim_g = torch.optim.AdamW(
net_g.parameters(),
hps.train.learning_rate,
betas=hps.train.betas,
eps=hps.train.eps)
optim_d = torch.optim.AdamW(
net_d.parameters(),
hps.train.learning_rate,
betas=hps.train.betas,
eps=hps.train.eps)
net_g = DDP(net_g, device_ids=[rank])#, find_unused_parameters=True)
net_d = DDP(net_d, device_ids=[rank])
try:
_, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g)
_, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d)
global_step = (epoch_str - 1) * len(train_loader)
except:
epoch_str = 1
global_step = 0
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str-2)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str-2)
scaler = GradScaler(enabled=hps.train.fp16_run)
for epoch in range(epoch_str, hps.train.epochs + 1):
if rank==0:
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, eval_loader], logger, [writer, writer_eval])
else:
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, None], None, None)
scheduler_g.step()
scheduler_d.step()
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
net_g, net_d = nets
optim_g, optim_d = optims
scheduler_g, scheduler_d = schedulers
train_loader, eval_loader = loaders
if writers is not None:
writer, writer_eval = writers
train_loader.batch_sampler.set_epoch(epoch)
global global_step
net_g.train()
net_d.train()
for batch_idx, items in enumerate(train_loader):
start_time = time.time()
if hps.model.use_spk:
c, spec, y, spk = items
g = spk.cuda(rank, non_blocking=True)
else:
c, spec, y = items
g = None
spec, y = spec.cuda(rank, non_blocking=True), y.cuda(rank, non_blocking=True)
c = c.cuda(rank, non_blocking=True)
with autocast(enabled=hps.train.fp16_run):
y_hat, ids_slice, z_mask,\
(z, z_p, m_p, logs_p, m_q, logs_q) = net_g(c, spec, g=g)
#print(ids_slice)
mel = mel_spectrogram_torch(
y.squeeze(1),
1920,
hps.data.n_mel_channels,
48000,
480,
1920,
hps.data.mel_fmin,
hps.data.mel_fmax
)
y_mel = commons.slice_segments(mel, ids_slice * 2, hps.train.segment_size // 480)
y_hat_mel = mel_spectrogram_torch(
y_hat.squeeze(1),
1920,
hps.data.n_mel_channels,
48000,
480,
1920,
hps.data.mel_fmin,
hps.data.mel_fmax
)
y = commons.slice_segments(y, ids_slice * 960, hps.train.segment_size) # slice
# Discriminator
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
with autocast(enabled=False):
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
loss_disc_all = loss_disc
optim_d.zero_grad()
scaler.scale(loss_disc_all).backward()
scaler.unscale_(optim_d)
grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
scaler.step(optim_d)
with autocast(enabled=hps.train.fp16_run):
# Generator
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
with autocast(enabled=False):
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
loss_fm = feature_loss(fmap_r, fmap_g)
loss_gen, losses_gen = generator_loss(y_d_hat_g)
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
optim_g.zero_grad()
scaler.scale(loss_gen_all).backward()
scaler.unscale_(optim_g)
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
scaler.step(optim_g)
scaler.update()
if rank==0:
if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]['lr']
losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_kl]
logger.info('Train Epoch: {} [{:.0f}%]'.format(
epoch,
100. * batch_idx / len(train_loader)))
logger.info([x.item() for x in losses] + [global_step, lr])
scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr, "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g}
scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl})
scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
image_dict = {
"slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
"slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),
"all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
}
utils.summarize(
writer=writer,
global_step=global_step,
images=image_dict,
scalars=scalar_dict)
if global_step % hps.train.eval_interval == 0:
evaluate(hps, net_g, eval_loader, writer_eval)
utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step)))
utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step)))
global_step += 1
new_time = time.time()
print(f"step {global_step}: speed={1/(new_time-start_time):.2f} steps/sec, loss_mel={loss_mel:.3f}, loss_kl={loss_kl:.3f}, loss_fm={loss_fm:.3f}, loss_gen={loss_gen:.3f}")
if rank == 0:
logger.info('====> Epoch: {}'.format(epoch))
def evaluate(hps, generator, eval_loader, writer_eval):
generator.eval()
with torch.no_grad():
for batch_idx, items in enumerate(eval_loader):
if hps.model.use_spk:
c, spec, y, spk = items
g = spk[:1].cuda(0)
else:
c, spec, y = items
g = None
spec, y = spec[:1].cuda(0), y[:1].cuda(0)
c = c[:1].cuda(0)
break
mel = mel_spectrogram_torch(
y.squeeze(1),
1920,
hps.data.n_mel_channels,
48000,
480,
1920,
hps.data.mel_fmin,
hps.data.mel_fmax
)
y_hat = generator.module.infer(c, g=g)
y_hat_mel = mel_spectrogram_torch(
y_hat.squeeze(1).float(),
1920,
hps.data.n_mel_channels,
48000,
480,
1920,
hps.data.mel_fmin,
hps.data.mel_fmax
)
image_dict = {
"gen/mel": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy()),
"gt/mel": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy())
}
audio_dict = {
"gen/audio": y_hat[0],
"gt/audio": y[0]
}
utils.summarize(
writer=writer_eval,
global_step=global_step,
images=image_dict,
audios=audio_dict,
audio_sampling_rate=48000
)
generator.train()
if __name__ == "__main__":
main()
run process_flist.py only on wav files, for me this was coming when spec.pt files are there in 16hz folder and i was running process_flist.py
I havent run flist, using original filelist.
segment_size
Can you run this code? When i run above code, i got this error message. I didn't change the segment size, and change the upsample_rates to [10,8,6,2]
-- Process 0 terminated with the following error: Traceback (most recent call last): File "/home/hudson-4way/anaconda3/envs/freevc/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap fn(i, args) File "/home/hudson-4way/Voice/FreeVC/train_48.py", line 120, in run train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, File "/home/hudson-4way/Voice/FreeVC/train_48.py", line 155, in train_and_evaluate (z, z_p, m_p, logs_p, m_q, logs_q) = net_g(c, spec, g=g) File "/home/hudson-4way/anaconda3/envs/freevc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(input, kwargs) File "/home/hudson-4way/anaconda3/envs/freevc/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1040, in forward output = self._run_ddp_forward(*inputs, *kwargs) File "/home/hudson-4way/anaconda3/envs/freevc/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1000, in _run_ddp_forward return module_to_run(inputs[0], kwargs[0]) File "/home/hudson-4way/anaconda3/envs/freevc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, kwargs) File "/home/hudson-4way/Voice/FreeVC/models.py", line 340, in forward o = self.dec(z_slice, g=g) File "/home/hudson-4way/anaconda3/envs/freevc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, *kwargs) File "/home/hudson-4way/Voice/FreeVC/models.py", line 116, in forward x = self.upsi File "/home/hudson-4way/anaconda3/envs/freevc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl result = forward_call(input, kwargs) File "/home/hudson-4way/anaconda3/envs/freevc/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 801, in forward return F.conv_transpose1d( RuntimeError: negative padding is not supported
no I cant, thats the same error I get posted that here https://github.com/OlaWod/FreeVC/issues/63#issuecomment-1434845714
I compared all of the changes between the original files, and the 24khz version using compare in notepad++. So I was able to find everything you changed for the 24khz version, I then doubled all the values to make the 48khz version. I did find those and changed them.
not simply double all the values.
mel = mel_spectrogram_torch(
y.squeeze(1),
960,
hps.data.n_mel_channels,
48000,
240,
960,
hps.data.mel_fmin,
hps.data.mel_fmax
)
y_mel = commons.slice_segments(mel, ids_slice * 4, hps.train.segment_size // 240)
y_hat_mel = mel_spectrogram_torch(
y_hat.squeeze(1),
960,
hps.data.n_mel_channels,
48000,
240,
960,
hps.data.mel_fmin,
hps.data.mel_fmax
)
y = commons.slice_segments(y, ids_slice * 960, hps.train.segment_size) # slice
Ahh ok, I applied those changes, Still getting the same error.
Applied them to the evaluate section as well.
Traceback (most recent call last):
File "train_48.py", line 295, in <module>
main()
File "train_48.py", line 49, in main
mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,))
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\multiprocessing\spawn.py", line 230, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\multiprocessing\spawn.py", line 188, in start_processes
while not context.join():
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\multiprocessing\spawn.py", line 150, in join
raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:
-- Process 0 terminated with the following error:
Traceback (most recent call last):
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\multiprocessing\spawn.py", line 59, in _wrap
fn(i, *args)
File "Z:\FreeVC-TrainHIFI\train_48.py", line 115, in run
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, eval_loader], logger, [writer, writer_eval])
File "Z:\FreeVC-TrainHIFI\train_48.py", line 149, in train_and_evaluate
(z, z_p, m_p, logs_p, m_q, logs_q) = net_g(c, spec, g=g)
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\parallel\distributed.py", line 886, in forward
output = self.module(*inputs[0], **kwargs[0])
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "Z:\FreeVC-TrainHIFI\models.py", line 336, in forward
o = self.dec(z_slice, g=g)
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "Z:\FreeVC-TrainHIFI\models.py", line 113, in forward
x = self.ups[i](x)
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\modules\module.py", line 1120, in _call_impl
result = forward_call(*input, **kwargs)
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\modules\conv.py", line 774, in forward
output_padding, self.groups, self.dilation)
RuntimeError: negative padding is not supported
this error can be because some upsample_kernel_size
is smaller than the corresponding upsample_rate
.
adjust upsample_kernel_sizes
and upsample_rates
in the config.
Increasing upsample_kernel_sizes from [16,16,4,4] to [16,16,6,4] Did the trick, its now running, but ran out of memory, had to decrease batch size all the way down to 32. Thanks for the help!
Well this is strange, finished 5 epochs with no issues, halfway through epoch 6 I got this error
File "Z:\FreeVC-TrainHIFI\data_utils_48.py", line 158, in __call__
c_padded[i, :, :c.size(1)] = c
RuntimeError: The expanded size of the tensor (250) must match the existing size (251) at non-singleton dimension 1. Target sizes: [1024, 250]. Tensor sizes: [1024, 251]
what could have caused this? I mean it made it through the dataset 5 times.....
So I replaced line 158 with code from the original datautils,
try:
c_padded[i, :, :c.size(1)] = c
except:
if len(c.shape) > 2:
c_padded[i, :, :c.size(1)] = c[:,:-1,:]
else:
c_padded[i, :, :c.size(1)] = c[:,:-1]
This was removed in the 24khz datautils. That seems to have solved the error.
Now im getting this error after 30 epochs Any suggestions or ideas?
Traceback (most recent call last):
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\multiprocessing\spawn.py", line 59, in _wrap
fn(i, *args)
File "Z:\FreeVC-TrainHIFI\train_48.py", line 115, in run
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, eval_loader], logger, [writer, writer_eval])
File "Z:\FreeVC-TrainHIFI\train_48.py", line 136, in train_and_evaluate
for batch_idx, items in enumerate(train_loader):
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\utils\data\dataloader.py", line 521, in __next__
data = self._next_data()
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\utils\data\dataloader.py", line 1203, in _next_data
return self._process_data(data)
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\utils\data\dataloader.py", line 1229, in _process_data
data.reraise()
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\_utils.py", line 434, in reraise
raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 1.
Original Traceback (most recent call last):
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\utils\data\_utils\worker.py", line 287, in _worker_loop
data = fetcher.fetch(index)
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\utils\data\_utils\fetch.py", line 52, in fetch
return self.collate_fn(data)
File "Z:\FreeVC-TrainHIFI\data_utils_48.py", line 180, in __call__
spec_padded, ids_slice = commons.rand_spec_segments(spec_padded, spec_lengths, spec_seglen)
File "Z:\FreeVC-TrainHIFI\commons.py", line 73, in rand_spec_segments
ret = slice_segments(x, ids_str, segment_size)
File "Z:\FreeVC-TrainHIFI\commons.py", line 53, in slice_segments
ret[i] = x[i, :, idx_str:idx_end]
RuntimeError: The expanded size of the tensor (36) must match the existing size (0) at non-singleton dimension 1. Target sizes: [641, 36]. Tensor sizes: [641, 0]
no much idea. maybe some data is too short?
Running at 24khz it works fine, the problem is only when running at 48khz. EDIT: spoke too soon, exact same error show up when running at 24khz as well.
Doesn't the bucket sampler discard data that is too short?
Running at 24khz it works fine, the problem is only when running at 48khz. EDIT: spoke too soon, exact same error show up when running at 24khz as well.
Doesn't the bucket sampler discard data that is too short?
I've recently found this repository: https://github.com/innnky/so-vits-svc/blob/32k/Eng_docs.md It says its based on Free-VC and I see a lot of similar code in preprocessing and stuff but this supports 48k natively and doesn't eat up the whole hard drive with a billion SR files. I got some very clear results except for when there's breathing on the recording. Check it out !
Running at 24khz it works fine, the problem is only when running at 48khz. EDIT: spoke too soon, exact same error show up when running at 24khz as well. Doesn't the bucket sampler discard data that is too short?
I've recently found this repository: https://github.com/innnky/so-vits-svc/blob/32k/Eng_docs.md It says its based on Free-VC and I see a lot of similar code in preprocessing and stuff but this supports 48k natively and doesn't eat up the whole hard drive with a billion SR files. I got some very clear results except for when there's breathing on the recording. Check it out !
Like here:
This is Zoey (L4D) to Alyx (HL 2). I added the breathing clip at the end just to check and it is pretty bad. Idk why that is, probably because their checkpoint was trained on some singing rather than speech. But the speech itself is pretty good I think and pretty high res.
banging my head against a wall here, still getting that same errors, always at random sometimes after 5 epochs, sometimes after 57.......
Traceback (most recent call last):
File "train_48.py", line 295, in <module>
main()
File "train_48.py", line 49, in main
mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,))
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\multiprocessing\spawn.py", line 230, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\multiprocessing\spawn.py", line 188, in start_processes
while not context.join():
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\multiprocessing\spawn.py", line 150, in join
raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:
-- Process 0 terminated with the following error:
Traceback (most recent call last):
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\multiprocessing\spawn.py", line 59, in _wrap
fn(i, *args)
File "Z:\FreeVC-TrainHIFI\train_48.py", line 115, in run
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, eval_loader], logger, [writer, writer_eval])
File "Z:\FreeVC-TrainHIFI\train_48.py", line 136, in train_and_evaluate
for batch_idx, items in enumerate(train_loader):
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\utils\data\dataloader.py", line 521, in __next__
data = self._next_data()
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\utils\data\dataloader.py", line 1203, in _next_data
return self._process_data(data)
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\utils\data\dataloader.py", line 1229, in _process_data
data.reraise()
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\_utils.py", line 434, in reraise
raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 4.
Original Traceback (most recent call last):
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\utils\data\_utils\worker.py", line 287, in _worker_loop
data = fetcher.fetch(index)
File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\utils\data\_utils\fetch.py", line 52, in fetch
return self.collate_fn(data)
File "Z:\FreeVC-TrainHIFI\data_utils_48.py", line 180, in __call__
spec_padded, ids_slice = commons.rand_spec_segments(spec_padded, spec_lengths, spec_seglen)
File "Z:\FreeVC-TrainHIFI\commons.py", line 73, in rand_spec_segments
ret = slice_segments(x, ids_str, segment_size)
File "Z:\FreeVC-TrainHIFI\commons.py", line 53, in slice_segments
ret[i] = x[i, :, idx_str:idx_end]
RuntimeError: The expanded size of the tensor (36) must match the existing size (0) at non-singleton dimension 1. Target sizes: [641, 36]. Tensor sizes: [641, 0]
Likkez, Thanks for the info, ill take a look, but at first glance im not sure its going to do what I need.
I added the following to train48.py
try:
for batch_idx, items in enumerate(train_loader):
.
.
.
except Exception as e:
if isinstance(e, RuntimeError):
print("Exception! Bad batch, skip to next epoch.")
else:
raise e
now I get an exception every 30 or so epochs, its skips the bad batch and continues on. No idea how to pinpoint the bad file at this point, But at least it runs like this...
for the input files (spectrogram, wavlm, 48k wav), spectrogram and wavlm are from 16k wav; spectrogram is of 1280 filter length, 320 hop length, just like those used to train 16k; 48k wav is of the same trimmed index with 16k wav. i.e., the input files can be inherited from those used to train 16k, only 48k wav is new. and for the code just modify those numbers like 480 -> 960. and to pinpoint a simple way is to print out filename, tensor size, etc., every time the data loads.
I reprocessed the 24 and 48khz data using your downsample 24kh file and training at 48khz is now working.
the problem was with the long scalars in the downsample file was (index[0]*48000/22050)
Changed it to wav1 = wav1[int(index[0]*(48000/22050)): int(index[1]*(48000/22050))]
and now it works.
@OlaWod one last question, since I had to change the upsample kernel size I cant resume the 900k model and train it to 48khz. Since I cant resume the existing and need to train a model from scratch, do I need to train with 16khz then switch to 48khz later, or can I do 48 khz from the start?
no need to train with 16k. can do 48k from scratch, or, to save training time, can also load part of the 900k model w/o optimizer.
for k, v in state_dict.items():
try:
if 'dec.ups' not in k:
new_state_dict[k] = saved_state_dict[k]
...
So the spectogram settings are not ideal for 48khz, I want to change them to the following.
mel = mel_spectrogram_torch(
y.squeeze(1),
4096,
512,
48000,
600,
2400,
55,
hps.data.mel_fmax
But I keep getting errors regarding this line here
y_mel = commons.slice_segments(mel, ids_slice * 4, hps.train.segment_size // 150)
File "Z:\FreeVC-TrainHIFI\commons.py", line 53, in slice_segments
ret[i] = x[i, :, idx_str:idx_end]
RuntimeError: The expanded size of the tensor (57) must match the existing size (0) at non-singleton dimension 1. Target sizes: [512, 57]. Tensor sizes: [512, 0]
What changes do I need to make to be able to use those spectogram settings?
mel = mel_spectrogram_torch(
y.squeeze(1),
4096,
512,
48000,
960, # this needs to be 960/480/240/120, etc.
2400,
55,
hps.data.mel_fmax
)
y_mel = commons.slice_segments(mel, ids_slice * 1, hps.train.segment_size // 960)
# y_mel = commons.slice_segments(mel, ids_slice * 2, hps.train.segment_size // 480)
and set upsample rates to 960
Ok that makes sense,
and the 960 here, this has to always match the upsample rate correct?
y = commons.slice_segments(y, ids_slice * 960, hps.train.segment_size) # slice
Ok that makes sense,
and the 960 here, this has to always match the upsample rate correct?
y = commons.slice_segments(y, ids_slice * 960, hps.train.segment_size) # slice
U got any good results out of that? I been trying that other repo I linked and theres lots of noise in breathing which sucks
Yeah ive been getting really good results with the 48khz model. its not perfect, but it could be.... The problem is the long hop of 320 in the base settings. So wavlm etc. trying to change the hop to 128, wavlm isnt setup for that but fed it 40khz wavs instead of 16, so that generates a content output the same length as the audio file but has the effect of a 128 hop. but changing the hop from 320 to 128 in the model isnt as easy as I thought it would be. Getting tons of errors from the slicer again.
@steven850 Hi, U got any improvements from previous experiments??
Still working on trying to get it to run with a 128hop. That Slicer is a real pain in the ass.
banging my head against a wall here, still getting that same errors, always at random sometimes after 5 epochs, sometimes after 57.......
Traceback (most recent call last): File "train_48.py", line 295, in <module> main() File "train_48.py", line 49, in main mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,)) File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\multiprocessing\spawn.py", line 230, in spawn return start_processes(fn, args, nprocs, join, daemon, start_method='spawn') File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\multiprocessing\spawn.py", line 188, in start_processes while not context.join(): File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\multiprocessing\spawn.py", line 150, in join raise ProcessRaisedException(msg, error_index, failed_process.pid) torch.multiprocessing.spawn.ProcessRaisedException: -- Process 0 terminated with the following error: Traceback (most recent call last): File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\multiprocessing\spawn.py", line 59, in _wrap fn(i, *args) File "Z:\FreeVC-TrainHIFI\train_48.py", line 115, in run train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, eval_loader], logger, [writer, writer_eval]) File "Z:\FreeVC-TrainHIFI\train_48.py", line 136, in train_and_evaluate for batch_idx, items in enumerate(train_loader): File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\utils\data\dataloader.py", line 521, in __next__ data = self._next_data() File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\utils\data\dataloader.py", line 1203, in _next_data return self._process_data(data) File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\utils\data\dataloader.py", line 1229, in _process_data data.reraise() File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\_utils.py", line 434, in reraise raise exception RuntimeError: Caught RuntimeError in DataLoader worker process 4. Original Traceback (most recent call last): File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\utils\data\_utils\worker.py", line 287, in _worker_loop data = fetcher.fetch(index) File "C:\Users\steven\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\utils\data\_utils\fetch.py", line 52, in fetch return self.collate_fn(data) File "Z:\FreeVC-TrainHIFI\data_utils_48.py", line 180, in __call__ spec_padded, ids_slice = commons.rand_spec_segments(spec_padded, spec_lengths, spec_seglen) File "Z:\FreeVC-TrainHIFI\commons.py", line 73, in rand_spec_segments ret = slice_segments(x, ids_str, segment_size) File "Z:\FreeVC-TrainHIFI\commons.py", line 53, in slice_segments ret[i] = x[i, :, idx_str:idx_end] RuntimeError: The expanded size of the tensor (36) must match the existing size (0) at non-singleton dimension 1. Target sizes: [641, 36]. Tensor sizes: [641, 0]
I'm exactly stuck in the same phase. I'm getting this error once every few epochs. I've also tried downsampling with
wav1 = wav1[int(index[0]*(48000/22050)): int(index[1]*(48000/22050))]
, but it still happens. @steven850 Did you find a way to point out those corrupt files?
Thank you for the insightful discussion. @steven850 could you kindly share the latest data utili and training codes that you found effective, along with any updated hyperparameters you've come across?
Since I found those issues with the 24khz preprocess I figured I might as well run a 48khz model instead of the 24 since I have to start from the 900k model regardless.
I have already updated the code in datautils and train for 48khz, as well as the preprocess. (I took all the 24khz relevant values in train and datautils and doubled the values) My question is what is about the config.json file. im not sure how you came up with the segment_size in the freevc-24.json? What would I need for 48khz? same with the upsample_rates, will [10,8,6,2] work?