huggingface / parler-tts

Inference and training library for high-quality TTS models.
Apache License 2.0
4.19k stars 411 forks source link

GREAT MODELS, but a number of issues ... #125

Open apresence opened 2 weeks ago

apresence commented 2 weeks ago

First off -- AMAZING TTS!!!

I know I'm repeating several other issues that have been opened, but I've spent several days testing and code tweaking to try to resolve the issues I have found, and wanted to share. Plus, I figured rolling them all up into one place might be helpful.

It would be AWESOME if we could get this thing working reliably!

I've tried the following:

I wrote a program that works through a sampling of all of the above combinations, used it to generate 500 WAV files from the same paragraph of input text and description (description varies by model, of course), then randomly sampled about 10% of them.

At least one or more of the following issues occur regardless of the model or which combination of the above are done:

Any more than about 50 input tokens and the issues get much worse.

I'm wondering if there isn't an issue with the way attention or KV caching is implemented. That seems to fit as a cause for the issues.

For one thing, this message is logged at the first generation: prompt_attention_mask is specified but attention_mask is not. A full attention_mask will be created. Make sure this is the intended behaviour.

A further hint towards attention issues is in the code examples: sometimes only the text mask is given, sometimes only the description mask, sometimes both, and sometimes neither. Sometimes they're padded, sometimes not.

Looking at the code, it seems that the input attention mask is ignored in some cases, generated/re-generated/re-shaped/modified several times throughout the generation cycle, and so is the cache. The code that manipulates them is spread around and repeated in different places. There are also multiple conditionals that check the torch version and whatnot and change the way things are processed. Then there are comments like:

I do not intend this as a form of criticism -- the quality coming out of Parler is amazing! I highlight these in case anyone with the requisite knowledge might be able to review them. While I am a developer, I am brand new to transformers and don't really understand the underlying concepts at play here.

It's worth noting that I can see a lot of this code was copied and pasted from somewhere else (most notably, MusicGen), so many of these little wrinkles may have been pre-existing.

Here's my setup:

Hardware: RTX 4090 FE 24GB VRAM Drivers: 555.42.06, Cuda 12.5 OS: Ubuntu 22.04.4 LTS

I also tried it on an RTX 8000 48GB VRAM, same results.

Thanks!

apresence commented 2 weeks ago

Something definitely seems strange with the cache. I noticed the following warning when compiling:

V0901 20:36:51.238000 125202030948352 torch/_dynamo/guards.py:2611] [0/1] [__recompiles] Recompiling function forward in /work/parler-tts/parler_tts/modeling_parler_tts.py:2576
V0901 20:36:51.238000 125202030948352 torch/_dynamo/guards.py:2611] [0/1] [__recompiles]     triggered by the following guard failure(s):
V0901 20:36:51.238000 125202030948352 torch/_dynamo/guards.py:2611] [0/1] [__recompiles]     - tensor 'L['cache_position']' size mismatch at index 0. expected 51, actual 1

It's saying it expected cache_position to be of size 51, but it's actually 1. Using the handy python module icecream, we can trace this through a generation. As you'll see, cache_position starts out as a Tensor with an array of size 51 with numbers 0 .. 50 in it. But then suddenly in prepare_inputs_for_generation() it is converted to a tensor with only one value. This tracks with the compiler warning.

Log snippet:

2024-09-01 19:08:14,490 [MainThread  ] [INFO ] >>> Performing inference into test_comp_inf_20240901190805.attn_impl=eager.model=expresso.pad_dir=left.pad_len=50.seed=42.tdev=cuda0.ttyp=torch.bfloat16_gen.wav
2024-09-01 19:08:14,492 [MainThread  ] [INFO ] Chunk 1/7: numtoks=11/50 text="Hey, there, I'm Parly!"
...
ic| modeling_parler_tts.py:2819 in prepare_inputs_for_generation()
    cache_position: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                            18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                            36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50],
                           device='cuda:0')
...
ic| modeling_parler_tts.py:2833 in prepare_inputs_for_generation()
    cache_position: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                            18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                            36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50],
                           device='cuda:0')
