pdhoolia / se-agent

Software Engineering Agent
GNU General Public License v3.0
1 stars 0 forks source link

Support for local models using Ollama #13

Open pdhoolia opened 10 hours ago

pdhoolia commented 10 hours ago

Let's add support for using locally hosted models using Ollama. We'll use the langchain-ollama module for it. ChatOllama from langchain_ollama may be constructed similar to ChatOpenAI. However, it doesn't support structured output response using response_format as does ChatOpenAI. For output structuring it works like WatsonxLLM, and the parser has to be chained.

se-agent commented 10 hours ago

To add support for the Ollama provider using the langchain-ollama module, you'll need to make modifications to both the api.py and model_configuration_manager.py files. Here are the steps and the necessary code changes:

1. Update api.py to Include Ollama:

First, you need to import the ChatOllama class from the langchain_ollama module. Then, modify the fetch_llm_for_task function to handle the Ollama provider.

# Add the import statement at the top
from langchain_ollama import ChatOllama

# Modify the fetch_llm_for_task function
def fetch_llm_for_task(task_name: TaskName, **kwargs) -> Union[BaseLanguageModel, BaseChatModel]:
    task_config = config.get_task_config(PROVIDER, task_name)

    if not task_config:
        raise ValueError(f"No task configuration found for provider: {PROVIDER}")

    model_name = task_config.model_name
    max_tokens = task_config.max_tokens

    if PROVIDER == "openai":
        return ChatOpenAI(model=model_name, max_tokens=max_tokens, **kwargs)
    elif PROVIDER == "watsonx":
        return WatsonxLLM(
            model_id=model_name,
            project_id=os.getenv("WATSONX_PROJECT_ID"),
            url=os.getenv("WATSONX_URL"),
            apikey=os.getenv("WATSONX_APIKEY"),
            params={"decoding_method": "greedy", "max_new_tokens": max_tokens},
        )
    elif PROVIDER == "ollama":
        return ChatOllama(model=model_name, max_tokens=max_tokens, **kwargs)
    else:
        raise ValueError(f"Unsupported LLM provider: {PROVIDER}")

# Modify the call_llm_for_task function for response handling
@retry_with_exponential_backoff
def call_llm_for_task(task_name: TaskName, messages: list, **kwargs):
    response_format = kwargs.pop('response_format', None)
    llm = fetch_llm_for_task(task_name, **kwargs)

    if response_format:
        if isinstance(llm, BaseChatModel) and PROVIDER == "openai":
            llm = llm.with_structured_output(response_format)
            return llm.invoke(transform_to_langchain_base_chat_model_format(messages))
        elif isinstance(llm, BaseLanguageModel) or PROVIDER == "ollama":
            parser = PydanticOutputParser(pydantic_object=response_format)
            chain = llm | parser
            return chain.invoke(input=transform_to_base_language_model_single_prompt_string(messages))
    else:
        if isinstance(llm, BaseChatModel):
            return llm.invoke(transform_to_langchain_base_chat_model_format(messages))
        elif isinstance(llm, BaseLanguageModel):
            response = llm.invoke(input=transform_to_base_language_model_single_prompt_string(messages))
            if isinstance(response, str):
                return AIMessage(content=response)
            elif isinstance(response, BaseMessage):
                return response
            else:
                raise ValueError(f"Unsupported response type: {type(response)}")

2. Update model_configuration_manager.py for Configuration:

Ensure that model_configuration_manager.py can handle configuration for Ollama.

No explicit changes are needed in model_configuration_manager.py unless you want to add specific configurations for Ollama as defaults. You can add a configuration entry for Ollama in your YAML or JSON configuration files that you use to load configurations.

Additional Steps:

These changes will enable your system to support locally hosted models using Ollama, leveraging the langchain-ollama package.

pdhoolia commented 8 hours ago

Changes to call_llm_for_task have one inconsistency. in cases of response_format, ChatOllama requires chaining the parser like WatsonxLLM, but it still needs the messages in langchain base chat model format.

se-agent commented 8 hours ago

To address the inconsistency, we need to ensure that when using ChatOllama with a response_format, we chain the parser like with WatsonxLLM, but continue to use the Langchain base chat model format for the messages. Here's what the corrected portion of the call_llm_for_task function should look like:

@retry_with_exponential_backoff
def call_llm_for_task(task_name: TaskName, messages: list, **kwargs):
    response_format = kwargs.pop('response_format', None)
    llm = fetch_llm_for_task(task_name, **kwargs)

    if response_format:
        if isinstance(llm, BaseChatModel) and PROVIDER == "openai":
            llm = llm.with_structured_output(response_format)
            return llm.invoke(transform_to_langchain_base_chat_model_format(messages))
        elif isinstance(llm, BaseLanguageModel) or PROVIDER in ["watsonx", "ollama"]:
            parser = PydanticOutputParser(pydantic_object=response_format)
            chain = llm | parser
            # For Ollama, transform messages to Langchain's base chat model format
            if PROVIDER == "ollama":
                return chain.invoke(transform_to_langchain_base_chat_model_format(messages))
            else:
                # For Watsonx, transform messages to a single prompt string
                return chain.invoke(input=transform_to_base_language_model_single_prompt_string(messages))
    else:
        if isinstance(llm, BaseChatModel):
            return llm.invoke(transform_to_langchain_base_chat_model_format(messages))
        elif isinstance(llm, BaseLanguageModel):
            response = llm.invoke(input=transform_to_base_language_model_single_prompt_string(messages))
            if isinstance(response, str):
                return AIMessage(content=response)
            elif isinstance(response, BaseMessage):
                return response
            else:
                raise ValueError(f"Unsupported response type: {type(response)}")

Key Changes:

These adjustments should ensure that ChatOllama behaves correctly in line with its requirements.