vocodedev / vocode-core

🤖 Build voice-based LLM agents. Modular + open source.
https://vocode.dev
MIT License
2.51k stars 415 forks source link

[Feature]: Groq Agent #526

Open Arunprakaash opened 3 months ago

Arunprakaash commented 3 months ago

Brief Description

. add support for chat groq agent

Rationale

  1. Faster streamin response

Suggested Implementation

vocode/streaming/agent/groq_agent.py

import logging
from typing import AsyncGenerator
from typing import Optional, Tuple

from langchain import ConversationChain
from langchain.memory import ConversationBufferMemory
from langchain.prompts import (
    ChatPromptTemplate,
    MessagesPlaceholder,
    HumanMessagePromptTemplate,
)
from langchain.schema import ChatMessage, AIMessage, HumanMessage
from langchain_groq import ChatGroq

from vocode import getenv
from vocode.streaming.agent.base_agent import RespondAgent
from vocode.streaming.agent.utils import get_sentence_from_buffer
from vocode.streaming.models.agent import ChatGroqAgentConfig

SENTENCE_ENDINGS = [".", "!", "?"]

class ChatGroqAgent(RespondAgent[ChatGroqAgentConfig]):
    def __init__(
            self,
            agent_config: ChatGroqAgentConfig,
            logger: Optional[logging.Logger] = None,
            groq_api_key: Optional[str] = None,
    ):
        super().__init__(agent_config=agent_config, logger=logger)
        from groq import AsyncGroq

        groq_api_key = groq_api_key or getenv("GROQ_API_KEY")
        if not groq_api_key:
            raise ValueError(
                "GROQ_API_KEY must be set in environment or passed in"
            )
        self.prompt = ChatPromptTemplate.from_messages(
            [
                MessagesPlaceholder(variable_name="history"),
                HumanMessagePromptTemplate.from_template("{input}"),
            ]
        )

        self.llm = ChatGroq(
            model_name=agent_config.model_name,
            groq_api_key=groq_api_key,
        )

        self.groq_client = (
            AsyncGroq(api_key=groq_api_key) if agent_config.generate_responses else None
        )

        self.memory = ConversationBufferMemory(return_messages=True)
        self.memory.chat_memory.messages.append(
            HumanMessage(content=self.agent_config.prompt_preamble)
        )
        if agent_config.initial_message:
            self.memory.chat_memory.messages.append(
                AIMessage(content=agent_config.initial_message.text)
            )

        self.conversation = ConversationChain(
            memory=self.memory, prompt=self.prompt, llm=self.llm
        )

    async def respond(
            self,
            human_input,
            conversation_id: str,
            is_interrupt: bool = False,
    ) -> Tuple[str, bool]:
        text = await self.conversation.apredict(input=human_input)
        self.logger.debug(f"LLM response: {text}")
        return text, False

    async def generate_response(
            self,
            human_input,
            conversation_id: str,
            is_interrupt: bool = False,
    ) -> AsyncGenerator[Tuple[str, bool], None]:
        self.memory.chat_memory.messages.append(HumanMessage(content=human_input))

        bot_memory_message = AIMessage(content="")
        self.memory.chat_memory.messages.append(bot_memory_message)
        prompt = self.llm._create_message_dicts(self.memory.chat_memory.messages, None)[0]

        if self.groq_client:
            streamed_response = await self.groq_client.chat.completions.create(
                messages=prompt,
                model=self.agent_config.model_name,
                stream=True,
                max_tokens=self.agent_config.max_tokens_to_sample,
                stop=None
            )

            buffer = ""
            async for completion in streamed_response:
                buffer += completion.choices[0].delta.content
                sentence, remainder = get_sentence_from_buffer(buffer)
                if sentence:
                    bot_memory_message.content = bot_memory_message.content + sentence
                    buffer = remainder
                    yield sentence, True
                continue

    def update_last_bot_message_on_cut_off(self, message: str):
        for memory_message in self.memory.chat_memory.messages[::-1]:
            if (
                    isinstance(memory_message, ChatMessage)
                    and memory_message.role == "assistant"
            ) or isinstance(memory_message, AIMessage):
                memory_message.content = message
                return

vocode/streaming/models/agent.py

from enum import Enum
from typing import List, Optional, Union

from langchain.prompts import PromptTemplate
from pydantic import validator

from vocode.streaming.models.actions import ActionConfig
from vocode.streaming.models.message import BaseMessage
from .model import TypedModel, BaseModel
from .vector_db import VectorDBConfig

