coqui-ai / TTS

🐸💬 - a deep learning toolkit for Text-to-Speech, battle-tested in research and production
http://coqui.ai
Mozilla Public License 2.0
35.36k stars 4.31k forks source link

[Bug] Error with torch.isin() in Docker Container with transformers Library #3833

Closed Fledermaus-20 closed 3 months ago

Fledermaus-20 commented 3 months ago

Describe the bug

When running the application inside a Docker container, an error occurs related to the torch.isin() method within the transformers library. The error does not occur when running the application locally (outside of the container), suggesting a possible incompatibility or issue with the dependencies inside the Docker container.

To Reproduce

Build the Docker image using the provided Dockerfile.

Dockerfile:

FROM python:3.11.8-slim

ENV PYTHONUNBUFFERED=1

# Install system dependencies and Rust
RUN apt-get update && \
    apt-get install -y --no-install-recommends \
    build-essential \
    curl \
    libsndfile1 \
    libgomp1 \
    pkg-config \
    libssl-dev && \
    curl https://sh.rustup.rs -sSf | sh -s -- -y

ENV PATH="/root/.cargo/bin:${PATH}"
ENV COQUI_TOS_AGREED=1

# Update pip to the latest version
RUN pip install --upgrade pip

# Install Python dependencies
RUN pip install --no-cache-dir fastapi uvicorn torch==2.2.0 torchaudio==2.2.0 transformers==4.43.1 numpy==1.24.3 TTS==0.22.0 sudachipy cutlet
RUN pip install --upgrade transformers

# Copy the FastAPI application code
COPY main.py /app/main.py

WORKDIR /app

EXPOSE 8001

CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8001"]

main.py:


import io
import os
import wave
import torch
import numpy as np
from fastapi import FastAPI, Request, Header, Body
from fastapi.responses import StreamingResponse
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
from TTS.utils.generic_utils import get_user_data_dir
from TTS.utils.manage import ModelManager

# Set the number of threads and device
torch.set_num_threads(int(os.environ.get("NUM_THREADS", os.cpu_count())))
device = torch.device("cuda" if torch.cuda.is_available() and os.environ.get("USE_CPU", "0") == "0" else "cpu")

# Load custom model if available, otherwise download the default model
custom_model_path = os.environ.get("CUSTOM_MODEL_PATH", "/app/tts_models")
if os.path.exists(custom_model_path) and os.path.isfile(custom_model_path + "/config.json"):
    model_path = custom_model_path
    print("Loading custom model from", model_path, flush=True)
else:
    print("Loading default model", flush=True)
    model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
    print("Downloading XTTS Model:", model_name, flush=True)
    ModelManager().download_model(model_name)
    model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
    print("XTTS Model downloaded", flush=True)

# Load model configuration and model
print("Loading XTTS", flush=True)
config = XttsConfig()
config.load_json(os.path.join(model_path, "config.json"))
model = Xtts.init_from_config(config)
model.load_checkpoint(config, checkpoint_dir=model_path, eval=True, use_deepspeed=True if device == "cuda" else False)
model.to(device)
print("XTTS Loaded.", flush=True)

# Initialize FastAPI
app = FastAPI(
    title="XTTS Streaming server",
    description="XTTS Streaming server",
    version="0.0.1",
    docs_url="/",
)

# Helper functions
def postprocess(wav):
    if isinstance(wav, list):
        wav = torch.cat(wav, dim=0)
    wav = wav.clone().detach().cpu().numpy()
    wav = wav[None, : int(wav.shape[0])]
    wav = np.clip(wav, -1, 1)
    wav = (wav * 32767).astype(np.int16)
    return wav

def wav_data_generator(frame_input, sample_rate=24000, sample_width=2, channels=1):
    wav_buf = io.BytesIO()
    with wave.open(wav_buf, "wb") as vfout:
        vfout.setnchannels(channels)
        vfout.setsampwidth(sample_width)
        vfout.setframerate(sample_rate)
        vfout.writeframes(frame_input)

    wav_buf.seek(0)
    return wav_buf.read()

# Streaming generator
def predict_streaming_generator(text, language, add_wav_header, stream_chunk_size):

    speaker_name = "Alison Dietlinde"
    speaker_raw = model.speaker_manager.speakers[speaker_name]["speaker_embedding"].cpu().squeeze().half().tolist()
    gpt_raw = model.speaker_manager.speakers[speaker_name]["gpt_cond_latent"].cpu().squeeze().half().tolist()

    speaker_embedding = torch.tensor(speaker_raw).unsqueeze(0).unsqueeze(-1)
    gpt_cond_latent = torch.tensor(gpt_raw).reshape((-1, 1024)).unsqueeze(0)

    chunks = model.inference_stream(
        text,
        language,
        gpt_cond_latent,
        speaker_embedding,
        stream_chunk_size=int(stream_chunk_size),
        enable_text_splitting=True
    )

    for i, chunk in enumerate(chunks):
        chunk = postprocess(chunk)
        if i == 0 and add_wav_header:
            yield wav_data_generator(b"")
            yield chunk.tobytes()
        else:
            yield chunk.tobytes()

