lucidrains / e2-tts-pytorch

Implementation of E2-TTS, "Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS", in Pytorch
MIT License
221 stars 21 forks source link

inference code example #8

Open eschmidbauer opened 1 month ago

eschmidbauer commented 1 month ago

Is there inference code? I could not find any. but I read through other issues and found this.

          i'll write a inference script next so we can do some quick experiments.

Originally posted by @manmay-nakhashi in https://github.com/lucidrains/e2-tts-pytorch/issues/1#issuecomment-2227175532

manmay-nakhashi commented 1 month ago

I'll put together a code this weekend.

cyber-phys commented 1 month ago

okay so I took at shot at hacking together some inference code. I trained a model for 400k steps on the MushanW/GLOBE dataset; when I test it I get a cacophony which is starting to resemble the tts prompts.. but the intermediate melspec is of very poor quality so something might be wrong with my approach.

image

import os

import torch
import torchaudio
from torchaudio.transforms import GriffinLim, InverseSpectrogram, InverseMelScale, Resample, Speed

from einops import rearrange
from accelerate import Accelerator

from torch.optim import Adam

from e2_tts_pytorch.e2_tts import (
    E2TTS,
    DurationPredictor,
    MelSpec
)

duration_predictor = DurationPredictor(
    transformer = dict(
        dim = 80,
        depth = 2,
    )
)

model = E2TTS(
    duration_predictor = duration_predictor,
    transformer = dict(
        dim = 80,
        depth = 4,
        skip_connect_type = 'concat'
    )
)

n_fft = 1024
sample_rate = 22050
checkpoint_path = "./e2tts.pt"

def exists(v):
    return v is not None

def vocoder(melspec):
    inverse_melscale_transform = InverseMelScale(n_stft=n_fft // 2 + 1, n_mels=80, sample_rate=sample_rate, norm="slaney", f_min=0, f_max=8000)
    spectrogram = inverse_melscale_transform(melspec)
    transform = GriffinLim(n_fft=n_fft, hop_length=256, power=2)
    waveform = transform(spectrogram)
    return waveform

def load_checkpoint(checkpoint_path, model, accelerator, optimizer):
    if not exists(checkpoint_path) or not os.path.exists(checkpoint_path):
        return 0

    checkpoint = torch.load(checkpoint_path)
    accelerator.unwrap_model(model).load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['step']

accelerator = Accelerator(
            log_with="all",
        )

optimizer = Adam(model.parameters(), lr=1e-4)

start_step = load_checkpoint(checkpoint_path=checkpoint_path, model=model, accelerator=accelerator, optimizer=optimizer)

ref_waveform, ref_sample_rate = torchaudio.load("ref.wav", normalize=True)
resampler = Resample(orig_freq=ref_sample_rate, new_freq=sample_rate)
ref_waveform = resampler(ref_waveform)
speed_factor = sample_rate / ref_sample_rate
respeed = Speed(ref_sample_rate, speed_factor)
ref_waveform = respeed(ref_waveform)
ref_waveform_resampled = ref_waveform[0]

mel_model = MelSpec()
mel = mel_model(ref_waveform_resampled)
mel = torch.cat([mel, mel], dim=0)
mel = rearrange(mel, 'b d n -> b n d')

text = ["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", "Waves crashed against the cliffs, their thunderous applause echoing for miles."]
sample = model.sample(mel[:,:25], text = text, vocoder=vocoder)
sample = sample.to('cpu')

waveform = sample

mono_channel_1 = waveform[0].unsqueeze(0)
mono_channel_2 = waveform[1].unsqueeze(0)

torchaudio.save("output_channel_1.wav", mono_channel_1, sample_rate)
torchaudio.save("output_channel_2.wav", mono_channel_2, sample_rate)
changjinhan commented 1 month ago

I'm also looking forward to it

lucasnewman commented 1 month ago

I haven't figured out what's up with the text conditioning yet, but here's a rough sample (it doesn't use the duration predictor) of the generation flow in a notebook. I left in some debugging outputs so you can see the flow resolving visually. The voice cloning aspect seems to work fine with different speakers, fwiw, they just say nonsense at the moment 😅

(This is from a quick ~100M param model I trained with ~1/100th the FLOPs used in the paper.)

generate.ipynb

Coice commented 1 month ago

