huggingface / parler-tts

Inference and training library for high-quality TTS models.
Apache License 2.0
2.6k stars 265 forks source link

Poor quality when batch inferencing #13

Closed hscspring closed 4 days ago

hscspring commented 1 month ago

code as below:


prompt1 = "Hey, how are you doing today?"
prompt2 = "Hey, good."
description = "A female speaker with a slightly low-pitched voice delivers her words quite expressively, in a very confined sounding environment with clear audio quality. She speaks very fast."

input_ids = tokenizer([description, description], return_tensors="pt").input_ids.to(device)
prompt_input_ids = tokenizer([prompt1, prompt2], padding=True, truncation=True,  return_tensors="pt").input_ids.to(device)

generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
audio_arr = generation.cpu().numpy().squeeze()
for i in range(2):
    sf.write(f"{i}_parler_tts_out.wav", audio_arr[i].squeeze(), model.config.sampling_rate)

The result seems not stable.

ylacombe commented 1 month ago

Hey @hscspring, thanks for opening the issue, this is something that needs to be improved but I actually used the same tokenizer but with two different sets of parameters during training: https://github.com/huggingface/parler-tts/blob/10016fb0300c0dc31a0fb70e26f3affee7b62f16/training/run_parler_tts_training.py#L900-L918

So to do batching we should probably load two tokenizers:

tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1", padding_side="right",)
prompt_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1", padding_side="left",)

and modify the prompt_input_ids sentence to: prompt_input_ids = prompt_tokenizer([prompt1, prompt2], padding=True, truncation=True, return_tensors="pt").input_ids.to(device)

hscspring commented 4 days ago

thanks, great work