...
ic| modeling_parler_tts.py:1360 in forward()
    cache_position.unsqueeze(0): tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                                          18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                                          36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50]],
                                        device='cuda:0')
...
ic| modeling_parler_tts.py:1607 in _update_causal_mask()
    cache_position: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                            18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                            36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50],
                           device='cuda:0')
ic| modeling_parler_tts.py:1608 in _update_causal_mask()
    cache_position.reshape(-1, 1): tensor([[ 0],
                                           [ 1],
                                           [ 2],
                                           [ 3],
                                           [ 4],
                                           [ 5],
                                           [ 6],
                                           [ 7],
                                           [ 8],
                                           [ 9],
                                           [10],
                                           [11],
                                           [12],
                                           [13],
                                           [14],
                                           [15],
                                           [16],
                                           [17],
                                           [18],
                                           [19],
                                           [20],
                                           [21],
                                           [22],
                                           [23],
                                           [24],
                                           [25],
                                           [26],
                                           [27],
                                           [28],
                                           [29],
                                           [30],
                                           [31],
                                           [32],
                                           [33],
                                           [34],
                                           [35],
                                           [36],
                                           [37],
                                           [38],
                                           [39],
                                           [40],
                                           [41],
                                           [42],
                                           [43],
                                           [44],
                                           [45],
                                           [46],
                                           [47],
                                           [48],
                                           [49],
                                           [50]], device='cuda:0')
ic| modeling_parler_tts.py:433 in forward()
    cache_position: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                            18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
                            36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50],
                           device='cuda:0')
...
ic| modeling_parler_tts.py:433 in forward()- cache_position: None
ic| modeling_parler_tts.py:2801 in prepare_inputs_for_generation()
    cache_position[0] if cache_position is not None else past_key_values.get_seq_length(): tensor(51, device='cuda:0')
ic| modeling_parler_tts.py:2802 in prepare_inputs_for_generation()
    past_key_values.get_seq_length(): tensor(51, device='cuda:0')
ic| modeling_parler_tts.py:2819 in prepare_inputs_for_generation()
    cache_position: tensor([51], device='cuda:0')
...
ic| modeling_parler_tts.py:433 in forward()
    cache_position: tensor([51], device='cuda:0')
ic| modeling_parler_tts.py:2801 in prepare_inputs_for_generation()
    cache_position[0] if cache_position is not None else past_key_values.get_seq_length(): tensor(52, device='cuda:0')
ic| modeling_parler_tts.py:2802 in prepare_inputs_for_generation()
    past_key_values.get_seq_length(): tensor(52, device='cuda:0')
ic| modeling_parler_tts.py:2819 in prepare_inputs_for_generation()
    cache_position: tensor([52], device='cuda:0')
ic| modeling_parler_tts.py:2826 in prepare_inputs_for_generation()
    decoder_input_ids.shape[1]: 1
...
ic| modeling_parler_tts.py:433 in forward()
    cache_position: tensor([52], device='cuda:0')
ic| modeling_parler_tts.py:2801 in prepare_inputs_for_generation()
    cache_position[0] if cache_position is not None else past_key_values.get_seq_length(): tensor(53, device='cuda:0')
ic| modeling_parler_tts.py:2802 in prepare_inputs_for_generation()
    past_key_values.get_seq_length(): tensor(53, device='cuda:0')
ic| modeling_parler_tts.py:2819 in prepare_inputs_for_generation()
    cache_position: tensor([53], device='cuda:0')
ic| modeling_parler_tts.py:2826 in prepare_inputs_for_generation()
...
apresence commented 2 weeks ago

BTW, I used the expresso model for that, but the same thing happens for the other models. Here's the config I am using: attention_implementation='eager' padding_side='left' (for input), padding='max_length', padding_size=50 model.generation_config.cache_implementation = "static" model.forward = torch.compile(model.forward, mode='reduce-overhead', fullgraph=True)

Also, to clarify, the tensor remains 1x1 throughout the generation until the next generation, then it is reset to the 51x1.

ylacombe commented 1 week ago

