langchain-ai / langchainjs

🦜🔗 Build context-aware reasoning applications 🦜🔗
https://js.langchain.com/docs/
MIT License
12.42k stars 2.1k forks source link

[Feature Request] Google Gemini Vertex AI Function Calling #4253

Closed natapol2547 closed 6 months ago

natapol2547 commented 8 months ago

The new @google-cloud/vertexai library makes it very easy to use function calling with Google Gemini Models. It would be awesome to see more models other than OpenAI that supports Function Calling in Langchain js 👍

dosubot[bot] commented 8 months ago

🤖

Hey there, @natapol2547! Good to see you back with more ideas.

Thank you for your feature request, natapol2547. It's indeed possible to integrate the @google-cloud/vertexai library into LangChainJS to support function calling with Google Gemini Models.

LangChainJS is designed to be flexible and extensible, allowing for integration with external libraries like @google-cloud/vertexai. As an example, here's a hypothetical code snippet showing how you might use the VertexAI client within the LangChainJS framework:

import { ChatGoogleVertexAI } from "@langchain/community/chat_models/googlevertexai";
import { VertexAI } from "@google-cloud/vertexai";

// Initialize the Google Vertex AI client
const vertexClient = new VertexAI();

const model = new ChatGoogleVertexAI({
  temperature: 0.7,
  vertexClient: vertexClient, // Pass the Vertex AI client to the LangChain model
});

In this example, the Vertex AI client is passed to the LangChain model, allowing the model to make function calls using the Vertex AI client. Please note that this is a hypothetical example and the actual implementation may vary based on the specific requirements and the actual API of the LangChainJS and @google-cloud/vertexai libraries.

As for function calling with OpenAI models, LangChainJS currently handles this using the call method of the OpenAI class. The call method takes a prompt and an options object, which can contain various settings for the call. Here's an example:

const model = new OpenAI({ temperature: 0 });

const res = await model.call(
  "What would be a good company name a company that makes colorful socks?",
  {
    options: {
      headers: {
        "User-Id": "123",
      },
    },
  }
);
console.log(res);

In this code, an instance of the OpenAI class is created with a temperature of 0. The call method is then used to make a request to the OpenAI API with the provided prompt and options. The response from the API is logged to the console.

I hope this information is helpful. If you have any further questions or need more clarification, please don't hesitate to ask.

Best regards, Dosu

Sources


This response is meant to be useful and save you time. It is not meant to be a precise solution, but rather a starting point for your own research.

Help me be more useful! Please leave a 👍 if this is helpful and 👎 if it is irrelevant.

If you want to continue the conversation, start your reply with @dosu-bot.

afirstenberg commented 8 months ago

Working on it! (But always appreciate help.)

My biggest challenge is that I haven't used OpenAI's or LangChainJS' implementation, so I'm not entirely sure I'm capturing how it is used correctly. Some simple examples and test cases of what you're looking to see would be appreciated.

I won't be using the vertex library itself, but will be using REST. See the @langchain/google-common package to see the current base package.

natapol2547 commented 8 months ago

@afirstenberg Yea. I am also not familiar with Langchain JS implementation, but I'll share what I know. (If I'm wrong on some parts, I'm sorry 🙏)

Function Calling with Langchain JS

I believe there are two ways:

  1. Pass the functions or tools as inputs for .invoke
    const result = await model.invoke([new HumanMessage("What a beautiful day!")], {
    functions: [extractionFunctionSchema],
    function_call: { name: "extractor" },
    });
  2. Bind the functions or tools first using .bind and then use .invoke (.bind basically attaches any call arguments passed in to all future calls)
    const model = new ChatOpenAI({ modelName: "gpt-4" }).bind({
    functions: [extractionFunctionSchema],
    function_call: { name: "extractor" },
    });
    const result = await model.invoke([new HumanMessage("What a beautiful day!")]);

    The tools argument in defined in the ChatOpenAICallOptions So we can basically update GoogleVertexAIChatInput to have the tools argument with type of StructuredToolInterface[] | ChatVertexAITool[] also.

