Open chhu opened 7 months ago
In the langchain_provider folder you need to modify the client.py and the provider.py files.
new client.py:
# Copyright © Spyder Project Contributors
# Licensed under the terms of the MIT License
"""Langchain completions HTTP client."""
# Standard library imports
import json
import logging
# Third party imports
# from langchain_community.chat_models import ChatOpenAI changed to
from langchain_community.chat_models.ollama import ChatOllama
from langchain.chains import LLMChain
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from qtpy.QtCore import QObject, QThread, Signal, QMutex, Slot
# Spyder imports
from spyder.plugins.completion.api import CompletionRequestTypes, CompletionItemKind
logger = logging.getLogger(__name__)
LANG_COMPLETION = "Langchain"
LANG_ICON_SCALE = 1
class LangchainClient(QObject):
sig_response_ready = Signal(int, dict)
sig_client_started = Signal()
sig_client_error = Signal(str)
sig_perform_request = Signal(dict)
sig_perform_status_request = Signal(str)
sig_status_response_ready = Signal((str,), (dict,))
sig_onboarding_response_ready = Signal(str)
def __init__(self, parent, template, model_name, language="python"):
QObject.__init__(self, parent)
self.requests = {}
self.language = language
self.mutex = QMutex()
self.opened_files = {}
self.opened_files_status = {}
self.thread_started = False
self.thread = QThread(None)
self.moveToThread(self.thread)
self.thread.started.connect(self.started)
self.sig_perform_request.connect(self.handle_msg)
self.sig_perform_status_request.connect(self.get_status)
self.template = template
self.model_name = model_name
self.chain = None
def start(self):
if not self.thread_started:
self.thread.start()
logger.debug("Starting LangChain session...")
system_message_prompt = SystemMessagePromptTemplate.from_template(self.template)
code_template = "{text}"
code_message_prompt = HumanMessagePromptTemplate.from_template(
code_template,
)
try:
# Change from ChatOpenAI to ChatOllama
llm = ChatOllama(
model=self.model_name, # change model_name to model
)
chat_prompt = ChatPromptTemplate.from_messages(
[system_message_prompt, code_message_prompt]
)
chain = LLMChain(
llm=llm,
prompt=chat_prompt,
)
self.chain = chain
self.sig_client_started.emit()
except ValueError as e:
logger.debug(e)
self.sig_client_error.emit("Missing Ollama API key or configuration")
except Exception as e:
logger.debug(e)
self.sig_client_error.emit("Unexpected error")
def started(self):
self.thread_started = True
def stop(self):
if self.thread_started:
logger.debug("Closing LangChain session...")
self.thread.quit()
self.thread.wait()
self.thread_started = False
def update_configuration(self, model_name, template):
self.stop()
self.model_name = model_name
self.template = template
self.start()
def get_status(self, filename):
"""Get langchain status for a given filename."""
langchain_status = None
if not filename or langchain_status is None:
langchain_status = self.model_name
self.sig_status_response_ready[str].emit(langchain_status)
def run_chain(self, params=None):
response = None
try:
prevResponse = self.chain.invoke(params)["text"]
if prevResponse[0] == '"':
response = json.loads("{" + prevResponse + "}")
else:
response = json.loads(prevResponse)
return response
except Exception:
self.sig_client_error.emit("No suggestions available")
return {"suggestions": []}
def send(self, params):
response = None
response = self.run_chain(params=params)
return response
@Slot(dict)
def handle_msg(self, message):
"""Handle one message"""
msg_type, _id, file, msg = [message[k] for k in ("type", "id", "file", "msg")]
logger.debug("Perform request {0} with id {1}".format(msg_type, _id))
if msg_type == CompletionRequestTypes.DOCUMENT_DID_OPEN:
self.opened_files[msg["file"]] = msg["text"]
elif msg_type == CompletionRequestTypes.DOCUMENT_DID_CHANGE:
self.opened_files[msg["file"]] = msg["text"]
elif msg_type == CompletionRequestTypes.DOCUMENT_COMPLETION:
response = self.send(self.opened_files[msg["file"]])
logger.debug(response)
if response is None:
return {"params": []}
spyder_completions = []
completions = response["suggestions"]
if completions is not None:
for i, completion in enumerate(completions):
entry = {
"kind": CompletionItemKind.TEXT,
"label": completion,
"insertText": completion,
"filterText": "",
# Use the returned ordering
"sortText": (0, i),
"documentation": completion,
"provider": LANG_COMPLETION,
"icon": ("langchain", LANG_ICON_SCALE),
}
spyder_completions.append(entry)
self.sig_response_ready.emit(_id, {"params": spyder_completions})
new provider.py:
# Copyright © Spyder Project Contributors
# Licensed under the terms of the MIT License
"""Langchain completion HTTP client."""
# Standard library imports
import logging
import os
# Qt imports
from qtpy.QtCore import Slot
# Local imports
from langchain_provider.client import LangchainClient
from langchain_provider.widgets import LangchainStatusWidget
# Spyder imports
from spyder.api.config.decorators import on_conf_change
from spyder.config.base import running_under_pytest, get_module_data_path
from spyder.plugins.completion.api import SpyderCompletionProvider
from spyder.utils.image_path_manager import IMAGE_PATH_MANAGER
logger = logging.getLogger(__name__)
class LangchainProvider(SpyderCompletionProvider):
COMPLETION_PROVIDER_NAME = "langchain"
DEFAULT_ORDER = 1
SLOW = True
CONF_VERSION = "1.0.0"
CONF_DEFAULTS = [
("suggestions", 4),
("language", "Python"),
("model_name", "codellama"), # Replace with your Ollama model name
]
TEMPLATE_PARAM = """You are a helpful assistant in completing following {0} code based
on the previous sentence.
You always complete the code in same line and give {1} suggestions.
Example : a=3 b=4 print
AI : "suggestions": ["print(a)", "print(b)", "print(a+b)"]
Example : a=3 b=4 c
AI : "suggestions": ["c=a+b", "c=a-b", "c=5"]
Format the output as JSON with the following key:
suggestions
"""
def __init__(self, parent, config):
super().__init__(parent, config)
IMAGE_PATH_MANAGER.add_image_path(
get_module_data_path("langchain_provider", relpath="images")
)
self.available_languages = []
self.client = LangchainClient(
None,
model_name=self.get_conf("model_name"),
template=self.TEMPLATE_PARAM.format(
self.get_conf("language"), self.get_conf("suggestions")
),
)
# Signals
self.client.sig_client_started.connect(
lambda: self.sig_provider_ready.emit(self.COMPLETION_PROVIDER_NAME)
)
self.client.sig_client_error.connect(self.set_status_error)
self.client.sig_status_response_ready[str].connect(self.set_status)
self.client.sig_status_response_ready[dict].connect(self.set_status)
self.client.sig_response_ready.connect(
lambda _id, resp: self.sig_response_ready.emit(
self.COMPLETION_PROVIDER_NAME, _id, resp
)
)
# Status bar widget
self.STATUS_BAR_CLASSES = [self.create_statusbar]
self.started = False
# ------------------ SpyderCompletionProvider methods ---------------------
def get_name(self):
return "LangChain"
def send_request(self, language, req_type, req, req_id):
request = {"type": req_type, "file": req["file"], "id": req_id, "msg": req}
self.client.sig_perform_request.emit(request)
def start_completion_services_for_language(self, language):
return self.started
def start(self):
if not self.started:
self.client.start()
self.started = True
def shutdown(self):
if self.started:
self.client.stop()
self.started = False
@Slot(str)
@Slot(dict)
def set_status(self, status):
"""Show Langchain status for the current file."""
self.sig_call_statusbar.emit(
LangchainStatusWidget.ID, "set_value", (status,), {}
)
def set_status_error(self, error_message):
"""Show Langchain status for the current file."""
self.sig_call_statusbar.emit(
LangchainStatusWidget.ID, "set_value", (error_message,), {}
)
def file_opened_closed_or_updated(self, filename, _language):
"""Request status for the given file."""
self.client.sig_perform_status_request.emit(filename)
@on_conf_change(section="completions", option=("enabled_providers", "langchain"))
def on_langchain_enable_changed(self, value):
self.sig_call_statusbar.emit(LangchainStatusWidget.ID, "set_value", (None,), {})
@on_conf_change
def update_langchain_configuration(self, config):
if running_under_pytest():
if not os.environ.get("SPY_TEST_USE_INTROSPECTION"):
return
self.client.update_configuration(
self.get_conf("model_name"),
self.TEMPLATE_PARAM.format(
self.get_conf("language"), self.get_conf("suggestions")
),
)
def create_statusbar(self, parent):
return LangchainStatusWidget(parent, self)
Hi, I was recently trying VS code with the Continue Plugin, configured to use my own OLLAMA server and LLMs (https://ollama.ai) and was amazed how well this works.
I'm not a big fan of shipping my code to companies to train the models they sell, so I'm not a fan of copilot and the like. But with the option of local open-source LLMs this becomes a game changer. Related to #20632