huggingface / parler-tts

Inference and training library for high-quality TTS models.
Apache License 2.0
4.68k stars 475 forks source link

Does not work on macOS with device="mps": "Can't infer missing attention mask on `mps` device" #148

Open ChristianWeyer opened 1 month ago

ChristianWeyer commented 1 month ago

This is my simple test script:

import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import soundfile as sf

torch_device = "mps:0"
torch_dtype = torch.bfloat16
model_name = "parler-tts/parler-tts-mini-v1"

attn_implementation = "eager" # "sdpa" or "flash_attention_2"

model = ParlerTTSForConditionalGeneration.from_pretrained(
    model_name,
    attn_implementation=attn_implementation
).to(torch_device, dtype=torch_dtype)

tokenizer = AutoTokenizer.from_pretrained(model_name)

prompt = "Hey, how are you doing today?"
description = "Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."

input_ids = tokenizer(description, return_tensors="pt").input_ids.to(torch_device)
prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(torch_device)

generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
audio_arr = generation.cpu().numpy().squeeze()

sf.write("parler_tts_out.wav", audio_arr, model.config.sampling_rate)

I get this error:

ValueError: Can't infer missing attention mask on `mps` device. Please provide an `attention_mask` or use a different device.

Any idea what could be wrong? Thanks!

tulas75 commented 1 month ago

same problem for me. I tried to use transformers version 4.44.2 (not supported by parler-tts) and it seems to use GPU but at the end saving the wav file, I get an error.

NotImplementedError: Output channels > 65536 not supported at the MPS device. As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

ylacombe commented 3 weeks ago

On the last commit pushed, I've bumped the transformers version, could you try again, after having installed again from scratch

tulas75 commented 3 weeks ago

Same error even with transformers 4.46.1 NotImplementedError: Output channels > 65536 not supported at the MPS device. As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS

There's a weird thing. I tried to set torch_dtype = torch.bfloat16 but when I run the code (same code from @ChristianWeyer ) I got the following logs. It seems it uses float32.

Flash attention 2 is not installed /Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/utils/weight_norm.py:143: FutureWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm. WeightNorm.apply(module, name, dim) Config of the text_encoder: <class 'transformers.models.t5.modeling_t5.T5EncoderModel'> is overwritten by shared text_encoder config: T5Config { "_name_or_path": "google/flan-t5-large", "architectures": [ "T5ForConditionalGeneration" ], "classifier_dropout": 0.0, "d_ff": 2816, "d_kv": 64, "d_model": 1024, "decoder_start_token_id": 0, "dense_act_fn": "gelu_new", "dropout_rate": 0.1, "eos_token_id": 1, "feed_forward_proj": "gated-gelu", "initializer_factor": 1.0, "is_encoder_decoder": true, "is_gated_act": true, "layer_norm_epsilon": 1e-06, "model_type": "t5", "n_positions": 512, "num_decoder_layers": 24, "num_heads": 16, "num_layers": 24, "output_past": true, "pad_token_id": 0, "relative_attention_max_distance": 128, "relative_attention_num_buckets": 32, "tie_word_embeddings": false, "transformers_version": "4.46.1", "use_cache": true, "vocab_size": 32128 }

Config of the audio_encoder: <class 'parler_tts.dac_wrapper.modeling_dac.DACModel'> is overwritten by shared audio_encoder config: DACConfig { "_name_or_path": "parler-tts/dac_44khZ_8kbps", "architectures": [ "DACModel" ], "codebook_size": 1024, "frame_rate": 86, "latent_dim": 1024, "model_bitrate": 8, "model_type": "dac_on_the_hub", "num_codebooks": 9, "sampling_rate": 44100, "torch_dtype": "float32", "transformers_version": "4.46.1" }

