huggingface / parler-tts

Inference and training library for high-quality TTS models.
Apache License 2.0
4.7k stars 478 forks source link

Torch compile problem and some more ideas #107

Closed sang-nguyen-ts closed 3 months ago

sang-nguyen-ts commented 3 months ago

@ylacombe Hi there! I've tried the latest version (mini-v1) with torch.compile and torch.sdpa on an A100. The results are very good, with a speed of about 80ms for generating 200ms of audio, which is excellent for streaming. However, I believe there's still room for improvement:

Here are some ideas I have for optimization:

  1. Quantization of KVCache: This could help reduce the KVCache size.

  2. Export to ONNX using Optimum: I found an implementation for MusicGen, and I think it will be similar for this model.

  3. Implement PageAttention: This could help reduce wasted VRAM. I found a vLLM implementation for Whisper, and I think it will be similar to the current Static Cache's implementation based on Whisper. Maybe someday we can serve ParlerTTS like other LLM in vLLM, a study from our team that we can serve a 8b LLM with sub-second latency for ~20 CCUs

Please discuss and help determine a feasible approach we can take. Of course, I'm willing to contribute in any way I can.

dgm3333 commented 3 months ago

what about a c++ implementation similar to llama cpp - which has server implementation or whisper.cpp). Because its precompiled there should be no warmup effect (presumably due to jit Interpreter stabilising?) https://github.com/ggerganov/llama.cpp https://github.com/ggerganov/whisper.cpp it's also handles quants and other optimisations

sang-nguyen-ts commented 3 months ago

what about a c++ implementation similar to llama cpp - which has server implementation or whisper.cpp). Because its precompiled there should be no warmup effect (presumably due to jit Interpreter stabilising?) https://github.com/ggerganov/llama.cpp https://github.com/ggerganov/whisper.cpp it's also handles quants and other optimisations

Yah this will be a good one, but I'm not familiar with ggml yet, maybe I will try it someday or we can do it together :3

sang-nguyen-ts commented 3 months ago

Follow up to the torch compile problem, I found that shape of tensor: decoder_attention_mask is change overtime based on sequence length which may make each generation step has it own CUDA graph:

generated_length = past_key_values_length - prompt_attention_mask.shape[1] + 1
                decoder_attention_mask = torch.ones(
                    (input_shape[0], generated_length), device=self.device, dtype=prompt_attention_mask.dtype
                )