idiap / coqui-ai-TTS

🐸💬 - a deep learning toolkit for Text-to-Speech, battle-tested in research and production
https://coqui-tts.readthedocs.io
Mozilla Public License 2.0
440 stars 40 forks source link

[Bug] Streaming dosn't work | inference_stream is as fast as normal inference #97

Open Fledermaus-20 opened 1 week ago

Fledermaus-20 commented 1 week ago

Describe the bug

When running tests against the TTS endpoint, I've observed that streaming the audio response takes nearly the same amount of time as receiving a fully generated audio file. This seems counterintuitive, as streaming should typically deliver the response faster, starting with the first available data chunk. Below are the code for the streaming endpoint

To Reproduce

model_manager.py

import asyncio
import os
import torch
import numpy as np
import wave
from io import BytesIO
from fastapi import HTTPException, status
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
import logging

formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s')
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(handler)

class ModelManager:
    def __init__(self):
        self.tts_model = None
        self.model_loading_lock = asyncio.Lock()
        self.tts_model_path = '/code/models/TTS/'

    async def load_tts_model(self):
        if self.tts_model is None:
            async with self.model_loading_lock:
                if self.tts_model is None:
                    logger.info("Loading TTS model...")
                    device = "cuda" if torch.cuda.is_available() else "cpu"
                    logger.info(f"Using {device} as device.")
                    self.config = XttsConfig()
                    self.config.load_json(os.path.join(self.tts_model_path, "config.json"))
                    self.tts_model = Xtts.init_from_config(self.config)
                    self.tts_model.load_checkpoint(self.config, checkpoint_dir=self.tts_model_path, eval=True)
                    self.tts_model.to(device)
                    self.gpt, self.speaker = self.tts_model.get_conditioning_latents(self.tts_model_path + "speaker.wav")
                    logger.info("TTS model loaded.")

        return self.tts_model, self.speaker, self.gpt

tts_streaming.py

import torch
import numpy as np
import wave
from io import BytesIO
from fastapi import HTTPException, status, Request, Header
from model_manager import ModelManager
import logging

# Setup logging
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s')

handler = logging.StreamHandler()
handler.setFormatter(formatter)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(handler)

model_manager = ModelManager()

def wav_data_generator(frame_input, sample_rate=24000, sample_width=2, channels=1):
    wav_buf = BytesIO()
    with wave.open(wav_buf, "wb") as vfout:
        vfout.setnchannels(channels)
        vfout.setsampwidth(sample_width)
        vfout.setframerate(sample_rate)
        vfout.writeframes(frame_input)
    wav_buf.seek(0)
    return wav_buf.read()

def postprocess(wav):
    if isinstance(wav, list):
        wav = torch.cat(wav, dim=0)
    wav = wav.clone().detach().cpu().numpy()
    wav = wav[None, : int(wav.shape[0])]
    wav = np.clip(wav, -1, 1)
    wav = (wav * 32767).astype(np.int16)
    return wav

async def text_to_speech_stream(text: str, language: str, chunk: int = 20):
    if not is_language_supported(language):
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unsupported language: {language}")

    tts_model, speaker, gpt = await model_manager.load_tts_model()

    add_wav_header = True

    try:
        chunks = tts_model.inference_stream(
            text=text,
            language=language,
            gpt_cond_latent=gpt,
            speaker_embedding=speaker,
            stream_chunk_size=chunk,
            enable_text_splitting=True
        )

        for i, chunk in enumerate(chunks):
            chunk = postprocess(chunk)
            if i == 0 and add_wav_header:
                yield wav_data_generator(b"")
                yield chunk.tobytes()
            else:
                yield chunk.tobytes()

    except Exception as e:
        logger.error(f"Error in text-to-speech-stream: {e}")
        raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Error in text-to-speech")

main.py

from fastapi import FastAPI, Request, Header
from fastapi.responses import StreamingResponse
from tts_streaming import text_to_speech_stream

app = FastAPI()

@app.post("/text-to-speech")
async def speech_route(
    stream_chunk: int = Header(20),
    text: str = Header(...),
    language: str = Header(...),
):
    return StreamingResponse(
        text_to_speech_stream(text=text, language=language, chunk=stream_chunk),
        media_type="audio/wav"
    )

Expected behavior

The behavior I'm expecting is that I get the tts stream back much sooner than if I request the finished file.

Logs

No response

Environment

It's Running in a Docker Contianer

{
    "CUDA": {
        "GPU": [],
        "available": false,
        "version": "12.1"
    },
    "Packages": {
        "PyTorch_debug": false,
        "PyTorch_version": "2.4.1+cu121",
        "TTS": "0.24.2",
        "numpy": "1.26.4"
    },
    "System": {
        "OS": "Linux",
        "architecture": [
            "64bit",
            ""
        ],
        "processor": "x86_64",
        "python": "3.11.0rc1",
        "version": "#45~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Wed Sep 11 15:25:05 UTC 2"
    }
}

Additional context

Thanks for the help here in advance

eginhard commented 1 week ago

You didn't share any timings, so it's hard to say what's going on. Note that streaming inference may take longer in total than normal inference, just the first chunks should arrive faster. Could you share (once models are loaded/warmed up):

Fledermaus-20 commented 1 week ago

Sorry, I had forgotten, the times after the model is loaded/warmed up are: