coqui-ai / TTS

🐸💬 - a deep learning toolkit for Text-to-Speech, battle-tested in research and production
http://coqui.ai
Mozilla Public License 2.0
35.66k stars 4.37k forks source link

[Bug] Streaming doesn't work #4011

Closed Fledermaus-20 closed 2 weeks ago

Fledermaus-20 commented 1 month ago

Describe the bug

When running tests against the TTS endpoint, I've observed that streaming the audio response takes nearly the same amount of time as receiving a fully generated audio file. This seems counterintuitive, as streaming should typically deliver the response faster, starting with the first available data chunk. Below are the code for the streaming endpoint

To Reproduce

model_manager.py

import asyncio
import os
import torch
import numpy as np
import wave
from io import BytesIO
from fastapi import HTTPException, status
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
import logging

formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s')
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(handler)

class ModelManager:
    def __init__(self):
        self.tts_model = None
        self.model_loading_lock = asyncio.Lock()
        self.tts_model_path = '/code/models/TTS/'

    async def load_tts_model(self):
        if self.tts_model is None:
            async with self.model_loading_lock:
                if self.tts_model is None:
                    logger.info("Loading TTS model...")
                    device = "cuda" if torch.cuda.is_available() else "cpu"
                    logger.info(f"Using {device} as device.")
                    self.config = XttsConfig()
                    self.config.load_json(os.path.join(self.tts_model_path, "config.json"))
                    self.tts_model = Xtts.init_from_config(self.config)
                    self.tts_model.load_checkpoint(self.config, checkpoint_dir=self.tts_model_path, eval=True)
                    self.tts_model.to(device)
                    self.gpt, self.speaker = self.tts_model.get_conditioning_latents(self.tts_model_path + "speaker.wav")
                    logger.info("TTS model loaded.")

        return self.tts_model, self.speaker, self.gpt

tts_streaming.py

import torch
import numpy as np
import wave
from io import BytesIO
from fastapi import HTTPException, status, Request, Header
from model_manager import ModelManager
import logging

# Setup logging
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(name)s - %(message)s')

handler = logging.StreamHandler()
handler.setFormatter(formatter)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(handler)

model_manager = ModelManager()

def wav_data_generator(frame_input, sample_rate=24000, sample_width=2, channels=1):
    wav_buf = BytesIO()
    with wave.open(wav_buf, "wb") as vfout:
        vfout.setnchannels(channels)
        vfout.setsampwidth(sample_width)
        vfout.setframerate(sample_rate)
        vfout.writeframes(frame_input)
    wav_buf.seek(0)
    return wav_buf.read()

def postprocess(wav):
    if isinstance(wav, list):
        wav = torch.cat(wav, dim=0)
    wav = wav.clone().detach().cpu().numpy()
    wav = wav[None, : int(wav.shape[0])]
    wav = np.clip(wav, -1, 1)
    wav = (wav * 32767).astype(np.int16)
    return wav

async def text_to_speech_stream(text: str, language: str, chunk: int = 20):
    if not is_language_supported(language):
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unsupported language: {language}")

    tts_model, speaker, gpt = await model_manager.load_tts_model()

    add_wav_header = True

    try:
        chunks = tts_model.inference_stream(
            text=text,
            language=language,
            gpt_cond_latent=gpt,
            speaker_embedding=speaker,
            stream_chunk_size=chunk,
            enable_text_splitting=True
        )

        for i, chunk in enumerate(chunks):
            chunk = postprocess(chunk)
            if i == 0 and add_wav_header:
                yield wav_data_generator(b"")
                yield chunk.tobytes()
            else:
                yield chunk.tobytes()

    except Exception as e:
        logger.error(f"Error in text-to-speech-stream: {e}")
        raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Error in text-to-speech")

main.py

from fastapi import FastAPI, Request, Header
from fastapi.responses import StreamingResponse
from tts_streaming import text_to_speech_stream

app = FastAPI()

@app.post("/text-to-speech")
async def speech_route(
    stream_chunk: int = Header(20),
    text: str = Header(...),
    language: str = Header(...),
):
    return StreamingResponse(
        text_to_speech_stream(text=text, language=language, chunk=stream_chunk),
        media_type="audio/wav"
    )

Expected behavior

The behavior I'm expecting is that I get the tts stream back much sooner than if I request the finished file.

Logs

No response

Environment

It's Running in a Docker Contianer

{
    "CUDA": {
        "GPU": [],
        "available": false,
        "version": "12.1"
    },
    "Packages": {
        "PyTorch_debug": false,
        "PyTorch_version": "2.4.1+cu121",
        "TTS": "0.24.2",
        "numpy": "1.26.4"
    },
    "System": {
        "OS": "Linux",
        "architecture": [
            "64bit",
            ""
        ],
        "processor": "x86_64",
        "python": "3.11.0rc1",
        "version": "#45~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Wed Sep 11 15:25:05 UTC 2"
    }
}

Additional context

Thanks for the help here in advance

eginhard commented 1 month ago

Duplicate of https://github.com/idiap/coqui-ai-TTS/issues/97

stale[bot] commented 2 weeks ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. You might also look our discussion channels.