spyder-ide / spyder

Official repository for Spyder - The Scientific Python Development Environment
https://www.spyder-ide.org
MIT License
8.32k stars 1.61k forks source link

Add Ollama API support (Copilot-like with local LLMs) #21879

Open chhu opened 7 months ago

chhu commented 7 months ago

Hi, I was recently trying VS code with the Continue Plugin, configured to use my own OLLAMA server and LLMs (https://ollama.ai) and was amazed how well this works.

I'm not a big fan of shipping my code to companies to train the models they sell, so I'm not a fan of copilot and the like. But with the option of local open-source LLMs this becomes a game changer. Related to #20632

Datagniel commented 1 month ago

In the langchain_provider folder you need to modify the client.py and the provider.py files.

new client.py:


# Copyright © Spyder Project Contributors
# Licensed under the terms of the MIT License

"""Langchain completions HTTP client."""

# Standard library imports
import json
import logging

# Third party imports
# from langchain_community.chat_models import ChatOpenAI changed to
from langchain_community.chat_models.ollama import ChatOllama
from langchain.chains import LLMChain
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
from qtpy.QtCore import QObject, QThread, Signal, QMutex, Slot

# Spyder imports
from spyder.plugins.completion.api import CompletionRequestTypes, CompletionItemKind

logger = logging.getLogger(__name__)

LANG_COMPLETION = "Langchain"
LANG_ICON_SCALE = 1

class LangchainClient(QObject):
    sig_response_ready = Signal(int, dict)
    sig_client_started = Signal()
    sig_client_error = Signal(str)
    sig_perform_request = Signal(dict)
    sig_perform_status_request = Signal(str)
    sig_status_response_ready = Signal((str,), (dict,))
    sig_onboarding_response_ready = Signal(str)

    def __init__(self, parent, template, model_name, language="python"):
        QObject.__init__(self, parent)
        self.requests = {}
        self.language = language
        self.mutex = QMutex()
        self.opened_files = {}
        self.opened_files_status = {}
        self.thread_started = False
        self.thread = QThread(None)
        self.moveToThread(self.thread)
        self.thread.started.connect(self.started)
        self.sig_perform_request.connect(self.handle_msg)
        self.sig_perform_status_request.connect(self.get_status)

        self.template = template
        self.model_name = model_name
        self.chain = None

    def start(self):
        if not self.thread_started:
            self.thread.start()
        logger.debug("Starting LangChain session...")
        system_message_prompt = SystemMessagePromptTemplate.from_template(self.template)
        code_template = "{text}"
        code_message_prompt = HumanMessagePromptTemplate.from_template(
            code_template,
        )
        try:
            # Change from ChatOpenAI to ChatOllama
            llm = ChatOllama(
                model=self.model_name, # change model_name to model
            )
            chat_prompt = ChatPromptTemplate.from_messages(
                [system_message_prompt, code_message_prompt]
            )
            chain = LLMChain(
                llm=llm,
                prompt=chat_prompt,
            )
            self.chain = chain
            self.sig_client_started.emit()
        except ValueError as e:
            logger.debug(e)
            self.sig_client_error.emit("Missing Ollama API key or configuration")
        except Exception as e:
            logger.debug(e)
            self.sig_client_error.emit("Unexpected error")

    def started(self):
        self.thread_started = True

    def stop(self):
        if self.thread_started:
            logger.debug("Closing LangChain session...")
            self.thread.quit()
            self.thread.wait()
            self.thread_started = False

    def update_configuration(self, model_name, template):
        self.stop()
        self.model_name = model_name
        self.template = template
        self.start()

    def get_status(self, filename):
        """Get langchain status for a given filename."""
        langchain_status = None
        if not filename or langchain_status is None:
            langchain_status = self.model_name
            self.sig_status_response_ready[str].emit(langchain_status)

    def run_chain(self, params=None):
        response = None
        try:
            prevResponse = self.chain.invoke(params)["text"]
            if prevResponse[0] == '"':
                response = json.loads("{" + prevResponse + "}")
            else:
                response = json.loads(prevResponse)
            return response
        except Exception:
            self.sig_client_error.emit("No suggestions available")
            return {"suggestions": []}

    def send(self, params):
        response = None
        response = self.run_chain(params=params)
        return response

    @Slot(dict)
    def handle_msg(self, message):
        """Handle one message"""
        msg_type, _id, file, msg = [message[k] for k in ("type", "id", "file", "msg")]
        logger.debug("Perform request {0} with id {1}".format(msg_type, _id))
        if msg_type == CompletionRequestTypes.DOCUMENT_DID_OPEN:
            self.opened_files[msg["file"]] = msg["text"]
        elif msg_type == CompletionRequestTypes.DOCUMENT_DID_CHANGE:
            self.opened_files[msg["file"]] = msg["text"]
        elif msg_type == CompletionRequestTypes.DOCUMENT_COMPLETION:
            response = self.send(self.opened_files[msg["file"]])
            logger.debug(response)
            if response is None:
                return {"params": []}
            spyder_completions = []
            completions = response["suggestions"]
            if completions is not None:
                for i, completion in enumerate(completions):
                    entry = {
                        "kind": CompletionItemKind.TEXT,
                        "label": completion,
                        "insertText": completion,
                        "filterText": "",
                        # Use the returned ordering
                        "sortText": (0, i),
                        "documentation": completion,
                        "provider": LANG_COMPLETION,
                        "icon": ("langchain", LANG_ICON_SCALE),
                    }
                    spyder_completions.append(entry)
            self.sig_response_ready.emit(_id, {"params": spyder_completions})

new provider.py:


# Copyright © Spyder Project Contributors
# Licensed under the terms of the MIT License

"""Langchain completion HTTP client."""

# Standard library imports
import logging
import os

# Qt imports
from qtpy.QtCore import Slot

# Local imports
from langchain_provider.client import LangchainClient
from langchain_provider.widgets import LangchainStatusWidget

# Spyder imports
from spyder.api.config.decorators import on_conf_change
from spyder.config.base import running_under_pytest, get_module_data_path
from spyder.plugins.completion.api import SpyderCompletionProvider
from spyder.utils.image_path_manager import IMAGE_PATH_MANAGER

logger = logging.getLogger(__name__)

class LangchainProvider(SpyderCompletionProvider):
    COMPLETION_PROVIDER_NAME = "langchain"
    DEFAULT_ORDER = 1
    SLOW = True
    CONF_VERSION = "1.0.0"
    CONF_DEFAULTS = [
        ("suggestions", 4),
        ("language", "Python"),
        ("model_name", "codellama"),  # Replace with your Ollama model name
    ]
    TEMPLATE_PARAM = """You are a helpful assistant in completing following {0} code based
                  on the previous sentence.
                  You always complete the code in same line and give {1} suggestions.
                  Example : a=3 b=4 print
                  AI : "suggestions": ["print(a)", "print(b)", "print(a+b)"]
                  Example : a=3 b=4 c
                  AI : "suggestions": ["c=a+b", "c=a-b", "c=5"]
                  Format the output as JSON with the following key:
                      suggestions
                  """

    def __init__(self, parent, config):
        super().__init__(parent, config)
        IMAGE_PATH_MANAGER.add_image_path(
            get_module_data_path("langchain_provider", relpath="images")
        )
        self.available_languages = []
        self.client = LangchainClient(
            None,
            model_name=self.get_conf("model_name"),
            template=self.TEMPLATE_PARAM.format(
                self.get_conf("language"), self.get_conf("suggestions")
            ),
        )

        # Signals
        self.client.sig_client_started.connect(
            lambda: self.sig_provider_ready.emit(self.COMPLETION_PROVIDER_NAME)
        )
        self.client.sig_client_error.connect(self.set_status_error)
        self.client.sig_status_response_ready[str].connect(self.set_status)
        self.client.sig_status_response_ready[dict].connect(self.set_status)
        self.client.sig_response_ready.connect(
            lambda _id, resp: self.sig_response_ready.emit(
                self.COMPLETION_PROVIDER_NAME, _id, resp
            )
        )

        # Status bar widget
        self.STATUS_BAR_CLASSES = [self.create_statusbar]
        self.started = False

    # ------------------ SpyderCompletionProvider methods ---------------------
    def get_name(self):
        return "LangChain"

    def send_request(self, language, req_type, req, req_id):
        request = {"type": req_type, "file": req["file"], "id": req_id, "msg": req}
        self.client.sig_perform_request.emit(request)

    def start_completion_services_for_language(self, language):
        return self.started

    def start(self):
        if not self.started:
            self.client.start()
            self.started = True

    def shutdown(self):
        if self.started:
            self.client.stop()
            self.started = False

    @Slot(str)
    @Slot(dict)
    def set_status(self, status):
        """Show Langchain status for the current file."""
        self.sig_call_statusbar.emit(
            LangchainStatusWidget.ID, "set_value", (status,), {}
        )

    def set_status_error(self, error_message):
        """Show Langchain status for the current file."""
        self.sig_call_statusbar.emit(
            LangchainStatusWidget.ID, "set_value", (error_message,), {}
        )

    def file_opened_closed_or_updated(self, filename, _language):
        """Request status for the given file."""
        self.client.sig_perform_status_request.emit(filename)

    @on_conf_change(section="completions", option=("enabled_providers", "langchain"))
    def on_langchain_enable_changed(self, value):
        self.sig_call_statusbar.emit(LangchainStatusWidget.ID, "set_value", (None,), {})

    @on_conf_change
    def update_langchain_configuration(self, config):
        if running_under_pytest():
            if not os.environ.get("SPY_TEST_USE_INTROSPECTION"):
                return
        self.client.update_configuration(
            self.get_conf("model_name"),
            self.TEMPLATE_PARAM.format(
                self.get_conf("language"), self.get_conf("suggestions")
            ),
        )

    def create_statusbar(self, parent):
        return LangchainStatusWidget(parent, self)
Datagniel commented 1 month ago

I did a pretty hacky fork here.