@lucasnewman does it always output the reference audio, regardless of what you use as input for the reference text? I also left out the duration predictor, I wound up just simply doubling the input duration and doubling the reference text, since if it can't do the doubling, it sure won't work for anything else 🙃

I couldn't get it to generate anything aside from the input reference audio. I was told by the author "train it more", but I put considerable resources into it, and it never improved.

lucasnewman commented 1 month ago

@lucasnewman does it always output the reference audio, regardless of what you use as input for the reference text? I also left out the duration predictor, I wound up just simply doubling the input duration and doubling the reference text, since if it can't do the doubling, it sure won't work for anything else 🙃

Yep, this is exactly what I'm doing, and more or less what you see in the notebook — I just hard-coded the duration to keep it simple.

I couldn't get it to generate anything aside from the input reference audio. I was told by the author "train it more", but I put considerable resources into it, and it never improved.

Make sure you took the duration fix from a few days ago if you explicitly passing it as an int because otherwise it will stop generation after the conditioning. You don't need to retrain your model as it only affects sampling.

Ryu1845 commented 1 month ago

I haven't figured out what's up with the text conditioning yet, but here's a rough sample (it doesn't use the duration predictor) of the generation flow in a notebook. I left in some debugging outputs so you can see the flow resolving visually. The voice cloning aspect seems to work fine with different speakers, fwiw, they just say nonsense at the moment 😅

(This is from a quick ~100M param model I trained with ~1/100th the FLOPs used in the paper.)

generate.ipynb

The output sample seems to be gibberish after what I assume to be the prompt(?)

Thank you for your work though!

https://github.com/user-attachments/assets/3ba4ee21-e94b-4187-b751-e60a6543def9

Coice commented 1 month ago

@lucasnewman The code I was using was based off of a modified version of voicebox, though I did try training an early version of this repo, but at the time it was giving nan's.

Just to be clear, if you put in any other text, do you still get the exact reference audio? The model I trained always ignored the text embeddings, I'm just wondering if you have the same issue. It looks like it just learns to pass through the input.

Another thing you can try is, pass in a masked region to see if it can do the training objective in inference mode. Can it do the infill? (When I tested this with my model, it was just gibberish, but the unmasked regions were the original audio basically).

lucasnewman commented 1 month ago

@lucasnewman The code I was using was based off of a modified version of voicebox, though I did try training an early version of this repo, but at the time it was giving nan's.

Just to be clear, if you put in any other text, do you still get the exact reference audio? The model I trained always ignored the text embeddings, I'm just wondering if you have the same issue. It looks like it just learns to pass through the input.

I haven't tried, but that would correlate to what I was referencing with the text conditioning.

Another thing you can try is, pass in a masked region to see if it can do the training objective in inference mode. Can it do the infill? (When I tested this with my model, it was just gibberish, but the unmasked regions were the original audio basically).

Yeah, this is effectively the same task with a different mask region, so I would expect similar results for now since the text conditioning doesn't seem to be working right. I don't actually have a ton of extra time to spend on debugging it, but you're welcome to run some experiments! The latest version of the code is almost exactly what I trained.

juliakorovsky commented 1 month ago

I'm trying to train this model with another repo (I've slightly changed voicebox repo) with around of 300 hours of data. Only at around 400 000 iterations it started to output something sounding like a speech (but the speech was gibberish). I also get random noise a lot, as if model would be unable to fill in the blanks. The voice model uses also doesn not resemble target voice for now. I'm thinking about increasing gradient accumulation to match paper's batch in case model just doesn't see enough of data per iteration.

HaiFengZeng commented 1 month ago

I think the model is hard to train when trying to directly learn the alignment from text to mel-spec, has anyone get some reasonable result? I also get some speech with the same timbre but the speech is not expected for text input,so I think the model doesn't get alignment properly sorry for this, after train a longer time, some of results seem to become more related to input text(some words missing, some just wrong words...),but much better

eschmidbauer commented 1 month ago

im trying to test inference again - it is very slow, and it appears the code is using CPU instead of GPU. In the image, GPU is at 0%, VRAM is 9% and CPU is 2718% Any ideas why this might be happening? image

AbrarMahmud commented 1 month ago

is there any pre-trained checkpoint for this model available ?

-thanks in advance

eschmidbauer commented 1 month ago

i was eventually able to get inference to work by changing this line

sample = e2tts.sample(mel[:,:25].to("cuda"), text = text)

