mozilla / TTS

:robot: :speech_balloon: Deep learning for Text to Speech (Discussion forum: https://discourse.mozilla.org/c/tts)
Mozilla Public License 2.0
9.26k stars 1.24k forks source link

Weird spectrogram when using ExtractTTSSpectrogram #278

Closed orbisAI closed 5 years ago

orbisAI commented 5 years ago

When I try to use the notebook to generate spectrogram for training a vocoder, I get the following results as spectrogram (plz note it's upside down):

Screen Shot 2019-09-09 at 12 56 10 PM

What's causing the vertical gaps in between the spectrogram data? I do not see such gaps when checking spectrograms during training. When we do GL with this spectrogram, as expected the sound is super jittery and broken.

Screen Shot 2019-09-09 at 1 08 16 PM

Here is the spectrogram I can see in tensorboard. individualImage

FYI the notebook has been modified slightly to be used with libritts:

This is a notebook to generate mel-spectrograms from a TTS model to be used for WaveRNN training.

TTS_PATH = "workspace"
%load_ext autoreload
%autoreload 2
import os
import sys
sys.path.append(TTS_PATH)
sys.path.append("..")
import torch
import importlib
import numpy as np
from tqdm import tqdm as tqdm
from torch.utils.data import DataLoader
from TTS.models.tacotron2 import Tacotron2
from TTS.datasets.TTSDataset import MyDataset
from TTS.utils.audio import AudioProcessor
from TTS.utils.visual import plot_spectrogram
from TTS.utils.generic_utils import load_config, setup_model
from TTS.datasets.preprocess import ljspeech
%matplotlib inline

import os
os.environ['CUDA_VISIBLE_DEVICES']='2'
def set_filename(wav_path, out_path):
    wav_file = os.path.basename(wav_path)
    file_name = wav_file.split('.')[0]
    os.makedirs(os.path.join(out_path, "quant"), exist_ok=True)
    os.makedirs(os.path.join(out_path, "mel"), exist_ok=True)
    os.makedirs(os.path.join(out_path, "wav_gl"), exist_ok=True)
    wavq_path = os.path.join(out_path, "quant", file_name)
    mel_path = os.path.join(out_path, "mel", file_name)
    wav_path = os.path.join(out_path, "wav_gl", file_name)
    return file_name, wavq_path, mel_path, wav_path
OUT_PATH = "/waveglow_data"
DATA_PATH = "LibriTTS/train-360"
DATASET = "libri_tts"
METADATA_FILE = "metadata.txt"
CONFIG_PATH = "TTS/config_gst_libritts.json"
MODEL_FILE = "TTS/results/libritts-360-August-31-2019_08+50AM-234b44d/checkpoint_205000.pth.tar"
DRY_RUN = False   # if False, does not generate output files, only computes loss and visuals.
BATCH_SIZE = 32

use_cuda = torch.cuda.is_available()
print(" > CUDA enabled: ", use_cuda)

c = load_config(CONFIG_PATH)
ap = AudioProcessor(bits=9, **c.audio)
c.prenet_dropout = False
c.separate_stopnet = True
 > CUDA enabled:  True
 > Setting up Audio Processor...
 | > sample_rate:24000
 | > num_mels:80
 | > min_level_db:-100
 | > frame_shift_ms:12.5
 | > frame_length_ms:50
 | > ref_level_db:20
 | > num_freq:1025
 | > power:1.5
 | > preemphasis:0.98
 | > griffin_lim_iters:60
 | > signal_norm:True
 | > symmetric_norm:False
 | > mel_fmin:0
 | > mel_fmax:8000.0
 | > max_norm:1.0
 | > clip_norm:True
 | > do_trim_silence:True
 | > n_fft:2048
 | > hop_length:300
 | > win_length:1200
print(os.getcwd())
print(c.data_path)
c.data_path = '../'+c.data_path
print(c.data_path)
from datasets.preprocess import get_preprocessor_by_name

if "meta_data_train" not in globals():
    print("test2")
    if c.meta_file_train is not None:
        meta_data_train = get_preprocessor_by_name(c.dataset)(c.data_path, c.meta_file_train)
        print("test3")
    else:
        meta_data_train = get_preprocessor_by_name(c.dataset)(c.data_path)
        print("test")

dataset = MyDataset(
    c.r,
    c.text_cleaner,
    meta_data= meta_data_train,
    ap=ap,
    batch_group_size= c.batch_group_size * c.batch_size,
    min_seq_len=c.min_seq_len,
    max_seq_len=c.max_seq_len,
    phoneme_cache_path=c.phoneme_cache_path,
    use_phonemes=c.use_phonemes,
    phoneme_language=c.phoneme_language,
    enable_eos_bos=c.enable_eos_bos_chars)
loader = DataLoader(
    dataset,
    batch_size=c.batch_size,
    shuffle=False,
    collate_fn=dataset.collate_fn,
    drop_last=False,)
from utils.speakers import load_speaker_mapping, save_speaker_mapping, \
    get_speakers

speakers = get_speakers(c.data_path, c.meta_file_train, c.dataset)

speaker_mapping = {name: i
                   for i, name in enumerate(speakers)}
save_speaker_mapping(OUT_PATH, speaker_mapping)
num_speakers = len(speaker_mapping)
print("Training with {} speakers: {}".format(num_speakers,
                                             ", ".join(speakers)))
Training with 904 speakers: 100, 1001, 101, 1012, 1018, 1025, 1027, 1028, 1046, 1050, 1052, 1053, 1054, 1058, 1060, 1061, 1066, 1079, 1093, 1100, 1112, 112, 1121, 114, 115, 1160, 1165, 1175, 1182, 119, 1195, 1212, 122, 1222, 1224, 1226, 1241, 1259, 126, 1264, 1265, 1271, 1283, 1289, 1290, 1296, 1311, 1313, 1316, 1322, 1323, 1335, 1336, 1337, 1343, 1348, 1349, 1365, 1379, 1382, 1383, 1387, 1390, 1392, 14, 1401, 1413, 1417, 1422, 1425, 1445, 1446, 1448, 1460, 1463, 1472, 1473, 1482, 1487, 1498, 1509, 1513, 1535, 1536, 154, 1547, 1552, 1556, 157, 1571, 159, 16, 1603, 1607, 1629, 1638, 1639, 1641, 1645, 1649, 166, 1668, 1678, 17, 1705, 1724, 1731, 1734, 1740, 1748, 175, 1752, 1754, 176, 1769, 1776, 1777, 1779, 1789, 1801, 1806, 1811, 1825, 1826, 1827, 1845, 1849, 1851, 1859, 1874, 188, 1885, 1903, 1913, 1914, 192, 1923, 1933, 1943, 1944, 1958, 1961, 1974, 1987, 2004, 2010, 2012, 203, 2039, 204, 2045, 205, 2053, 2056, 2060, 2061, 207, 2074, 208, 2085, 209, 2093, 210, 2110, 2113, 2127, 2137, 2146, 2149, 2156, 216, 2162, 2167, 217, 2194, 22, 2201, 2204, 2229, 2230, 2238, 224, 2240, 225, 2254, 2256, 2269, 227, 2272, 2285, 2294, 2299, 231, 2319, 2364, 2368, 2388, 2393, 2397, 240, 2401, 2404, 2411, 242, 2427, 246, 2473, 2481, 249, 2494, 2498, 2499, 2512, 2517, 2531, 2532, 2562, 2570, 2573, 2577, 258, 2581, 2582, 2589, 2592, 2598, 2618, 2628, 2638, 2652, 2654, 2673, 2674, 2688, 2696, 2709, 272, 274, 2741, 2751, 2758, 2769, 2774, 2775, 278, 2785, 2787, 2790, 28, 2812, 2816, 2823, 2827, 2853, 288, 2882, 2920, 2929, 296, 2960, 2971, 2992, 2999, 30, 3001, 3003, 3008, 3009, 3025, 303, 3032, 3046, 3070, 3072, 3082, 3083, 3094, 3105, 3114, 3118, 3119, 3157, 3171, 318, 3180, 3185, 3187, 3215, 3221, 3224, 3228, 323, 3230, 3258, 3274, 3289, 329, 3294, 3307, 3328, 3330, 3340, 335, 3357, 3361, 3368, 337, 3370, 3379, 3380, 3389, 339, 340, 3446, 3448, 345, 3482, 3483, 3490, 3493, 3513, 3521, 353, 3537, 3540, 3546, 3549, 3551, 3584, 359, 3615, 362, 3630, 3638, 3645, 3654, 3686, 369, 3703, 3717, 3728, 373, 3733, 3738, 3781, 3790, 3792, 38, 380, 3816, 3825, 3835, 3851, 3852, 3864, 3866, 3869, 3876, 3889, 3905, 3914, 3922, 3923, 3927, 3945, 3967, 3972, 3977, 3979, 398, 3989, 3994, 4010, 4013, 4039, 4044, 4054, 4057, 4064, 4071, 408, 409, 4098, 4108, 4110, 4111, 4116, 4133, 4138, 4145, 4148, 4152, 4222, 4226, 4236, 4238, 4243, 4246, 4257, 4260, 4278, 4289, 4290, 4331, 4335, 434, 4356, 4358, 4363, 4381, 439, 4425, 4427, 4433, 4434, 4438, 4490, 4495, 451, 4519, 4535, 454, 4586, 459, 4590, 4592, 4595, 4598, 4629, 4681, 4719, 472, 4731, 4733, 4734, 4744, 475, 476, 479, 480, 4800, 4806, 4807, 4837, 4839, 4846, 4848, 4854, 4856, 4860, 487, 4899, 492, 4926, 4945, 4957, 4967, 497, 4973, 500, 5002, 5007, 501, 5012, 5029, 5039, 5054, 5062, 5063, 5092, 5093, 510, 511, 5115, 512, 5123, 5126, 5133, 5139, 5147, 5154, 5157, 5186, 5189, 5190, 5206, 5239, 5242, 5246, 5261, 5266, 5290, 5293, 5304, 5319, 5333, 5337, 534, 5386, 5389, 54, 5400, 5401, 543, 5448, 548, 5489, 549, 55, 5513, 5519, 5538, 5570, 5583, 5588, 559, 56, 5604, 5606, 561, 5618, 5622, 5635, 5655, 5656, 5660, 5672, 5684, 5712, 5717, 5723, 5724, 5727, 5731, 5740, 5746, 576, 5767, 5776, 580, 5802, 5809, 581, 5810, 583, 5868, 5876, 5883, 589, 5909, 5914, 5918, 593, 5935, 594, 5940, 596, 5968, 597, 5975, 598, 5984, 5985, 6006, 6014, 6032, 6037, 6038, 6054, 606, 6060, 6075, 6080, 6082, 6098, 6099, 6104, 6115, 6119, 612, 6120, 6139, 6157, 6160, 6167, 6188, 6189, 6206, 6215, 6233, 6235, 6258, 6269, 6286, 6288, 6294, 6300, 6308, 6317, 6330, 6339, 6341, 6352, 6359, 636, 637, 6371, 6373, 6378, 6388, 639, 6395, 64, 6406, 6426, 6446, 6458, 6492, 6494, 6497, 6499, 6505, 6509, 6510, 6518, 6519, 6538, 6544, 6550, 6553, 6555, 6567, 6574, 6575, 6620, 663, 6637, 664, 6643, 666, 667, 6673, 6683, 6686, 6690, 6694, 6696, 6701, 6727, 6763, 6782, 6788, 6828, 6865, 6877, 688, 6895, 6904, 6918, 6924, 6927, 6937, 6956, 6965, 698, 6981, 699, 6993, 70, 7000, 7011, 7030, 7051, 7061, 7069, 707, 708, 7085, 7090, 7095, 711, 7117, 7120, 7126, 7128, 7134, 7139, 7140, 7145, 716, 7169, 718, 7188, 7229, 724, 7240, 7241, 7245, 7247, 7258, 7276, 7285, 7286, 7294, 7297, 731, 7313, 7314, 7316, 7318, 7335, 7339, 7342, 7383, 7384, 7395, 7398, 7416, 7434, 7437, 7445, 7460, 7475, 7478, 7481, 7484, 7495, 7498, 7515, 7518, 7520, 7525, 7538, 7540, 7553, 7555, 7558, 7569, 7594, 764, 7647, 7657, 7665, 7688, 770, 7704, 7705, 7717, 7720, 7730, 7732, 7733, 7739, 7752, 7754, 7766, 7777, 7783, 7789, 7802, 7809, 781, 7816, 7825, 7828, 783, 7832, 7833, 7837, 7867, 7868, 7874, 7881, 79, 7909, 7910, 7926, 7932, 7933, 7938, 7939, 7945, 7949, 7956, 7957, 7959, 7962, 7967, 7981, 7982, 7991, 7994, 7995, 8006, 8008, 8011, 8028, 803, 8050, 8057, 806, 8066, 8075, 8080, 8097, 81, 8113, 8118, 8119, 8138, 8142, 815, 8152, 816, 8163, 8176, 8183, 8190, 8193, 8194, 8195, 820, 8222, 8225, 8228, 8266, 829, 830, 8300, 8329, 8347, 835, 836, 8388, 8396, 8401, 8404, 8410, 8421, 8459, 8464, 8474, 8479, 8490, 8494, 8498, 850, 8506, 8527, 8534, 8545, 8573, 8575, 8591, 8592, 8605, 8619, 8635, 8643, 8677, 868, 8684, 8687, 8699, 8705, 8713, 8718, 8722, 8725, 8742, 8758, 8771, 8772, 8776, 8786, 8791, 882, 8820, 8824, 8825, 8848, 8855, 8875, 8879, 8887, 899, 90, 9022, 9023, 9026, 920, 922, 925, 93, 948, 949, 953, 954, 957, 968, 979, 98, 984, 986
from TTS.utils.text.symbols import symbols, phonemes
from TTS.utils.generic_utils import sequence_mask
from TTS.layers.losses import L1LossMasked
from TTS.utils.text.symbols import symbols, phonemes

# load the model
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model = setup_model(num_chars, num_speakers, c)
checkpoint = torch.load(MODEL_FILE)
model.load_state_dict(checkpoint['model'])
print(checkpoint['step'])
model.eval()
if use_cuda:
    model = model.cuda()
 > Using model: TacotronGST
205000
import pickle

file_idxs = []
losses = []
postnet_losses = []
criterion = L1LossMasked()

for data in tqdm(loader):
    # setup input data
    text_input = data[0]
    text_lengths = data[1]
    speaker_names = data[2]
    linear_input = data[3] if c.model in ["Tacotron", "TacotronGST"] else None
    mel_input = data[4]
    mel_lengths = data[5]
    stop_targets = data[6]
    item_idx = data[7]
    avg_text_length = torch.mean(text_lengths.float())
    avg_spec_length = torch.mean(mel_lengths.float())

    if c.use_speaker_embedding:
        speaker_ids = [speaker_mapping[speaker_name]
                   for speaker_name in speaker_names]
        speaker_ids = torch.LongTensor(speaker_ids)
    else:
        speaker_ids = None

    if use_cuda:
        text_input = text_input.cuda(non_blocking=True)
        text_lengths = text_lengths.cuda(non_blocking=True)
        mel_input = mel_input.cuda(non_blocking=True)
        mel_lengths = mel_lengths.cuda(non_blocking=True)
        linear_input = linear_input.cuda(non_blocking=True) if c.model in ["Tacotron", "TacotronGST"] else None
        stop_targets = stop_targets.cuda(non_blocking=True)
        if speaker_ids is not None:
            speaker_ids = speaker_ids.cuda(non_blocking=True)

    mask = sequence_mask(text_lengths)
    mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(
            text_input, text_lengths, mel_input, speaker_ids=speaker_ids)

    mel_specs = []
    if c.model == "TacotronGST":
        postnet_outputs = postnet_outputs.data.cpu().numpy()
        for b in range(postnet_outputs.shape[0]):
            postnet_output = postnet_outputs[b]
            mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda())
    postnet_outputs = torch.stack(mel_specs)

    loss = criterion(mel_outputs, mel_input, mel_lengths)
#     if c.model in ["Tacotron", "TacotronGST"]:
    loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)
    losses.append(loss.item())
    postnet_losses.append(loss_postnet.item())

    if not DRY_RUN:
        for idx in range(text_input.shape[0]):
            wav_file_path = item_idx[idx]
            wav = ap.load_wav(wav_file_path)
            file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)
            file_idxs.append(file_name)