FILLER_AUDIO_DEFAULT_SILENCE_THRESHOLD_SECONDS = 0.5
LLM_AGENT_DEFAULT_TEMPERATURE = 1.0
LLM_AGENT_DEFAULT_MAX_TOKENS = 256
LLM_AGENT_DEFAULT_MODEL_NAME = "text-curie-001"
CHAT_GPT_AGENT_DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
ACTION_AGENT_DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613"
CHAT_ANTHROPIC_DEFAULT_MODEL_NAME = "claude-v1"
CHAT_VERTEX_AI_DEFAULT_MODEL_NAME = "chat-bison@001"
AZURE_OPENAI_DEFAULT_API_TYPE = "azure"
AZURE_OPENAI_DEFAULT_API_VERSION = "2023-03-15-preview"
AZURE_OPENAI_DEFAULT_ENGINE = "gpt-35-turbo"
CHAT_GROQ_DEFAULT_MODEL_NAME = "mixtral-8x7b-32768"

class AgentType(str, Enum):
    BASE = "agent_base"
    LLM = "agent_llm"
    CHAT_GPT_ALPHA = "agent_chat_gpt_alpha"
    CHAT_GPT = "agent_chat_gpt"
    CHAT_ANTHROPIC = "agent_chat_anthropic"
    CHAT_GROQ = "agent_chat_groq"
    CHAT_VERTEX_AI = "agent_chat_vertex_ai"
    ECHO = "agent_echo"
    GPT4ALL = "agent_gpt4all"
    LLAMACPP = "agent_llamacpp"
    INFORMATION_RETRIEVAL = "agent_information_retrieval"
    RESTFUL_USER_IMPLEMENTED = "agent_restful_user_implemented"
    WEBSOCKET_USER_IMPLEMENTED = "agent_websocket_user_implemented"
    ACTION = "agent_action"

class FillerAudioConfig(BaseModel):
    silence_threshold_seconds: float = FILLER_AUDIO_DEFAULT_SILENCE_THRESHOLD_SECONDS
    use_phrases: bool = True
    use_typing_noise: bool = False

    @validator("use_typing_noise")
    def typing_noise_excludes_phrases(cls, v, values):
        if v and values.get("use_phrases"):
            values["use_phrases"] = False
        if not v and not values.get("use_phrases"):
            raise ValueError("must use either typing noise or phrases for filler audio")
        return v

class WebhookConfig(BaseModel):
    url: str

class AzureOpenAIConfig(BaseModel):
    api_type: str = AZURE_OPENAI_DEFAULT_API_TYPE
    api_version: Optional[str] = AZURE_OPENAI_DEFAULT_API_VERSION
    engine: str = AZURE_OPENAI_DEFAULT_ENGINE

class AgentConfig(TypedModel, type=AgentType.BASE.value):
    initial_message: Optional[BaseMessage] = None
    generate_responses: bool = True
    allowed_idle_time_seconds: Optional[float] = None
    allow_agent_to_be_cut_off: bool = True
    end_conversation_on_goodbye: bool = False
    send_filler_audio: Union[bool, FillerAudioConfig] = False
    webhook_config: Optional[WebhookConfig] = None
    track_bot_sentiment: bool = False
    actions: Optional[List[ActionConfig]] = None

class CutOffResponse(BaseModel):
    messages: List[BaseMessage] = [BaseMessage(text="Sorry?")]

class LLMAgentConfig(AgentConfig, type=AgentType.LLM.value):
    prompt_preamble: str
    expected_first_prompt: Optional[str] = None
    model_name: str = LLM_AGENT_DEFAULT_MODEL_NAME
    temperature: float = LLM_AGENT_DEFAULT_TEMPERATURE
    max_tokens: int = LLM_AGENT_DEFAULT_MAX_TOKENS
    cut_off_response: Optional[CutOffResponse] = None

class ChatGPTAgentConfig(AgentConfig, type=AgentType.CHAT_GPT.value):
    prompt_preamble: str
    expected_first_prompt: Optional[str] = None
    model_name: str = CHAT_GPT_AGENT_DEFAULT_MODEL_NAME
    temperature: float = LLM_AGENT_DEFAULT_TEMPERATURE
    max_tokens: int = LLM_AGENT_DEFAULT_MAX_TOKENS
    cut_off_response: Optional[CutOffResponse] = None
    azure_params: Optional[AzureOpenAIConfig] = None
    vector_db_config: Optional[VectorDBConfig] = None

class ChatAnthropicAgentConfig(AgentConfig, type=AgentType.CHAT_ANTHROPIC.value):
    prompt_preamble: str
    model_name: str = CHAT_ANTHROPIC_DEFAULT_MODEL_NAME
    max_tokens_to_sample: int = 200

class ChatGroqAgentConfig(AgentConfig, type=AgentType.CHAT_GROQ.value):
    prompt_preamble: str
    model_name: str = CHAT_GROQ_DEFAULT_MODEL_NAME
    max_tokens_to_sample: int = 200
    generate_responses: bool = True