# FastAPI endpoint for streaming
@app.post("/tts_stream")
async def predict_streaming_endpoint(
    text: str = Header(...),
    language: str = Header(...),
    add_wav_header: bool = Header(True),
    stream_chunk_size: str = Header("20")
):
    try:
        return StreamingResponse(
            predict_streaming_generator(text,language, add_wav_header, stream_chunk_size),
            media_type="audio/wav"
        )
    except Exception as e:
        raise

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8001)

Start the Docker container. Make a POST request to the /tts_stream endpoint with the appropriate headers and data. test.py:


import argparse
import json
import shutil
import subprocess
import sys
import time
from typing import Iterator

import requests

def is_installed(lib_name: str) -> bool:
    lib = shutil.which(lib_name)
    if lib is None:
        return False
    return True

def save(audio: bytes, filename: str) -> None:
    with open(filename, "wb") as f:
        f.write(audio)

def stream_ffplay(audio_stream, output_file, save=True):
    if not save:
        ffplay_cmd = ["ffplay", "-nodisp", "-probesize", "1024", "-autoexit", "-"]
    else:
        print("Saving to ", output_file)
        ffplay_cmd = ["ffmpeg", "-probesize", "1024", "-i", "-", output_file]

    ffplay_proc = subprocess.Popen(ffplay_cmd, stdin=subprocess.PIPE)
    for chunk in audio_stream:
        if chunk is not None:
            ffplay_proc.stdin.write(chunk)

    # close on finish
    ffplay_proc.stdin.close()
    ffplay_proc.wait()

def tts(text, language, server_url, stream_chunk_size) -> Iterator[bytes]:
    start = time.perf_counter()

    headers = {
        "text": text,
        "language": language,
        "add_wav_header": "False",
        "stream_chunk_size": stream_chunk_size,
    }

    res = requests.post(
        f"{server_url}/tts_stream",
        headers=headers, 
        stream=True
    )
    end = time.perf_counter()
    print(f"Time to make POST: {end-start}s", file=sys.stderr)

    if res.status_code != 200:
        print("Error:", res.text)
        sys.exit(1)

    first = True
    for chunk in res.iter_content(chunk_size=512):
        if first:
            end = time.perf_counter()
            print(f"Time to first chunk: {end-start}s", file=sys.stderr)
            first = False
        if chunk:
            yield chunk

    print("⏱️ response.elapsed:", res.elapsed)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--text",
        default="It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
        help="text input for TTS"
    )
    parser.add_argument(
        "--language",
        default="en",
        help="Language to use default is 'en'  (English)"
    )
    parser.add_argument(
        "--output_file",
        default=None,
        help="Save TTS output to given filename"
    )
    parser.add_argument(
        "--ref_file",
        default=None,
        help="Reference audio file to use, when not given will use default"
    )
    parser.add_argument(
        "--server_url",
        default="http://localhost:8000",
        help="Server url http://localhost:8000 default, change to your server location "
    )
    parser.add_argument(
        "--stream_chunk_size",
        default="20",
        help="Stream chunk size , 20 default, reducing will get faster latency but may degrade quality"
    )
    args = parser.parse_args()

    with open("./default_speaker.json", "r") as file:
        speaker = json.load(file)

    if args.ref_file is not None:
        print("Computing the latents for a new reference...")

    audio = stream_ffplay(
        tts(
            args.text,
            args.language,
            args.server_url,
            args.stream_chunk_size
        ), 
        args.output_file,
        save=bool(args.output_file)
    )

CMD: python test.py --text "This is a Test." --language en --server_url "http://localhost:8001" --stream_chunk_size 145

Expected behavior

No response

Logs

TypeError: isin() received an invalid combination of arguments - got (test_elements=int, elements=Tensor, ), but expected one of:
 * (Tensor elements, Tensor test_elements, *, bool assume_unique, bool invert, Tensor out)
 * (Number element, Tensor test_elements, *, bool assume_unique, bool invert, Tensor out)
 * (Tensor elements, Number test_element, *, bool assume_unique, bool invert, Tensor out)

Environment

transformers: 4.43.1
torch: 2.2.0
torchaudio: 2.2.0
TTS: 0.22.0
Platform: Docker

Additional context

No response

eginhard commented 3 months ago

It's fixed in our fork, available via pip install coqui-tts

Fledermaus-20 commented 3 months ago

Thanks for the information. To fix this issue, I installed transformers==4.40.2. However, I will try your suggested solution and see if it works as well. If I have an update on this, I will post it again.

Fledermaus-20 commented 3 months ago

Quick note: It works well with the fork package