Hey @apresence, thanks for the thorough feedback, there's definitely a lot to unpack.

In theory, every issues regarding generation (inconsistency, words that are dropped, pauses etc.) are explained by the data on which the model was used, and the tokenizer that we used. The model is a LLM that learns to associate tokens to sounds. As such, the model can have difficulty to pronounce infrequent tokens or infrequent sequences of tokens. Since it's a LLM, it also suffers from classic LLMs issues: hallucinations, inconsistent behavior etc.

Regarding length of the audio generated, it was trained on audios that are mostly under 20 seconds, and thus can't generalize to long prompts!

These are issues that we're aware of. Hopefully, we'll solve some of these in a next version (if any!)

Also cc @eustlb regarding the compilation warning with the cache position

eustlb commented 1 week ago

Hey @apresence,

Thank you very much for your detailed feedback. Concerning the point you've raised about cache_position and recompilation, that's actually and expected result. Indeed, when running generation:

  1. for the first forward pass the hidden states of the prompt text are pre-pended to the one-dimensional start of sequence tensor of the decoder. This way cache_position, which will indicate where to store key and values in the cache, should be a tensor of with number of tokens in the prompt + 1 → the value 51 in your example.
  2. after that, at each new time step, the position in the cache will be only one value, since we auto-regressively generate new tokens.

It is therefore expected to see this recompilation during the warmup step: torch will first compile the case where we have 51 values in cache_position and then recompile when only one value. You can also read this issue for more information about this necessary warmup step : #93

Test it by yourself using the following snippet 🤗: You'll see that we have no recompilation at the second generate call.

from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import torch

# debugging
torch._logging.set_logs(graph_breaks=True, recompiles=True, cudagraphs=True)

# reproducibility
torch.manual_seed(0)

# set-up device args
torch_device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
attn_implementation = "sdpa"

# model
model_name = "parler-tts/parler-tts-mini-v1"
model = ParlerTTSForConditionalGeneration.from_pretrained(
    model_name,
    attn_implementation=attn_implementation
).to(torch_device, dtype=torch_dtype)
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="default", fullgraph=True)

# tokenizers
padding_side = "left"
description_tokenizer = AutoTokenizer.from_pretrained(model_name) 
prompt_tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side=padding_side)

def tokenize_inputs(description, prompt):
    tokenized_description = description_tokenizer(description, return_tensors="pt", padding='max_length', max_length=50)
    input_ids = tokenized_description.input_ids.to(torch_device)
    attention_mask = tokenized_description.attention_mask.to(torch_device)

    tokenized_prompt = prompt_tokenizer(prompt, return_tensors="pt", padding='max_length', max_length=50)
    prompt_input_ids = tokenized_prompt.input_ids.to(torch_device)
    prompt_attention_mask = tokenized_prompt.attention_mask.to(torch_device)

    return input_ids, prompt_input_ids, attention_mask, prompt_attention_mask 

# first generation
prompt = "Hey, how are you doing today?"
description = "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. She speaks very fast."
input_ids, prompt_input_ids, attention_mask, prompt_attention_mask = tokenize_inputs(description, prompt)

_ = model.generate(
    input_ids=input_ids, 
    prompt_input_ids=prompt_input_ids,
    attention_mask=attention_mask,
    prompt_attention_mask=prompt_attention_mask
)

print("Completed first generate!")

# second generation, debugging parameters will show us if recompilation happens
prompt = "Hey, how are you doing?"
description = "A male speaker with a slightly low-pitched voice delivers his words quite expressively, in a very confined sounding environment with clear audio quality. He speaks very fast."
input_ids, prompt_input_ids, attention_mask, prompt_attention_mask = tokenize_inputs(description, prompt)

_ = model.generate(
    input_ids=input_ids, 
    prompt_input_ids=prompt_input_ids,
    attention_mask=attention_mask,
    prompt_attention_mask=prompt_attention_mask
)

print("Completed second generate!")
kunci115 commented 5 days ago

I also experienced with numbers wrong pronounced or skipped, and spelling felts like really hard for example CCB, they will spell it like "seb"