But i only get noise on output. Has anyone else been able to get inference to work?

eschmidbauer commented 1 month ago

shared checkpoint here

Coice commented 4 weeks ago

@eschmidbauer Were you successful in getting it to generate speech from the text input?

eschmidbauer commented 4 weeks ago

@Coice no, but maybe my inference script needs work. Maybe someone else is able to generate speech and share the code

juliakorovsky commented 4 weeks ago

i was eventually able to get inference to work by changing this line

sample = e2tts.sample(mel[:,:25].to("cuda"), text = text)

But i only get noise on output. Has anyone else been able to get inference to work?

I'll add what I know in case someone's interested. I't trying to train E2 TTS with another repo on a small dataset. I had to rewrote some code because I used Voicebox repo. I tried to train it for a couple of weeks, but the network only generated noise. I decided to print gradients for all parameters and found out that attention gradients were always zero. After some digging I found out I accidentally turned my attention dropout to 1. When I fixed it I got something resembling speech instead of noise. Model still can speak many sounds properly, but at least I see now that it learns. If your model outputs only noise even at 400 000 iterations (400 is just an example, theoretically at this stage it should be able to generate something), I would recommend to double check gradients: maybe there's some mistake and gradients are None or they might be zero, or you might have vanishing gradients.

cyber-phys commented 4 weeks ago

Yeah so as we are finding this model requires an considerable amount of training...

From the paper.

We utilized the Libriheavy dataset [30] to train our models. The Libriheavy dataset comprises 50,000 hours of read English speech from 6,736 speakers, accompanied by transcriptions that preserve case and punctuation marks.

We modeled the 100-dimensional log mel-filterbank features, extracted every 10.7 milliseconds from audio samples with a 24 kHz sampling rate.

All models were trained for 800,000 mini-batch updates with an effective mini-batch size of 307,200 audio frames.

Meaning the model saw 0.9131 hours of audio per mini-batch, 730480 hours total. ~15 epochs over Libriheavy.

From the WER graphs, we observe that the Voicebox models demonstrated a good WER even at the 10% training point, owing to the use of frame-wise phoneme alignment. On the other hand, E2 TTS required significantly more training to converge. Interestingly, E2 TTS achieved a better WER at the end of the training. We speculate this is because the E2 TTS model learned a more effective grapheme-to-phoneme mapping based on the large training data, compared to what was used for Voicebox. From the SIM-o graphs, we also observed that E2 TTS required more training iteration, but it ultimately achieved a better result at the end of the training. We believe this suggests the superiority of E2 TTS, where the audio model and duration model are jointly learned as a single flow-matching Transformer.

image

From this chart we can infer that the model doesn't really start to learn how to speak until the it sees 73048 hours of audio.

Here is the output I am getting after training the a model with the same specs as the paper on 4440.9 hours of audio: trying to say "son, he would tell him. son, he would tell him". Note the first utterance is the reference audio, the second half is the tts generation.

https://github.com/user-attachments/assets/f8c1b587-879f-43bf-9778-cc89b8cb8698

lucasnewman commented 4 weeks ago

@cyber-phys FWIW that lines up with my napkin math and your sample sounds similar to my experiments.

I tried a trick where I used a scale factor that ramps from (0, 1] for the random times selection for a few thousand steps, forcing the model to learn stronger conditioning from close-to-noise time steps, which seemed to help a little bit with pronunciation in a low data regime (you could recognize words with ~10k hours of audio training), but nothing close to the quality of Voicebox, which obviously has a big alignment advantage.

It seems like you need a bunch of training over 50k+ hours of audio to make a dent on this one, which is kind of cool because it's possible to just brute force the alignment, but also probably out of reach for most academic/unfunded settings, unfortunately.

skirdey commented 4 weeks ago

I have a training job running that saw around 2,000,000 samples of speech out of 13M total. I am training on multi-lingual datasets so most likely it will take awhile before it can do coherent speech. But it does "speak" a combination of languages now, with no apparent alignment to the text prompt. output_1089.webm

You can find latest checkpoint here https://drive.google.com/drive/folders/11m6ftmJbxua7-pVkQCA6qbfLMlsfC_Ls?usp=drive_link

# Initialize the duration predictor and TTS model
duration_predictor = DurationPredictor(
    transformer=dict(
        dim=512,
        depth=6,
        heads=2,
        dim_head=64,
        max_seq_len=4000
    )
)

