MadcowD / ell

A language model programming library.
http://docs.ell.so/
MIT License
5.34k stars 315 forks source link

Cache completions #389

Open abrichr opened 6 days ago

abrichr commented 6 days ago

Is there some mechanism to avoid hitting the API if the prompt hasn't changed at all?

For example:

import ell

@ell.simple(model="gpt-4o")
def hello(name: str):
    """You are a helpful assistant.""" # System prompt
    return f"Say hello to {name}!" # User prompt

greeting = hello("Sam Altman")
print(greeting)

If we run this script twice, there is no need for the API to be called on the second time if we simply persist the result of the function call to disk.

Normally we can accomplish this with joblib.memory:

from joblib import Memory
import ell

memory = Memory("./cache")

@memory.cache()
@ell.simple(model="gpt-4o")
def hello(name: str):
    """You are a helpful assistant.""" # System prompt
    return f"Say hello to {name}!" # User prompt

greeting = hello("Sam Altman")
print(greeting)

Now if we run this script twice, the API will not be hit on the second call.

This behaves as we expect if we modify the parameters to the function, e.g. if we call hello("Sam"), the API will be hit, since the arguments changed.

However, if we change the prompt literal inside the function, unfortunately joblib is not able to pick up on it, and the stale result is returned.

Any suggestions for avoiding unnecessary API calls would be appreciated!

abrichr commented 6 days ago

Workaround: incorporate a hash of the function’s source code (including the prompt) into the cache key.

import hashlib
import inspect
from joblib import Memory
import ell

memory = Memory("./cache")

def hash_source_code(func):
    """Hash the entire function's source code, including its docstring."""
    source = inspect.getsource(func)
    return hashlib.sha256(source.encode("utf-8")).hexdigest()

def cache_source(func):
    """Decorator to cache a function, including its source code in the hash."""

    def wrapper(*args, **kwargs):
        # Compute the hash of the source code
        func_hash = hash_source_code(func)
        cache_key = (func_hash, args, frozenset(kwargs.items()))

        # Define the cacheable function
        @memory.cache()
        def cached_call(func_hash, func_args, func_kwargs):
            return func(*func_args, **func_kwargs)

        return cached_call(func_hash, args, kwargs)

    return wrapper

# Example usage
@cache_source
@ell.simple(model="gpt-4o")
def hello(name: str):
    """You are a helpful assistant."""
    return f"Say hello to {name}!"

greeting = hello("Sam Altman")
print(greeting)

Now, modifying the prompt will avoid re-using the stale API result.