huggingface / parler-tts

Inference and training library for high-quality TTS models.
Apache License 2.0
4.53k stars 457 forks source link

Questions regarding attention masks in the ParlerTTSDecoder code #79

Open ajd12342 opened 4 months ago

ajd12342 commented 4 months ago

I have been reading through the code to understand the attention masking patterns and I have a few questions.

First, when running the code with the default hyperparams, I get the following warning during the forward pass of the ParlerTTSDecoder: "prompt_attention_mask is specified but attention_mask is not. A full attention_mask will be created. Make sure this is the intended behaviour." I am unsure whether this is intended behavior or not and would appreciate a clarification. At face value, it seems incorrect since there should ideally be a real attention_mask that masks out the audio pad tokens. I see some docstrings in the code that say that the decoder_attention_mask is optional and by default masks out the pad token ids, but I could not find whether this actually happens in the code (and if it did, the warning should not have appeared).

Second, I notice that the pad_token_id is the same as the eos_token_id in the decoder config at https://huggingface.co/parler-tts/parler_tts_mini_v0.1/blob/main/config.json . Is masking correctly handled such that the token at the EOS position is not masked out, but the tokens at the PAD positions are, even though the token IDs are the same for both?

zyy-fc commented 3 months ago

I have the same problem

mjaniec2013 commented 2 months ago

Same here:

`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour.

Additionally, with flash-attn temporarily uninstalled:

UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\transformers\cuda\sdp_utils.cpp:555.)
  attn_output = torch.nn.functional.scaled_dot_product_attention(

while: attn_implementation = "sdpa"

camba1 commented 1 month ago

Same here. I have not been able to find a way to get around the warning either (prompt_attention_mask is specified but attention_mask is not. A full attention_mask will be created. Make sure this is the intended behaviour.) Is it safe to ignore or does it have some unexpected consequences?