e2tts = E2TTS(
    duration_predictor=duration_predictor,
    num_channels=100,
    transformer=dict(
        dim=1024,
        depth=24,
        skip_connect_type='concat',
        heads=16,
        dim_head=64,
        max_seq_len=4000
    ),
    text_num_embeds=256,
    cond_drop_prob=0.25,
)
optimizer = AdamWScheduleFree(e2tts.parameters(), lr=3e-5)
HaiFengZeng commented 4 weeks ago

I would like to share a sample based another modified repo: really needs a lot resource(only get 4X4090 gpus) and train nearly two week and the result seems need more training. I only use two datasets: gigaspeech and libiritts.

https://github.com/user-attachments/assets/84863700-3ed6-4d3e-b97c-bdeb0957c4f3

text: you are very handsome.

https://github.com/user-attachments/assets/9079588c-a947-4222-af5f-31cc8241f6bf

juliakorovsky commented 4 weeks ago

@cyber-phys FWIW that lines up with my napkin math and your sample sounds similar to my experiments.

I tried a trick where I used a scale factor that ramps from (0, 1] for the random times selection for a few thousand steps, forcing the model to learn stronger conditioning from close-to-noise time steps, which seemed to help a little bit with pronunciation in a low data regime (you could recognize words with ~10k hours of audio training), but nothing close to the quality of Voicebox, which obviously has a big alignment advantage.

It seems like you need a bunch of training over 50k+ hours of audio to make a dent on this one, which is kind of cool because it's possible to just brute force the alignment, but also probably out of reach for most academic/unfunded settings, unfortunately.

Could you show the code for scale factor trick? Or link to it if it's included in this repo.

lucasnewman commented 4 weeks ago

It was just an experiment to see if the text conditioning was working at all — I'm not sure it's a great idea in general.

My intuition was that the joint training objective is particularly difficult for alignment because the "fingerprint" of the flow is pretty well established in the first ~2-3% of the ODE steps and at that point the model will primarily use the flow from the previous timestep for prediction. If we force the model to predict from near-noise earlier, we can bias the training objective towards the text conditioning at the start.

(Also I forgot to mention that I used phonemes instead of the raw byte encoding to make it a little easier on the model because I'm using a smaller dataset.)

You can reproduce it with something like:

diff --git a/e2_tts_pytorch/e2_tts.py b/e2_tts_pytorch/e2_tts.py
index b43a8d5..c83cf05 100644
--- a/e2_tts_pytorch/e2_tts.py
+++ b/e2_tts_pytorch/e2_tts.py
@@ -694,7 +694,8 @@ class E2TTS(Module):
         *,
         text: Int['b nt'] | List[str] | None = None,
         times: Int['b'] | None = None,
-        lens: Int['b'] | None = None,
+        times_scale: Float | None = None,
+        lens: Int['b'] | None = None
     ):
         # handle raw wave

@@ -740,6 +741,10 @@ class E2TTS(Module):
         # t is random times from above

         times = torch.rand((batch,), dtype = dtype, device = self.device)
+
+        if exists(times_scale):
+            times = times * max(1e-3, times_scale)
+        
         t = rearrange(times, 'b -> b 1 1')

         # sample xt (w in the paper)

And then in your trainer class define num_time_scale_steps and do:

if global_step < self.num_time_scale_steps:
    scale_progress = float(global_step) / float(self.num_time_scale_steps)
    times_scale = min(1.0, scale_progress)
else:
    times_scale = None

loss, ... = self.model(mel_spec, text = text, lens = mel_lengths, times_scale = times_scale)
lucidrains commented 4 weeks ago
Screen Shot 2024-08-07 at 1 46 59 PM
HaiFengZeng commented 3 weeks ago

It was just an experiment to see if the text conditioning was working at all — I'm not sure it's a great idea in general.

My intuition was that the joint training objective is particularly difficult for alignment because the "fingerprint" of the flow is pretty well established in the first ~2-3% of the ODE steps and at that point the model will primarily use the flow from the previous timestep for prediction. If we force the model to predict from near-noise earlier, we can bias the training objective towards the text conditioning at the start.

