huggingface / parler-tts

Inference and training library for high-quality TTS models.
Apache License 2.0
4.69k stars 476 forks source link

Using sdpa and flash_attention_2 error #168

Open aixingxy opened 1 week ago

aixingxy commented 1 week ago

Hello, thanks for this great job! I followed the instructions INFERENCE , but encountered some difficulties.

from parler_tts import ParlerTTSForConditionalGeneration
import torch
from transformers import AutoTokenizer
import soundfile as sf

torch_device = "cuda:0" # use "mps" for Mac
torch_dtype = torch.float32
model_name = "parler-tts/parler-tts-mini-v1"

attn_implementation = "sdpa" # "sdpa" or "flash_attention_2"

model = ParlerTTSForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch_dtype, attn_implementation=attn_implementation).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

prompt = "Hey, how are you doing today?"
description = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up."

input_ids = tokenizer(description, return_tensors="pt").input_ids.to(torch_device)
prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(torch_device)

generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
audio_arr = generation.cpu().to(torch.float32).numpy().squeeze()

sf.write("parler_tts_out.wav", audio_arr, model.config.sampling_rate)

when I set attn_implementation="sdpa",get an error

ValueError: T5EncoderModel does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet. Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`

and set attn_implementation="flash_attention_2",get an error

ValueError: T5EncoderModel does not support Flash Attention 2.0 yet. Please request to add support where the model is hosted, on its model hub page: https://huggingface.co/google/flan-t5-large/discussions/new or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new

I use A100 GPU, my environment is:

transformers                4.46.1
torch                       2.3.0
flash-attn                  2.5.8

Am I missing some important configuration information?

remichu-ai commented 1 day ago

i am encountering this error too. Appreciate if there can be any help

aixingxy commented 1 day ago

来函妥收。

jack-richards commented 8 hours ago

I am getting the same issue, I also followed the tutorial exactly.