johnmai-dev / NotebookMLX

📋 NotebookMLX - An Open Source version of NotebookLM (Ported NotebookLlama)
MIT License
180 stars 14 forks source link

Error when generating audio for long text prompt #1

Closed namuan closed 2 weeks ago

namuan commented 2 weeks ago

Getting this error when generating audio for a long text

ValueError: Invalid high padding size (-185) passed to pad for axis 1. Padding sizes must be non-negative

Here is the full stack trace with the example text_prompt

ValueError                                Traceback (most recent call last)
Cell In[33], line 6
      1 # Define text and description
      2 text_prompt = """
      3 Alright, folks, welcome to the podcast where we dive deep into the fascinating world of knowledge distillation in large language models, or LLMs for short. If you're here, you're likely curious about how these powerful AI tools can be made more accessible and effective for everyone. Today, we're going to peel back the layers and see how knowledge distillation works, and why it's such a game-changer. I’m your host, and I've got some incredible stories to share. So grab your coffee, get comfy, and let's jump into this episode where we talk about transforming complex AI capabilities into more manageable and accessible forms.
      4 """
----> 6 generate(
      7     generation_text=text_prompt,
      8     model_name=MODEL,
      9     output_path=TEST_AUDIO_FILE
     10 )

File ~/workspace/NotebookMLX/venv/lib/python3.12/site-packages/f5_tts_mlx/generate.py:70, in generate(generation_text, duration, model_name, ref_audio_path, ref_audio_text, steps, method, cfg_strength, sway_sampling_coef, speed, seed, output_path)
     67 if duration is not None:
     68     duration = int(duration * FRAMES_PER_SEC)
---> 70 wave, _ = f5tts.sample(
     71     mx.expand_dims(audio, axis=0),
     72     text=text,
     73     duration=duration,
     74     steps=steps,
     75     method=method,
     76     speed=speed,
     77     cfg_strength=cfg_strength,
     78     sway_sampling_coef=sway_sampling_coef,
     79     seed=seed,
     80 )
     82 # trim the reference audio
     83 wave = wave[audio.shape[0] :]

File ~/workspace/NotebookMLX/venv/lib/python3.12/site-packages/f5_tts_mlx/cfm.py:271, in F5TTS.sample(self, cond, text, duration, lens, steps, method, cfg_strength, speed, sway_sampling_coef, seed, max_duration, no_ref_audio, edit_mask)
    268 # duration
    270 if duration is None and self._duration_predictor is not None:
--> 271     duration_in_sec = self._duration_predictor(cond, text)
    272     frame_rate = self.mel_spec.sample_rate // self.mel_spec.hop_length
    273     duration = (duration_in_sec * frame_rate / speed).astype(mx.int32).item()

File ~/workspace/NotebookMLX/venv/lib/python3.12/site-packages/f5_tts_mlx/duration.py:207, in DurationPredictor.__call__(self, inp, text, lens, return_loss)
    201 # attending
    203 inp = mx.where(
    204     repeat(mask, "b n -> b n d", d=self.num_channels), inp, mx.zeros_like(inp)
    205 )
--> 207 x = self.transformer(inp, text=text)
    209 x = maybe_masked_mean(x, mask)
    211 pred = self.to_pred(x)

File ~/workspace/NotebookMLX/venv/lib/python3.12/site-packages/f5_tts_mlx/duration.py:117, in DurationTransformer.__call__(self, x, text, mask)
    113 batch, seq_len = x.shape[0], x.shape[1]
    115 t = self.time_embed(mx.ones((batch,), dtype=mx.float32))
--> 117 text_embed = self.text_embed(text, seq_len)
    119 x = self.input_embed(x, text_embed)
    121 rope = self.rotary_embed.forward_from_seq_len(seq_len)

File ~/workspace/NotebookMLX/venv/lib/python3.12/site-packages/f5_tts_mlx/dit.py:59, in TextEmbedding.__call__(self, text, seq_len, drop_text)
     53 text = (
     54     text + 1
     55 )  # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
     56 text = text[
     57     :, :seq_len
     58 ]  # curtail if character tokens are more than the mel spec tokens
---> 59 text = mx.pad(text, [(0, 0), (0, seq_len - text_len)], constant_values=0)
     61 if drop_text:  # cfg for text
     62     text = mx.zeros_like(text)

ValueError: Invalid high padding size (-185) passed to pad for axis 1. Padding sizes must be non-negative

I did manage to get around this by splitting the text into multiple sentences and generating separate audio segments using the same speaker.

Here is the modified code from Step-4-TTS-Workflow.ipynb

import re

def split_into_sentences(text):
    # Split on period followed by space or newline, preserving sentence-final punctuation
    sentences = re.split(r'(?<=\.)\s+', text.strip())
    # Remove empty strings and clean up
    return [s.strip() for s in sentences if s.strip()]
def generate_speaker1_audio(text, output_path):
    sentences = split_into_sentences(text)
    base_path = str(output_path).rsplit('.', 1)[0]

    for j, sentence in enumerate(sentences):
        sentence_output_path = f"{base_path}_{j:02}.wav"

        generate(
            generation_text=sentence,
            model_name=MODEL,
            output_path=sentence_output_path
        )
def generate_speaker2_audio(text, output_path):
    sentences = split_into_sentences(text)
    base_path = str(output_path).rsplit('.', 1)[0]

    for j, sentence in enumerate(sentences):
        sentence_output_path = f"{base_path}_{j:02}.wav"
        generate(
            generation_text=text,
            model_name=MODEL,
            output_path=sentence_output_path,
            ref_audio_path="./resources/test_en_2_ref_short.wav",
            ref_audio_text="Some call me nature, others call me mother nature."
        )

Also updated the file path to use padding to help with sorting

final_audio = None

i = 1

for speaker, text in tqdm(ast.literal_eval(PODCAST_TEXT), desc="Generating podcast segments", unit="segment"):
    output_path = f"./resources/segments/_podcast_segment_{i:02}.wav"
    if speaker == "Speaker 1":
        generate_speaker1_audio(text, output_path)
    else:  # Speaker 2
        generate_speaker2_audio(text, output_path)
    i += 1
johnmai-dev commented 2 weeks ago

It seems that f5-tts-mlx does not support certain characters, resulting in a negative value for seq_len - text_len. https://github.com/lucasnewman/f5-tts-mlx/issues/13

johnmai-dev commented 2 weeks ago

Fixed!

pip install -U f5-tts-mlx or pip install -r requirements.txt