(Also I forgot to mention that I used phonemes instead of the raw byte encoding to make it a little easier on the model because I'm using a smaller dataset.)

You can reproduce it with something like:

diff --git a/e2_tts_pytorch/e2_tts.py b/e2_tts_pytorch/e2_tts.py
index b43a8d5..c83cf05 100644
--- a/e2_tts_pytorch/e2_tts.py
+++ b/e2_tts_pytorch/e2_tts.py
@@ -694,7 +694,8 @@ class E2TTS(Module):
         *,
         text: Int['b nt'] | List[str] | None = None,
         times: Int['b'] | None = None,
-        lens: Int['b'] | None = None,
+        times_scale: Float | None = None,
+        lens: Int['b'] | None = None
     ):
         # handle raw wave

@@ -740,6 +741,10 @@ class E2TTS(Module):
         # t is random times from above

         times = torch.rand((batch,), dtype = dtype, device = self.device)
+
+        if exists(times_scale):
+            times = times * max(1e-3, times_scale)
+        
         t = rearrange(times, 'b -> b 1 1')

         # sample xt (w in the paper)

And then in your trainer class define num_time_scale_steps and do:

if global_step < self.num_time_scale_steps:
    scale_progress = float(global_step) / float(self.num_time_scale_steps)
    times_scale = min(1.0, scale_progress)
else:
    times_scale = None

loss, ... = self.model(mel_spec, text = text, lens = mel_lengths, times_scale = times_scale)

good idea, how to do inference when apply time_scale? will it use less NFE steps?

acul3 commented 3 weeks ago

Has anyone got good result already (both text aligment and sound similar/quality)

4xA100 a week,30k hours, still produce incosistent speech with txt

changjinhan commented 3 weeks ago

@acul3 I'm working with a similar setup, just using GLOBE. Here’s what I have so far—could you share your intermediate results as well?