#             # quantize and save wav
#             wavq = ap.quantize(wav)
#             np.save(wavq_path, wavq)

            # save TTS mel
            mel = postnet_outputs[idx]
            mel = mel.data.cpu().numpy()
            mel_length = mel_lengths[idx]
            mel = mel[:mel_length, :].T
            np.save(mel_path, mel)

            # save GL voice
    #         wav_gen = ap.inv_mel_spectrogram(mel.T) # mel to wav
    #         wav_gen = ap.quantize(wav_gen)
    #         np.save(wav_path, wav_gen)
  4%|▎         | 107/2858 [33:51<15:02:06, 19.68s/it]

### Check model performance

```python
idx = 1
mel_example = postnet_outputs[idx].data.cpu().numpy()
plot_spectrogram(mel_example[:mel_lengths[idx], :], ap);
print(mel_example[:mel_lengths[1], :].shape)
mel_example = mel_outputs[idx].data.cpu().numpy()
plot_spectrogram(mel_example[:mel_lengths[idx], :], ap);
print(mel_example[:mel_lengths[1], :].shape)
wav = ap.load_wav(item_idx[idx])
melt = ap.melspectrogram(wav)
print(melt.shape)
plot_spectrogram(melt.T, ap);
# postnet, decoder diff
from matplotlib import pylab as plt
mel_diff = mel_outputs[idx] - postnet_outputs[idx]
plt.figure(figsize=(16, 10))
plt.imshow(abs(mel_diff.detach().cpu().numpy()[:mel_lengths[idx],:]).T,aspect="auto", origin="lower");
plt.colorbar()
plt.tight_layout()
from matplotlib import pylab as plt
# mel = mel_poutputs[idx].detach().cpu().numpy()
mel = postnet_outputs[idx].detach().cpu().numpy()
mel_diff2 = melt.T - mel[:melt.shape[1]]
plt.figure(figsize=(16, 10))
plt.imshow(abs(mel_diff2).T,aspect="auto", origin="lower");
plt.colorbar()
plt.tight_layout()
orbisAI commented 5 years ago

Ok, so the problem was r. I set r to be gradually training as per config below: "gradual_training": [[0, 7, 32], [10000, 5, 32], [50000, 3, 32], [130000, 2, 16], [290000, 1, 8]],

I trained model to 200k, so r should have been 2.

But I was loading the dataset and the model with r = 7, causing weird spectrogram outputs. Now I get this after changing r to 2: Screen Shot 2019-09-09 at 2 38 31 PM

Which makes much more sense.