PrefectHQ / marvin

✨ Build AI interfaces that spark joy
https://askmarvin.ai
Apache License 2.0
5.3k stars 348 forks source link

Fallback selection of model based on tokenized prompt input #451

Open twardoch opened 1 year ago

twardoch commented 1 year ago

First check

Describe the current behavior

I have code like this:

from functools import lru_cache
from typing import Any, List, Literal, Tuple, Optional

import marvin
from marvin import ai_fn, ai_model
from marvin.engine.language_models import ChatLLM
from pydantic import BaseModel, Field

def cached_method(method):
    method = lru_cache(maxsize=None)(method)

    def wrapper(self, *args, **kwargs):
        return method(self, *args, **kwargs)

    return wrapper

class ChatLLMPlus(ChatLLM):
    reserve_tokens: int = Field(1000)

    def __init__(self, *args, reserve_tokens: int = 1000, **kwargs):
        super().__init__(*args, **kwargs)
        self.reserve_tokens = reserve_tokens

    def __hash__(self):
        return hash((self.name, self.model, self.max_tokens, self.temperature))

    @cached_method
    def count_tokens(self, text: str, **kwargs) -> int:
        return len(self.get_tokens(text, **kwargs))

    @cached_method
    def is_safe(self, num_of_tokens: int) -> bool:
        return num_of_tokens <= self.context_size - self.reserve_tokens

    @cached_method
    def check_len(self, text: str, **kwargs) -> Tuple[int, bool]:
        num_of_tokens = self.count_tokens(text, **kwargs)
        return num_of_tokens, self.is_safe(num_of_tokens)

class ChatLLMs:
    def __init__(self, reserve_tokens: int = 1000):
        self.model_names = [
            "gpt-3.5-turbo",
            "gpt-4",
            "gpt-3.5-turbo-16k",
        ]  # Ensure this is in order of context_size
        self.models = {
            name: ChatLLMPlus(model=name, reserve_tokens=reserve_tokens)
            for name in self.model_names
        }

    def auto_model(self, text: str, prefer=None) -> Tuple[Optional[ChatLLMPlus], int]:
        start_index = self.model_names.index(prefer) if prefer in self.models else 0
        for model_name in self.model_names[start_index:]:
            model = self.models[model_name]
            num_of_tokens, is_safe = model.check_len(text)
            if is_safe:
                return model, num_of_tokens

        return None, model.count_tokens(text)

llms = ChatLLMs()
marvin.settings.llm_model = llms.models["gpt-3.5-turbo-16k"].model

class DiffInterpretation(BaseModel):
    change_impact: Literal["low", "medium", "high"] = Field(
        ..., description="How much the change affects the code"
    )
    change_ui: bool = Field(..., description="True if the change affects the UI")
    change_content: str = Field(
        ...,
        description="Extensive and detailed human-readable interpretation of the diff change",
    )
    file_changed: str | Path
    repo_name: str

class CodeChangesBase(BaseModel):
    build_number: int = Field(
        ...,
        description="`VERS_BUILD_NUMBER` if it was changed in `version_build.h` of the repo, otherwise 0",
    )
    diff_interpretations: List[DiffInterpretation]

@ai_model(
    model=llms.models["gpt-3.5-turbo"],
    instructions="You're an expert analyzer of C++ source code.",
)
class CodeChangesS(CodeChangesBase):
    ...

@ai_model(
    model=llms.models["gpt-4"],
    instructions="You're an expert analyzer of C++ source code.",
)
class CodeChangesM(CodeChangesBase):
    ...

@ai_model(
    model=llms.models["gpt-3.5-turbo-16k"],
    instructions="You're an expert analyzer of C++ source code.",
)
class CodeChangesL(CodeChangesBase):
    ...

def get_code_changes(diff):
    code_changes_classes = {
        "gpt-3.5-turbo": CodeChangesS,
        "gpt-4": CodeChangesM,
        "gpt-3.5-turbo-16k": CodeChangesL,
    }
    auto_model, num_of_tokens = llms.auto_model(diff)
    return code_changes_classes[auto_model.model](diff)