Text: (separated from the main spacecraft and began its descent to the moon's surface.) Waves crashed against the cliffs, their thunderous applause echoing for miles.

https://github.com/user-attachments/assets/4a9cf4de-bcea-4ddd-8a09-191429023c15

acul3 commented 3 weeks ago

@changjinhan how long you train it?

i am training multilingual (indonesia, and malay)

the output its acceptable, but seem hard to follow text

here is my config

# Initialize the duration predictor and TTS model
duration_predictor = DurationPredictor(
    transformer=dict(
        dim=512,
        depth=6,
        heads=2,
        dim_head=64,
        max_seq_len=4000
    )
)

e2tts = E2TTS(
    duration_predictor=duration_predictor,
    num_channels=100,
    transformer=dict(
        dim=1024,
        depth=24,
        skip_connect_type='concat',
        heads=16,
        dim_head=64,
        max_seq_len=4000
    ),
    text_num_embeds=256,
    cond_drop_prob=0.25,
)
optimizer = AdamWScheduleFree(e2tts.parameters(), lr=3e-5)

can you share yours? appreciated

changjinhan commented 3 weeks ago

@acul3 I trained it for 6 days and my config is as follows:

from itertools import chain

duration_predictor = DurationPredictor(
    transformer = dict(
        dim = 512,
        depth = 6,
    )
)

model = E2TTS(
    num_channels=80,
    transformer = dict(
        dim = 512,
        depth = 12,
        skip_connect_type = 'concat'
    )
)
optimizer = Adam(chain(e2tts.parameters(), duration_predictor.parameters()), lr=7.5e-5)
eschmidbauer commented 3 weeks ago

@changjinhan are you on latest code?

model = E2TTS(
    num_channels=80,

This line doesn't work for me because MelSpec() expects a value of 100 Also, what is benefit of 80 vs 100 (default) ?

lucidrains commented 2 weeks ago

@lucasnewman which phoneme library would you recommend using, if i were to add it as an option?

lucasnewman commented 2 weeks ago

@lucasnewman which phoneme library would you recommend using, if i were to add it as an option?

I typically use this one since it’s quick and simple to use, but it’s also English-centric: https://pypi.org/project/g2p-en/

lucasnewman commented 2 weeks ago

@changjinhan are you on latest code?

model = E2TTS(
    num_channels=80,

This line doesn't work for me because MelSpec() expects a value of 100 Also, what is benefit of 80 vs 100 (default) ?

One advantage of 100-bin / 24khz melspec is you can use an off-the-shelf vocoder like vocos: https://huggingface.co/charactr/vocos-mel-24khz

If you’re training everything yourself or using a different vocoder you can adjust it to whatever makes sense for your dataset.

lucidrains commented 2 weeks ago

@lucasnewman which phoneme library would you recommend using, if i were to add it as an option?

I typically use this one since it’s quick and simple to use, but it’s also English-centric: https://pypi.org/project/g2p-en/

thanks! i'll start by integrating that

wetdog commented 1 week ago

@changjinhan are you on latest code?

model = E2TTS(
    num_channels=80,

This line doesn't work for me because MelSpec() expects a value of 100 Also, what is benefit of 80 vs 100 (default) ?

One advantage of 100-bin / 24khz melspec is you can use an off-the-shelf vocoder like vocos: https://huggingface.co/charactr/vocos-mel-24khz

If you’re training everything yourself or using a different vocoder you can adjust it to whatever makes sense for your dataset.

@lucasnewman We have a version of Vocos that uses 80 mel bins and 22050 for retro compatibility with other models. https://huggingface.co/BSC-LT/vocos-mel-22khz

lucasnewman commented 1 week ago

@lucasnewman We have a version of Vocos that uses 80 mel bins and 22050 for retro compatibility with other models. https://huggingface.co/BSC-LT/vocos-mel-22khz

Awesome, I didn't know about that one! That's helpful.

I did a quick run of 100k steps on an H100 with LJSpeech (which is comparatively tiny, 24h with a single speaker) and I'm able to produce sensible generated speech on the current version of the code. I'm sure it could be much better with more training and a larger dataset. Here's the original vs a sample where the second half is generated:

Original Generated

manmay-nakhashi commented 6 days ago

@lucasnewman We have a version of Vocos that uses 80 mel bins and 22050 for retro compatibility with other models. https://huggingface.co/BSC-LT/vocos-mel-22khz

Awesome, I didn't know about that one! That's helpful.

I did a quick run of 100k steps on an H100 with LJSpeech (which is comparatively tiny, 24h with a single speaker) and I'm able to produce sensible generated speech on the current version of the code. I'm sure it could be much better with more training and a larger dataset. Here's the original vs a sample where the second half is generated:

Original Generated

did you just trained a vanilla code ? , what's the config looks like ?

lucasnewman commented 6 days ago

did you just trained a vanilla code ? , what's the config looks like ?

I made a few tweaks but nothing major: 1) I didn't use the gateloop layers (not against it, but I don't have much compute and they add a lot of parameters), and 2) I used the phoneme tokenizer and extended it with punctuation characters so I could just pump in the captions from LJSpeech as-is.

I have a branch here in my fork if you want to see the changes I used. There's some other experiments in there that I decided against, like the noise schedule, so you can just ignore those.

This is the model config — it's ~50M parameters and fits a batch size of 32 on a single H100 in fp16. I trained it for 100k steps using a max audio duration of 10 seconds. I only used 1k warmup steps because the dataset was a lot smaller, and a max learning rate of 3e-4.

e2tts = E2TTS(
    tokenizer = 'phoneme_en',
    cond_drop_prob = 0.2,
    transformer = dict(
        dim = 384,
        depth = 12,
        heads = 6,
        max_seq_len = 1024,
        skip_connect_type = 'concat'
    ),
    mel_spec_kwargs = dict(
        filter_length = 1024,
        hop_length = 256,
        win_length = 1024,
        n_mel_channels = 100,
        sampling_rate = 24000,
    ),
    frac_lengths_mask = (0.7, 0.9)
)

This is miles away from the compute and dataset size used in the paper and the WER is still pretty high (as you would expect from the paper), but nice to see that's it's tractable. Here's the average epoch loss curve and the final loss:

Screenshot 2024-08-28 at 9 38 18 PM
juliakorovsky commented 6 days ago

did you just trained a vanilla code ? , what's the config looks like ?

I made a few tweaks but nothing major: 1) I didn't use the gateloop layers (not against it, but I don't have much compute and they add a lot of parameters), and 2) I used the phoneme tokenizer and extended it with punctuation characters so I could just pump in the captions from LJSpeech as-is.

I have a branch here in my fork if you want to see the changes I used. There's some other experiments in there that I decided against, like the noise schedule, so you can just ignore those.

This is the model config — it's ~50M parameters and fits a batch size of 32 on a single H100 in fp16. I trained it for 100k steps using a max audio duration of 10 seconds. I only used 1k warmup steps because the dataset was a lot smaller, and a max learning rate of 3e-4.

