huggingface / parler-tts

Inference and training library for high-quality TTS models.
Apache License 2.0
4.18k stars 411 forks source link

[Performance] Usage with `optimum-quanto`? #92

Open N3RDIUM opened 1 month ago

N3RDIUM commented 1 month ago

Hey there. I'm trying to use parler-tts for near-realtime text to speech, just fast enough for conversations, on CPU inference. I'm trying to quantize your model in int8 using the following code:

import torch
from parler_tts import ParlerTTSForConditionalGeneration
from optimum.quanto import quantize, freeze, qint8
from transformers import AutoTokenizer, set_seed
import soundfile as sf

device = "cpu"

model = ParlerTTSForConditionalGeneration.from_pretrained(
    "parler-tts/parler-tts-mini-expresso"
).to(device)
quantize(model, weights=qint8, activations=qint8)
freeze(model)
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-expresso")

prompt = "How am I speaking so fast? It's almost unreal."
description = "Jerry speaks fast in a, sarcastic low-pitched tone, with emphasis and high quality audio."

print("Started")

input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
print("Tokenized!")

set_seed(42)
with torch.inference_mode():
    print(dir(model))
    generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
    audio_arr = generation.cpu().numpy().squeeze()
    print("Generation complete!")

sf.write("parler_tts_out.wav", audio_arr, model.config.sampling_rate)
print("Saved to wavfile.")

It gives me this error:

✦ /mnt/Code/Code/jarvis main* 15s
.venv ❯ /mnt/Code/Code/jarvis/.venv/bin/python /mnt/Code/Code/jarvis/tts.py
Flash attention 2 is not installed
/mnt/Code/Code/jarvis/.venv/lib/python3.12/site-packages/torch/nn/utils/weight_norm.py:134: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.
  WeightNorm.apply(module, name, dim)
You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Started
Tokenized!
The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Traceback (most recent call last):
  File "/mnt/Code/Code/jarvis/tts.py", line 28, in <module>
    generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/Code/Code/jarvis/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/Code/Code/jarvis/.venv/lib/python3.12/site-packages/parler_tts/modeling_parler_tts.py", line 3461, in generate
    outputs = self._sample(
              ^^^^^^^^^^^^^
  File "/mnt/Code/Code/jarvis/.venv/lib/python3.12/site-packages/transformers/generation/utils.py", line 2932, in _sample
    outputs = self(**model_inputs, return_dict=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/Code/Code/jarvis/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/Code/Code/jarvis/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/Code/Code/jarvis/.venv/lib/python3.12/site-packages/parler_tts/modeling_parler_tts.py", line 2728, in forward
    decoder_outputs = self.decoder(
                      ^^^^^^^^^^^^^
  File "/mnt/Code/Code/jarvis/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/Code/Code/jarvis/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/Code/Code/jarvis/.venv/lib/python3.12/site-packages/parler_tts/modeling_parler_tts.py", line 1794, in forward
    lm_logits = torch.stack([head(hidden_states) for head in self.lm_heads], dim=1)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/Code/Code/jarvis/.venv/lib/python3.12/site-packages/optimum/quanto/tensor/qtensor.py", line 93, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/Code/Code/jarvis/.venv/lib/python3.12/site-packages/optimum/quanto/tensor/qbytes.py", line 130, in __torch_dispatch__
    return qdispatch(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/Code/Code/jarvis/.venv/lib/python3.12/site-packages/optimum/quanto/tensor/qbytes_ops.py", line 262, in stack
    return qfallback(inputs, dim)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/Code/Code/jarvis/.venv/lib/python3.12/site-packages/optimum/quanto/tensor/qtensor.py", line 29, in qfallback
    return callable(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: 'list' object is not callable

Can someone help me figure out what I'm doing wrong?

ylacombe commented 1 month ago

I've never used Quanto tbh, so not sure if I can help you here, but have you try excluding the lm_heads ? Refering to the code snippet here, you could do exclude=lm_heads ? Let me know if it helps

N3RDIUM commented 1 month ago

Thanks for your reply! Sadly, it still throws the same error.