timhagel / MeloTTS-Docker-API-Server

A docker image to access MeloTTS through API calls
MIT License
13 stars 7 forks source link

unload model after x minutes of inactivity? #7

Open jozsefszalma opened 2 days ago

jozsefszalma commented 2 days ago

Hey Tim,

as far as I can tell the model hangs around in GPU memory indefinitely, which might be undesirable for certain use-cases. I could submit a PR that introduces a timeout env variable (defaulting to 15min if not set), something like this (not yet tested):

import os
import uvicorn
from fastapi import FastAPI, Body, Depends
from pydantic import BaseModel
from fastapi.responses import FileResponse
from melo.api import TTS
from dotenv import load_dotenv
import tempfile
import threading
import time

load_dotenv()
DEFAULT_SPEED = float(os.getenv("DEFAULT_SPEED"))
DEFAULT_LANGUAGE = os.getenv("DEFAULT_LANGUAGE")
DEFAULT_SPEAKER_ID = os.getenv("DEFAULT_SPEAKER_ID")
device = "auto"  # Will automatically use GPU if available

# Idle timeout for unloading the model (in seconds), defaulting to 15 minutes
MODEL_IDLE_TIMEOUT = int(os.getenv("MODEL_IDLE_TIMEOUT", 15 * 60))

class TextModel(BaseModel):
    text: str
    speed: float = DEFAULT_SPEED
    language: str = DEFAULT_LANGUAGE
    speaker_id: str = DEFAULT_SPEAKER_ID

app = FastAPI()

class ModelManager:
    def __init__(self):
        self.model = None
        self.last_used = time.time()
        self.lock = threading.Lock()
        self._start_cleanup_thread()

    def _start_cleanup_thread(self):
        # Start a background thread that will periodically check and unload the model if idle
        def cleanup():
            while True:
                with self.lock:
                    if self.model and (time.time() - self.last_used) > MODEL_IDLE_TIMEOUT:
                        print("Unloading model due to inactivity...")
                        self.model = None
                time.sleep(60)  # Check every minute

        thread = threading.Thread(target=cleanup, daemon=True)
        thread.start()

    def get_model(self, language):
        with self.lock:
            if not self.model:
                print("Loading TTS model...")
                self.model = TTS(language=language, device=device)
            self.last_used = time.time()
            return self.model

# Instantiate the model manager
model_manager = ModelManager()

def get_tts_model(body: TextModel):
    return model_manager.get_model(body.language)

@app.post("/convert/tts")
async def create_upload_file(
    body: TextModel = Body(...), model: TTS = Depends(get_tts_model)
):
    speaker_ids = model.hps.data.spk2id

    # Use a temporary file
    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
        output_path = tmp.name
        model.tts_to_file(
            body.text, speaker_ids[body.speaker_id], output_path, speed=body.speed
        )

        # Return the audio file, ensure the file is not deleted until after the response is sent
        response = FileResponse(
            output_path, media_type="audio/mpeg", filename=os.path.basename(output_path)
        )

    return response

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8080)
timhagel commented 2 days ago

@jozsefszalma I agree; adding this would be beneficial. The only thing I would suggest is to add a way to keep the model loaded in case someone needs it for responsiveness. I’d set it up so that if the variable is set to -1, the model stays loaded indefinitely.

jozsefszalma commented 1 day ago

FYI this a tad more complex than I initially expected; the melotts code loads some additional tensors during inferencing that are not cleaned up afterwards. So even if I move the model to cpu, then delete the model object, empty the torch cache and do a garbage collect there is still ~1.2 GB VRAM consumed.
I guess I will raise the topic with them and see.