e2tts = E2TTS(
    tokenizer = 'phoneme_en',
    cond_drop_prob = 0.2,
    transformer = dict(
        dim = 384,
        depth = 2,
        heads = 6,
        max_seq_len = 1024,
        skip_connect_type = 'concat'
    ),
    mel_spec_kwargs = dict(
        filter_length = 1024,
        hop_length = 256,
        win_length = 1024,
        n_mel_channels = 100,
        sampling_rate = 24000,
    ),
    frac_lengths_mask = (0.7, 0.9)
)

This is miles away from the compute and dataset size used in the paper and the WER is still pretty high (as you would expect from the paper), but nice to see that's it's tractable. Here's the average epoch loss curve and the final loss:

Screenshot 2024-08-28 at 9 38 18 PM

Can you generate speech from LJ Speech samples model haven't seen?

lucasnewman commented 6 days ago

Can you generate speech from LJ Speech samples model haven't seen?

Kind of... here it is saying "easy text to speech". The word error rate is pretty high but you can make it out. More training would help!

lucidrains commented 6 days ago

@lucasnewman let's grab dinner again once someone else sees the same!

JingRH commented 5 days ago

Can you generate speech from LJ Speech samples model haven't seen?

Kind of... here it is saying "easy text to speech". The word error rate is pretty high but you can make it out. More training would help!

Excellent experimental results! I have been closely following your( @lucidrains and @lucasnewman ) work, and I am currently training based on your configuration. Could you share your inference code for reference? I suspect there might be an issue with my inference code.

I also have another question: I have been trying to train a model using a multi-speaker Chinese dataset, filtering around 100 hours from the Wenetspeech4TTS dataset, but the generated results are not usable; the output does not align with the given text. Do you think this is solely due to the limited amount of data? For a single-speaker model, you have indeed achieved very good results.

lucasnewman commented 5 days ago

Excellent experimental results! I have been closely following your( @lucidrains and @lucasnewman ) work, and I am currently training based on your configuration. Could you share your inference code for reference? I suspect there might be an issue with my inference code.

This is what I use for testing: infer.py

I also have another question: I have been trying to train a model using a multi-speaker Chinese dataset, filtering around 100 hours from the Wenetspeech4TTS dataset, but the generated results are not usable; the output does not align with the given text. Do you think this is solely due to the limited amount of data? For a single-speaker model, you have indeed achieved very good results.

Multi-speaker is going to take more training -- I would probably just train longer. The text alignment is the part that takes the longest to resolve into something sensible. Also, I would check to make sure the sample rate you're specifying matches your dataset, and if you have a phonemizer that works for your dataset you could use that to help accelerate the model's ability to pronounce words correctly.

SWivid commented 5 days ago

@JingRH wenet4tts will work for sure. i have tried it, but with vanilla structure in e2 paper and others lucidrains' design should be just as effective

lucasnewman commented 4 days ago

@JingRH wenet4tts will work for sure. i have tried it, but with vanilla structure in e2 paper and others lucidrains' design should be just as effective

Just to confirm, I trained on LibriTTS-R with the same config as above for ~300k steps and it handles multi-speaker with voice matching perfectly fine:

Example 1 Example 2

juliakorovsky commented 3 days ago

@JingRH wenet4tts will work for sure. i have tried it, but with vanilla structure in e2 paper and others lucidrains' design should be just as effective

Just to confirm, I trained on LibriTTS-R with the same config as above for ~300k steps and it handles multi-speaker with voice matching perfectly fine:

Example 1 Example 2

Do you mean e2 tts? I train vanilla E2 TTS on around of 400 hours of high-quality multispeaker data in another language and can't any words at 1-2 millions steps with 6 heads/2 layers (I think it was written previously in your config) or 6 heads/4 layers (not this repo though).

lucidrains commented 3 days ago

@juliakorovsky that's way too small, increase your depth to at least 8 - 12

lucidrains commented 3 days ago

n9fgba8b0qr01

juliakorovsky commented 3 days ago

@juliakorovsky that's way too small, increase your depth to at least 8 - 12

I was surprised too, it was written in @lucasnewman config, as you can see in my quote several posts higher. I guess it was just a mistake (it's corrected now) and I'll indeed try to increase depth, thanks.