cpacker / MemGPT

Letta (fka MemGPT) is a framework for creating stateful LLM services.
https://letta.com
Apache License 2.0
11.88k stars 1.29k forks source link

Add authentication middleware to REST API #898

Open cpacker opened 8 months ago

cpacker commented 8 months ago

Add optional authentication middleware to REST API, allowing clients to authenticate with bearer tokens

arduenify commented 8 months ago

@cpacker What do you think of this implementation? Also, we can modify token storage from in-memory to persistent if desired.

Changes

server.py (SyncServer)

def authenticate_user(self) -> str:
    """
    Generates a secure random bearer token, creates a new user if necessary,
    and stores the token associated with the user.
    """
    user_id = uuid.uuid4() 
    token = secrets.token_urlsafe()
    user = User(id=user_id)

    try:
        self.ms.create_user(user)
    except ValueError:
        # user already exists
        pass
    self.active_tokens[token] = user_id
    return token

def verify_token(self, token: str) -> Optional[uuid.UUID]:
    return self.active_tokens.get(token, None)

auth/index.py

security = HTTPBearer()

class AuthResponse(BaseModel):
    token: str = Field(..., description="Bearer token for the authenticated user")

def setup_auth_router(server: SyncServer, interface: QueuingInterface):
    @router.get("/auth", tags=["auth"], response_model=AuthResponse)
    def authenticate_user():
        """
        Authenticates the user and sends response with User related data.

        Now returns a bearer token for the authenticated user.
        """
        interface.clear()
        try:
            token = server.authenticate_user()
        except HTTPException:
            raise
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"{e}")
        return AuthResponse(token=token)

    def get_current_user(credentials: HTTPAuthorizationCredentials = Security(security)):
        token = credentials.credentials
        user_id = server.verify_token(token)
        if not user_id:
            raise HTTPException(status_code=403, detail="Invalid authentication credentials")
        user = server.ms.get_user(user_id)
        if not user:
            raise HTTPException(status_code=404, detail="User not found")
        return user