Config of the decoder: <class 'parler_tts.modeling_parler_tts.ParlerTTSForCausalLM'> is overwritten by shared decoder config: ParlerTTSDecoderConfig { "_name_or_path": "/fsx/yoach/tmp/artefacts/parler-tts-mini/decoder", "activation_dropout": 0.0, "activation_function": "gelu", "add_cross_attention": true, "architectures": [ "ParlerTTSForCausalLM" ], "attention_dropout": 0.0, "bos_token_id": 1025, "codebook_weights": null, "cross_attention_implementation_strategy": null, "dropout": 0.1, "eos_token_id": 1024, "ffn_dim": 4096, "hidden_size": 1024, "initializer_factor": 0.02, "is_decoder": true, "layerdrop": 0.0, "max_position_embeddings": 4096, "model_type": "parler_tts_decoder", "num_attention_heads": 16, "num_codebooks": 9, "num_cross_attention_key_value_heads": 16, "num_hidden_layers": 24, "num_key_value_heads": 16, "pad_token_id": 1024, "rope_embeddings": false, "rope_theta": 10000.0, "scale_embedding": false, "tie_word_embeddings": false, "torch_dtype": "float32", "transformers_version": "4.46.1", "use_cache": true, "use_fused_lm_heads": false, "vocab_size": 1088 }

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's attention_mask to obtain reliable results. Traceback (most recent call last): File "/Users/tulas/Projects/parler-tts/main.py", line 39, in generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/parler_tts/modeling_parler_tts.py", line 3633, in generate sample = self.audio_encoder.decode(audio_codes=sample[None, ...], single_audio_decode_kwargs).audio_values ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/parler_tts/dac_wrapper/modeling_dac.py", line 139, in decode audio_values = self.model.decode(audio_values) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/dac/model/dac.py", line 266, in decode return self.decoder(z) ^^^^^^^^^^^^^^^ File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/dac/model/dac.py", line 144, in forward return self.model(x) ^^^^^^^^^^^^^ File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/container.py", line 250, in forward input = module(input) ^^^^^^^^^^^^^ File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/dac/model/dac.py", line 112, in forward return self.block(x) ^^^^^^^^^^^^^ File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/container.py", line 250, in forward input = module(input) ^^^^^^^^^^^^^ File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/dac/model/dac.py", line 36, in forward y = self.block(x) ^^^^^^^^^^^^^ File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/container.py", line 250, in forward input = module(input) ^^^^^^^^^^^^^ File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/conv.py", line 375, in forward return self._conv_forward(input, self.weight, self.bias) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/tulas/Projects/parler-tts/env/lib/python3.12/site-packages/torch/nn/modules/conv.py", line 370, in _conv_forward return F.conv1d( ^^^^^^^^^ NotImplementedError: Output channels > 65536 not supported at the MPS device. As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

chigkim commented 2 weeks ago

It looks like a problem with Pytorch with MPS. :(

https://github.com/pytorch/pytorch/issues/134416

They just changed the output message to "Output channels > 65536 not supported at the MPS device." removing the message "As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1".

// TODO: MPS convolution kernel currently does not support output channels > 2^16

https://github.com/pytorch/pytorch/commit/aa3ae50c07dd5c397fc430fab0c0ce5196bb1791

hvaara commented 1 week ago

Feel free to follow https://github.com/pytorch/pytorch/issues/140722 for updates on a fix in PyTorch. Tentative fix in https://github.com/pytorch/pytorch/pull/140726.

hvaara commented 1 week ago

The channel size issue has been fixed in PyTorch on macOS 15.1. It should be available in PyTorch nightly in < 24h.

While testing the fix I discovered that descript-audiotools, which parler-tts is a transitive dependent of, requires torch.distributed for types. I don't know why, but unfortunately torch.distributed is disabled by default in PyTorch on macOS. This should be the last remaining step to get parler-tts working on macOS with PyTorch/MPS.

The most straight-forward approach is probably to handle unavailability gracefully in descript-audiotools. I'm quite curious why support was removed for macOS though, since this was definitely supported in the past (ref https://github.com/pytorch/pytorch/issues/20380#issuecomment-531917214).