For reference: Langchain JS OpenAI Implementation: https://github.com/langchain-ai/langchainjs/blob/5f5507ddab86b1ce5b4c6b558d7cc10264f81a01/libs/langchain-openai/src/chat_models.ts#L433-L435

Then we convert the Structured Tools to Gemini Function schema which uses OpenAPI Schema. Vertex AI Function Calling Document

We may need to create a .createTool method in ‎BaseChatGoogleVertexAI and update .createInstance to be able to do Multi-turn function calling (sending back function output) and support the ToolMessage too, so it would be easier to create AgentExecutor in the future.

For reference: Gemini Request body API

Lastly we need to update the .convertPrediction and GoogleVertexAIChatMessage to support content with role of "functionCall" and "function"

I believe that in the OpenAI implementation, the function and tools calls are placed as kwargs for an AIMessage. https://github.com/langchain-ai/langchainjs/blob/5f5507ddab86b1ce5b4c6b558d7cc10264f81a01/libs/langchain-openai/src/chat_models.ts#L101-L109

I hope this is useful 👍

natapol2547 commented 7 months ago

@afirstenberg

Here is my first draft of ChatGoogleVertexAI 👍👍 It uses @google-cloud/vertexai version 0.3.0 (0.3.1 of the library seems to have a bug where the response returns a promise instead)

The tools along with streaming are working.

The convert_system_message_to_human option is also here, just like Langchain Python.

chat_google_vertex_ai.ts

import {
    type BaseChatModelParams,
    BaseChatModel,
} from "@langchain/core/language_models/chat_models";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import {
    AIMessageChunk,
    HumanMessage,
    type BaseMessage,
    AIMessage,
} from "@langchain/core/messages";
import { ChatGenerationChunk, type ChatResult } from "@langchain/core/outputs";
import * as GoogleVertexAI from "@google-cloud/vertexai";
import type { BaseLanguageModelCallOptions } from "@langchain/core/language_models/base";
import type { StructuredToolInterface } from "@langchain/core/tools";
import { zodToJsonSchema } from "zod-to-json-schema";

export interface ChatGoogleVertexAIInput
    extends BaseChatModelParams,
        GoogleVertexAI.VertexInit,
        GoogleVertexAI.GenerationConfig {
    model?: string;
    tools?: StructuredToolInterface[] | GoogleVertexAI.Tool[];
    safety_settings?: GoogleVertexAI.SafetySetting[];
    convert_system_message_to_human?: boolean;
}

export interface ChatGoogleVertexAICallOptions
    extends BaseLanguageModelCallOptions {
    tools?: StructuredToolInterface[] | GoogleVertexAI.Tool[];
}

function convertToGoogleVertexAITools(
    tools: StructuredToolInterface[]
): GoogleVertexAI.Tool[] {
    return [
        {
            function_declarations: tools.map((tool) => {
                return {
                    name: tool.name,
                    description: tool.description,
                    parameters: zodToJsonSchema(tool.schema) as any,
                };
            }),
        },
    ];
}

export class ChatGoogleVertexAI<
    CallOptions extends ChatGoogleVertexAICallOptions = ChatGoogleVertexAICallOptions
