Open ajd12342 opened 4 months ago
I have the same problem
Same here:
`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour.
Additionally, with flash-attn temporarily uninstalled:
UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\transformers\cuda\sdp_utils.cpp:555.)
attn_output = torch.nn.functional.scaled_dot_product_attention(
while: attn_implementation = "sdpa"
Same here. I have not been able to find a way to get around the warning either (prompt_attention_mask is specified but attention_mask is not. A full attention_mask will be created. Make sure this is the intended behaviour.
)
Is it safe to ignore or does it have some unexpected consequences?
I have been reading through the code to understand the attention masking patterns and I have a few questions.
First, when running the code with the default hyperparams, I get the following warning during the forward pass of the ParlerTTSDecoder: "
prompt_attention_mask
is specified butattention_mask
is not. A fullattention_mask
will be created. Make sure this is the intended behaviour." I am unsure whether this is intended behavior or not and would appreciate a clarification. At face value, it seems incorrect since there should ideally be a realattention_mask
that masks out the audio pad tokens. I see some docstrings in the code that say that the decoder_attention_mask is optional and by default masks out the pad token ids, but I could not find whether this actually happens in the code (and if it did, the warning should not have appeared).Second, I notice that the pad_token_id is the same as the eos_token_id in the decoder config at https://huggingface.co/parler-tts/parler_tts_mini_v0.1/blob/main/config.json . Is masking correctly handled such that the token at the EOS position is not masked out, but the tokens at the PAD positions are, even though the token IDs are the same for both?