as-ideas / TransformerTTS

🤖💬 Transformer TTS: Implementation of a non-autoregressive Transformer based neural network for text to speech.
https://as-ideas.github.io/TransformerTTS/
Other
1.13k stars 227 forks source link

Fine-tuning HifiGAN using output mels #92

Closed kudanai closed 3 years ago

kudanai commented 3 years ago

The direct output mels from TransformerTTS seem to be incompatible with input for HifiGAN. I was able to make it work by applying the following patch on HifiGAN (please ignore the prints for debug)

This appears to work. Just wanted confirm if this is a correct approach. Also thought it might be helpful to someone having the same issue.

**note: the value 80 is config['mel_channels']

diff --git a/meldataset.py b/meldataset.py
index 4502924..b72981b 100644
--- a/meldataset.py
+++ b/meldataset.py
@@ -142,10 +142,14 @@ class MelDataset(torch.utils.data.Dataset):
         else:
             mel = np.load(
                 os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy'))
+
             mel = torch.from_numpy(mel)

             if len(mel.shape) < 3:
                 mel = mel.unsqueeze(0)
+            
+            if not mel.shape[1] == 80:
+              mel = mel.transpose(1,2)

             if self.split:
                 frames_per_seg = math.ceil(self.segment_size / self.hop_size)
diff --git a/train.py b/train.py
index 3b55094..c019e01 100644
--- a/train.py
+++ b/train.py
@@ -73,16 +73,24 @@ def train(rank, a, h):

     training_filelist, validation_filelist = get_dataset_filelist(a)

+    print(a)
+
+
+    print(f"Training with {len(training_filelist)}, validation with {len(validation_filelist)} files")
+    print(training_filelist[:5])
+    print(validation_filelist[:5])
     trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels,
                           h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0,
                           shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device,
                           fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir)

+    # print(trainset[0].shape)
+
     train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None

     train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
                               sampler=train_sampler,
-                              batch_size=h.batch_size,
+                              batch_size=1,   #h.batch_size,
                               pin_memory=True,
                               drop_last=True)

@@ -196,6 +204,12 @@ def train(rank, a, h):
                             y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
                                                           h.hop_size, h.win_size,
                                                           h.fmin, h.fmax_for_loss)
+
+                            if y_mel.size(2) > y_g_hat_mel.size(2):
+                                y_g_hat_mel = torch.nn.functional.pad(y_g_hat_mel, (0, y_mel.size(2) - y_g_hat_mel.size(2)), 'constant')
+                            elif y_mel.size(2) < y_g_hat_mel.size(2):
+                                y_mel = torch.nn.functional.pad(y_mel, (0, y_g_hat_mel.size(2) - y_mel.size(2)), 'constant')
+
                             val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()

                             if j <= 4:
cfrancesco commented 3 years ago

Hi, that is correct, in order to feed the mels to HiFiGAN or MelGAN you need to swap the last two axis. No other changes should be needed afaik

Have a look at the "vocoding" branch if you want, there I have some (ugly) code to predict with these vocoders.

kudanai commented 3 years ago

Thank you for the response.

I just tested out your suggestion and it does indeed work with just the axis swap. I'm updating the diff here.

We've had very good results on TransformerTTS + HiFiGAN Sound Sample

(In retrospect, this issue would make more sense under the new Discussions feature)

diff --git a/meldataset.py b/meldataset.py
index 4502924..b72981b 100644
--- a/meldataset.py
+++ b/meldataset.py
@@ -142,10 +142,14 @@ class MelDataset(torch.utils.data.Dataset):
         else:
             mel = np.load(
                 os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy'))
             mel = torch.from_numpy(mel)

             if len(mel.shape) < 3:
                 mel = mel.unsqueeze(0)
+            if not mel.shape[1] == 80:
+              mel = mel.transpose(1,2)

             if self.split:
                 frames_per_seg = math.ceil(self.segment_size / self.hop_size)
cfrancesco commented 3 years ago

Hi, cool, in what language is your sample? Did you use the phonemizer for text conversion? Not knowing the language my comment is probably invalid, but it seems a little flat, did you train with the pitch prediction too?

kudanai commented 3 years ago

It's "Dhivehi" using the Thaana script. Unfortunately phonemizer support is lacking right now so I patched it to skip the phonemizer and pick up a raw charset from config instead link to fork here.

Turned off stress and breathing. All other settings are default. The flatness probably comes more from the dataset itself, although to a native speaker it isn't bad at all.

leminhnguyen commented 3 years ago

@kudanai why did you let the batch size of HifiGan equal to 1 ???

kudanai commented 3 years ago

@kudanai why did you let the batch size of HifiGan equal to 1 ???

I'm not entirely sure. Please try the second patch first. It seems to be enough. On the first attempt to fix it I encountered some issues which appeared to be mitigated by setting the batch_size to 1

cfrancesco commented 3 years ago

@kudanai can this be closed?