> extends BaseChatModel<CallOptions> {
    model: string;

    vertexAI: GoogleVertexAI.VertexAI;

    safety_settings?: GoogleVertexAI.SafetySetting[];

    generation_config?: GoogleVertexAI.GenerationConfig;

    convert_system_message_to_human?: boolean;

    constructor(fields: ChatGoogleVertexAIInput) {
        super(fields);
        this.model = fields.model || "gemini-pro";
        this.vertexAI = new GoogleVertexAI.VertexAI({
            project: fields.project,
            location: fields.location,
            apiEndpoint: fields.apiEndpoint,
            googleAuthOptions: fields.googleAuthOptions,
        });
        this.safety_settings = fields.safety_settings;
        this.generation_config = {
            candidate_count: fields.candidate_count,
            stop_sequences: fields.stop_sequences,
            max_output_tokens: fields.max_output_tokens,
            temperature: fields.temperature,
            top_p: fields.top_p,
            top_k: fields.top_k,
        };
        this.convert_system_message_to_human =
            fields.convert_system_message_to_human;
    }

    _llmType() {
        return "chat_google_vertex_ai";
    }

    async _generate(
        messages: BaseMessage[],
        options: this["ParsedCallOptions"],
        runManager?: CallbackManagerForLLMRun | undefined
    ): Promise<ChatResult> {
        if (!messages.length) {
            throw new Error("No messages provided.");
        }
        if (typeof messages[0].content !== "string") {
            throw new Error("Multimodal messages are not supported.");
        }
        if (this.convert_system_message_to_human) {
            messages = this.convertSystemMessages(messages);
        }
        const generative_model = this.vertexAI.preview.getGenerativeModel({
            ...this.invocationParams(options),
        });
        // console.log(this.invocationParams(options))
        const vertexAI_content = this.mapChatMessagesToVertexAIContent(messages);
        const bug = await generative_model.generateContent({
            contents: vertexAI_content,
            ...this.invocationParams(options),
        });
        const response = (await bug.response).candidates[0];
        switch (response.finishReason) {
            case "STOP":
                if (response.content.parts && response.content.parts[0].functionCall) {
                    // Is function Call
                    return {
                        generations: [
                            {
                                message: new AIMessage({
                                    content: "",
                                    additional_kwargs: {
                                        function_call: {
                                            name: response.content.parts[0].functionCall.name,
                                            arguments: JSON.stringify(
                                                response.content.parts[0].functionCall.args
                                            ),
                                        },
                                        finishReason: response.finishReason,
                                    },
                                }),
                                text: "",
                            },
                        ],
                    };
                } else {
                    const response_text = response.content.parts[0].text || "";
                    return {
                        generations: [
                            {
                                message: new AIMessage({
                                    content: response_text,
                                    additional_kwargs: {
                                        finishReason: response.finishReason,
                                    },
                                }),
                                text: response_text,
                            },
                        ],
                    };
                }
            case "RECITATION":
                return {
                    generations: [
                        {
                            message: new AIMessage({
                                content: "โปรดดูเอกสารด้านล่างได้เลยครับ",
                                additional_kwargs: {
                                    finishReason: response.finishReason,
                                },
                            }),
                            text: "โปรดดูเอกสารด้านล่างได้เลยครับ",
                        },
                    ],
                };
            default:
                return {
                    generations: [
                        {
                            message: new AIMessage({
                                content: "",
                                additional_kwargs: {
                                    finishReason: response.finishReason,
                                },
                            }),
                            text: "",
                        },
                    ],
                };
        }
    }

    async *_streamResponseChunks(
        messages: BaseMessage[],
        options: this["ParsedCallOptions"],
        runManager?: CallbackManagerForLLMRun
    ): AsyncGenerator<ChatGenerationChunk> {
        if (!messages.length) {
            throw new Error("No messages provided.");
        }

        const generative_model = this.vertexAI.preview.getGenerativeModel({
            ...this.invocationParams(options),
        });
        const vertexAI_content = this.mapChatMessagesToVertexAIContent(messages);
        const streamingResp = await generative_model.generateContentStream({
            contents: vertexAI_content,
            ...this.invocationParams(options),
        });
        console.log(generative_model.tools);
        for await (const item of streamingResp.stream) {
            const text_chunk = item.candidates[0].content.parts[0].text || "";
            yield new ChatGenerationChunk({
                message: new AIMessageChunk({
                    content: text_chunk,
                }),
                text: text_chunk,
            });
            await runManager?.handleLLMNewToken(text_chunk);
        }
    }

    invocationParams(
        options?: this["ParsedCallOptions"]
    ): GoogleVertexAI.ModelParams {
        function isStructuredToolArray(
            tools?: unknown[]
        ): tools is StructuredToolInterface[] {
            return (
                tools !== undefined &&
                tools.every((tool) =>
                    Array.isArray((tool as StructuredToolInterface).lc_namespace)
                )
            );
        }
        const params: GoogleVertexAI.ModelParams = {
            model: this.model,
            generation_config: this.generation_config,
            safety_settings: this.safety_settings,
            tools: isStructuredToolArray(options?.tools)
                ? convertToGoogleVertexAITools(options?.tools)
                : options?.tools,
        };
        // if (params.tools) {
        //  console.log(params.tools[0].function_declarations[0].parameters);
        // }

        return params;
    }
    _combineLLMOutput() {
        return {};
    }

    private mapChatMessagesToVertexAIContent(
        messages: BaseMessage[]
    ): GoogleVertexAI.Content[] {
        let vertexAIContent: GoogleVertexAI.Content[] = [];
        messages.forEach((message) => {
            vertexAIContent.push(this.mapMessageToVertexAIPart(message));
        });
        return vertexAIContent;
    }

    private mapMessageToVertexAIPart(
        message: BaseMessage
    ): GoogleVertexAI.Content {
        const type = message._getType();
        if (typeof message.content !== "string") {
            throw new Error("Multimodal messages are not supported yet.");
        }
        switch (type) {
            case "system":
                throw new Error(`Google Vertex AI does not support system message`);
            case "ai":
                if (!message.additional_kwargs.function_call) {
                    const ai_part: GoogleVertexAI.TextPart = {
                        text: message.content,
                    };
                    return {
                        role: "model",
                        parts: [ai_part],
                    };
                } else {
                    const ai_part: GoogleVertexAI.FunctionCallPart = {
                        functionCall: {
                            name: message.additional_kwargs.function_call.name,
                            args: JSON.parse(
                                message.additional_kwargs.function_call.arguments
                            ),
                        },
                    };
                    return {
                        role: "model",
                        parts: [ai_part],
                    };
                }

            case "human":
                const human_part: GoogleVertexAI.TextPart = {
                    text: message.content,
                };
                return {
                    role: "user",
                    parts: [human_part],
                };
            case "function":
                if (!message.name || !message.content) {
                    throw new Error(`name or content is not defined in FunctionMessage`);
                }
                const function_part: GoogleVertexAI.FunctionResponsePart = {
                    functionResponse: {
                        name: message.name,
                        response: JSON.parse(message.content),
                    },
                };
                return {
                    role: "functionCall",
                    parts: [function_part],
                };
            default:
                throw new Error(`Unknown message type: ${type}`);
        }
    }

    private convertSystemMessages(messages: BaseMessage[]): BaseMessage[] {
        const convertedMessages: BaseMessage[] = [];

        for (const message of messages) {
            if (message._getType() === "system") {
                if (typeof message.content !== "string") {
                    throw new Error(`SystemMessage content is not a string`);
                }
                convertedMessages.push(new HumanMessage(message.content));
                convertedMessages.push(new AIMessage("Understood"));
            } else {
                convertedMessages.push(message);
            }
        }

        return convertedMessages;
    }

    private isStructuredToolArray(
        tools?: unknown[]
    ): tools is StructuredToolInterface[] {
        return (
            tools !== undefined &&
            tools.every((tool) =>
                Array.isArray((tool as StructuredToolInterface).lc_namespace)
            )
        );
    }
}
pnpm add langchain zod-to-json-schema @google-cloud/vertexai@0.3.0