It does the job but it’s stupid. What would be a better way to implement this?

Describe the proposed behavior

I’m not sure how I could write this more elegantly. I’m experimenting with various patterns to utilize Marvin, because I very much like the abstraction and elegance.

But because of the clever decorators, I’m not sure how I should do it without code duplication. Basically, what would be a simple way to call both the model config AND a pedantic model to ai_model?

Example Use

No response

Additional context

No response

twardoch commented 1 year ago

OK, I think I’ve managed to do this with a simple intervention of no longer using the ai_model as a decorator but turning this into a simple function instead.

#!/usr/bin/env python3

from functools import lru_cache
from typing import Any, List, Literal, Tuple, Optional, Union

import marvin
from git import Repo
from marvin import ai_fn, ai_model
from marvin.engine.language_models import ChatLLM
from pydantic import BaseModel, Field

def cached_method(method):
    method = lru_cache(maxsize=None)(method)

    def wrapper(self, *args, **kwargs):
        return method(self, *args, **kwargs)

    return wrapper

class ChatLLMPlus(ChatLLM):
    reserve_tokens: int = Field(1000)

    def __init__(self, *args, reserve_tokens: int = 1000, **kwargs):
        super().__init__(*args, **kwargs)
        self.reserve_tokens = reserve_tokens

    def __hash__(self):
        return hash((self.name, self.model, self.max_tokens, self.temperature))

    @cached_method
    def count_tokens(self, text: str, **kwargs) -> int:
        return len(self.get_tokens(text, **kwargs))

    @cached_method
    def is_safe(self, num_of_tokens: int) -> bool:
        return num_of_tokens <= self.context_size - self.reserve_tokens

    @cached_method
    def check_len(self, text: str, **kwargs) -> Tuple[int, bool]:
        num_of_tokens = self.count_tokens(text, **kwargs)
        return num_of_tokens, self.is_safe(num_of_tokens)

class ChatLLMs:
    def __init__(self, reserve_tokens: int = 1000):
        self.model_names = [
            "gpt-3.5-turbo",
            "gpt-4",
            "gpt-3.5-turbo-16k",
        ]  # Ensure this is in order of context_size
        self.models = {
            name: ChatLLMPlus(model=name, reserve_tokens=reserve_tokens)
            for name in self.model_names
        }

    def auto_model(self, text: str, prefer=None) -> Tuple[Optional[ChatLLMPlus], int]:
        start_index = self.model_names.index(prefer) if prefer in self.models else 0
        for model_name in self.model_names[start_index:]:
            model = self.models[model_name]
            num_of_tokens, is_safe = model.check_len(text)
            if is_safe:
                return model, num_of_tokens

        return None, model.count_tokens(text)

llms = ChatLLMs()

class DiffInterpretation(BaseModel):
    change_impact: Literal["low", "medium", "high"] = Field(
        ..., description="How much the change affects the final app"
    )
    change_ui: bool = Field(..., description="True if the change affects the UI")
    change_content: str = Field(
        ...,
        description="Extensive and detailed human-readable interpretation of the diff change",
    )
    file_changed: str | Path
    repo_name: str

class CodeChanges(BaseModel):
    build_number: int = Field(
        ...,
        description="`VERS_BUILD_NUMBER` if it was changed in `version_build.h` of the repo, otherwise 0",
    )
    diff_interpretations: List[DiffInterpretation]

def get_code_changes(text: str) -> CodeChanges:
    auto_model, num_of_tokens = llms.auto_model(text)
    return (
        CodeChanges(build_number=0, diff_interpretations=[text])
        if not auto_model
        else ai_model(
            model=auto_model,
            instructions="You're an expert analyzer of Qt C++ source code.",
        )(CodeChanges)(text)
    )

It’s not so elegant but less ugly than the stupid version I had above. I still wonder if this could be made nicer, in general. I do with that a more robust and transparent "fallback model selection" method were available upstream :)