arcee-ai / fastmlx

FastMLX is a high performance production ready API to host MLX models.
https://arcee-ai.github.io/fastmlx/
Other
227 stars 25 forks source link

Implement Model Loading State Tracker #9

Open Blaizzy opened 4 months ago

Blaizzy commented 4 months ago

Description:

We want to add a feature that tracks and reports the loading state of individual AI models in our FastMLX application. This will allow users to check the status of specific models they're interested in using.

Objective:

Create a system to track and report the loading state of individual models, with the ability to query the state of a single model or all models.

Tasks:

  1. Add a ModelState enum in fastmlx.py with states like LOADING, READY, and ERROR.
  2. Modify the ModelProvider class to include a state attribute for each model.
  3. Update the model loading process to set appropriate states.
  4. Add two endpoints:
    • /v1/model_status to report the current state of all models.
    • /v1/model_status/{model_name} to report the state of a specific model.
  5. Modify existing endpoints to check model state before processing requests.

Example Implementation:

from enum import Enum
from fastapi import HTTPException

class ModelState(Enum):
    LOADING = "loading"
    READY = "ready"
    ERROR = "error"

class ModelProvider:
    def __init__(self):
        self.models = {}
        self.model_states = {}

    async def load_model(self, model_name: str):
        self.model_states[model_name] = ModelState.LOADING
        try:
            # Existing model loading logic
            self.models[model_name] = await load_model(model_name)
            self.model_states[model_name] = ModelState.READY
        except Exception as e:
            self.model_states[model_name] = ModelState.ERROR
            raise

    async def get_model_status(self, model_name: str = None):
        if model_name:
            if model_name not in self.model_states:
                raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
            return {model_name: self.model_states[model_name].value}
        return {model: state.value for model, state in self.model_states.items()}

# In FastAPI app:
@app.get("/v1/model_status")
async def get_all_model_status():
    return await model_provider.get_model_status()

@app.get("/v1/model_status/{model_name}")
async def get_specific_model_status(model_name: str):
    return await model_provider.get_model_status(model_name)

Guidelines:

Resources:

Definition of Done:

We're excited to see your implementation of this feature! It will provide users with more granular control and information about model availability. If you have any questions or need clarification, please don't hesitate to ask in the comments. Good luck!