Try testing it 👍

import { ChatGoogleVertexAI } from "./chat_google_vertex_ai";

import { WikipediaQueryRun } from "@langchain/community/tools/wikipedia_query_run";

// Note to developers: You need to authenticate first. Take a look at https://cloud.google.com/docs/authentication/client-libraries
// Or use the following instead (Create a service account and get PROJECT_ID PRIVATE_KEY CLIENT_EMAIL in service account json)
// const llm = new ChatGoogleVertexAI({
//  project: FB_PROJECT_ID,
//  location: "asia-southeast1",
//  googleAuthOptions: {
//      credentials: {
//          project_id: PROJECT_ID,
//          private_key: PRIVATE_KEY,
//          client_email: CLIENT_EMAIL,
//      },
//  },
// });

const llm = new ChatGoogleVertexAI({
    project: FB_PROJECT_ID,
    location: "asia-southeast1",
});

export async function test(query: string) {
    const res1 = await llm.invoke(query);

    console.log(res1);

    const stream = await llm.stream(query);

    for await (const chunk of stream) {
        console.log(chunk);
    }

    const tools = [
        new WikipediaQueryRun({
            topKResults: 3,
            maxDocContentLength: 4000,
        }),
    ]

     const llm_with_tool = llm.bind({tools})

    const res3 = await llm_with_tool.invoke("On Wikipedia search for Google Deepmind");

    console.log(res3);
}
yadavj2008 commented 6 months ago