class ChatVertexAIAgentConfig(AgentConfig, type=AgentType.CHAT_VERTEX_AI.value):
    prompt_preamble: str
    model_name: str = CHAT_VERTEX_AI_DEFAULT_MODEL_NAME
    generate_responses: bool = False  # Google Vertex AI doesn't support streaming

class LlamacppAgentConfig(AgentConfig, type=AgentType.LLAMACPP.value):
    prompt_preamble: str
    llamacpp_kwargs: dict = {}
    prompt_template: Optional[Union[PromptTemplate, str]] = None

class InformationRetrievalAgentConfig(
    AgentConfig, type=AgentType.INFORMATION_RETRIEVAL.value
):
    recipient_descriptor: str
    caller_descriptor: str
    goal_description: str
    fields: List[str]
    # TODO: add fields for IVR, voicemail

class EchoAgentConfig(AgentConfig, type=AgentType.ECHO.value):
    pass

class GPT4AllAgentConfig(AgentConfig, type=AgentType.GPT4ALL.value):
    prompt_preamble: str
    model_path: str
    generate_responses: bool = False

class RESTfulUserImplementedAgentConfig(
    AgentConfig, type=AgentType.RESTFUL_USER_IMPLEMENTED.value
):
    class EndpointConfig(BaseModel):
        url: str
        method: str = "POST"

    respond: EndpointConfig
    generate_responses: bool = False
    # generate_response: Optional[EndpointConfig]

class RESTfulAgentInput(BaseModel):
    conversation_id: str
    human_input: str

class RESTfulAgentOutputType(str, Enum):
    BASE = "restful_agent_base"
    TEXT = "restful_agent_text"
    END = "restful_agent_end"

class RESTfulAgentOutput(TypedModel, type=RESTfulAgentOutputType.BASE):
    pass

class RESTfulAgentText(RESTfulAgentOutput, type=RESTfulAgentOutputType.TEXT):
    response: str

class RESTfulAgentEnd(RESTfulAgentOutput, type=RESTfulAgentOutputType.END):
    pass

streaming conversation usage of groq agent

import asyncio
import logging
import signal
from dotenv import load_dotenv

from vocode.streaming.agent.groq_agent import ChatGroqAgent

load_dotenv()

from vocode.streaming.streaming_conversation import StreamingConversation
from vocode.helpers import create_streaming_microphone_input_and_speaker_output
from vocode.streaming.transcriber import *
from vocode.streaming.agent import *
from vocode.streaming.synthesizer import *
from vocode.streaming.models.transcriber import *
from vocode.streaming.models.agent import *
from vocode.streaming.models.synthesizer import *
from vocode.streaming.models.message import BaseMessage

logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

async def main():
    (
        microphone_input,
        speaker_output,
    ) = create_streaming_microphone_input_and_speaker_output(
        use_default_devices=False,
        logger=logger,
        use_blocking_speaker_output=True,  # this moves the playback to a separate thread, set to False to use the main thread
    )

    conversation = StreamingConversation(
        output_device=speaker_output,
        transcriber=DeepgramTranscriber(
            DeepgramTranscriberConfig.from_input_device(
                microphone_input,
                endpointing_config=PunctuationEndpointingConfig(),
            )
        ),
        agent=ChatGroqAgent(
            ChatGroqAgentConfig(
                initial_message=BaseMessage(text="What up"),
                prompt_preamble="""The AI is having a pleasant conversation about life""",
            )
        ),
        synthesizer=AzureSynthesizer(
            AzureSynthesizerConfig.from_output_device(speaker_output)
        ),
        logger=logger,
    )
    await conversation.start()
    print("Conversation started, press Ctrl+C to end")
    signal.signal(
        signal.SIGINT, lambda _0, _1: asyncio.create_task(conversation.terminate())
    )
    while conversation.is_active():
        chunk = await microphone_input.get_audio()
        conversation.receive_audio(chunk)

if __name__ == "__main__":
    asyncio.run(main())
spikecodes commented 3 months ago

I'd recommend opening this as a PR so changes can be reviewed and merged into the project

Kevin7744 commented 2 months ago

Hey @Arunprakaash , are you not getting errors with ConversationChain(self.llm) BaseLanguageModel?

Arunprakaash commented 2 months ago

I have not tested that yet. I have just replaced the ChatGPT agent with this one and got it working. Once I've figured out everything, I'll create a pull request.

github-actions[bot] commented 1 week ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

Scylla2020 commented 1 week ago

This gives an error

  File "C:\Users\UserX\AppData\Local\Programs\Python\Python310\lib\site-packages\pydantic\deprecated\class_validators.py", line 249, in root_validator
    raise PydanticUserError(
pydantic.errors.PydanticUserError: If you use `@root_validator` with pre=False (the default) you MUST specify `skip_on_failure=True`. Note that `@root_validator` is deprecated and should be replaced with `@model_validator`.