Hello, I'm also looking for this feature, wondering if there is any immediate plan for add this feature into main?

BleedingDev commented 6 months ago

Right now I am working on rewriting Vercel AI RSC demo to use LangChain including function calling. It would be awesome to have this done with more than just ChatOpenAI. :) Looking forward to this PR!

afirstenberg commented 6 months ago

Right now I am working on rewriting Vercel AI RSC demo to use LangChain including function calling. It would be awesome to have this done with more than just ChatOpenAI. :) Looking forward to this PR!

Good timing! See #4864 which I just submitted the PR for.

BleedingDev commented 6 months ago

Awesome, I love that! So if I am correct this snippet should work when I swap ChatOpenAI for ChatGoogle or is there anything more to edit? :)

Maybe even createOpenAIFunctionsAgent? Because it is tied to OpenAI (at least by name).

const llm = new ChatOpenAI({
  streaming: true,
  temperature: 0,
})
const prompt = await pull<ChatPromptTemplate>("hwchase17/openai-functions-agent")
const tools: StructuredToolInterface[] = functions.map(
  (fn) =>
    new DynamicStructuredTool({
      name: fn.name,
      description: fn.description || "",
      schema: fn.parameters,
      func: (fn as any).func,
    }),
)

const agent = await createOpenAIFunctionsAgent({
  llm,
  tools,
  prompt,
})

const agentExecutor = new AgentExecutor({
  agent,
  tools,
})

const stream = await agentExecutor.stream({
  input: messages.at(-1)?.content,
  chatHistory: messages.slice(0, -1),
})
afirstenberg commented 6 months ago

I'll be honest, @BleedingDev - I'm kinda turning to you for the answer to your question. 😁

I've tried to make it compatible with the OpenAI way of doing things, but the OpenAI way is something of a mess, so I'm not sure if I've gotten it right.

I'm wondering if createOpenAIFunctionsAgent() was replaced with llm.withStructuredOutput(), but @jacoblee93 is working on adding that part and probably knows better than I do.

BleedingDev commented 6 months ago

Oh, great. I will try to validate this. To be honest it is tough to operate in this field - it is still not well supported. But I am doing this for my client, so research is important. :)

Let's see what @jacoblee93 tells us about it. Would be great to be as universal as possible.

jacoblee93 commented 6 months ago

We're looking into a more general way of performing agent outputs that takes advantage of structured output but doesn't directly use tool/function calling, but it's not ready yet!

@afirstenberg I haven't looked over the PR yet but yeah you'd theoretically be able to drop in Gemini for OpenAI for the function calling agent if it supports the same interface.

afirstenberg commented 6 months ago

Well, hopefully I supported the interface correctly then! {: That was certainly my goal. (As painful as it was.)

If not - let's identify where it fell short and fix it.

SebastianStehle commented 6 months ago

When are you planing to release that? does not seem to be part of the current version.

afirstenberg commented 6 months ago

@SebastianStehle - This should be in the latest versions of @langchain/google-vertexai and @langchain/google-vertexai-web.

What are you seeing?

SebastianStehle commented 6 months ago

I had a look to the changelog and could not see it there. But I will test it asap.