spring-projects / spring-ai

An Application Framework for AI Engineering
https://docs.spring.io/spring-ai/reference/index.html
Apache License 2.0
3.33k stars 855 forks source link

Chat history breaks function calling #1301

Open LiveNathan opened 2 months ago

LiveNathan commented 2 months ago

Bug description When chat history is enabled, the AI provides incorrect information about available functions, seemingly relying on outdated data from the chat history instead of querying the current available functions.

Environment

Steps to reproduce

  1. Enable chat history in the application.
  2. Ask the AI, "What functions are available?"
  3. Modify the available functions.
  4. Ask the AI, "What functions are available?"
  5. Observe that the response includes incorrect or outdated functions that are not part of the currently available functions.
  6. Disable chat history.
  7. Ask the same question, "What functions are available?"
  8. Observe that the response correctly lists the current available functions.

Expected behavior When chat history is enabled, the AI should provide the correct list of available functions by querying the current data, rather than relying on outdated or incorrect information from the chat history.

Minimal Complete Reproducible example Below is a snippet of the code demonstrating the setup of the SpringAiAdapter and the issue. Notice that chat history is disabled.

public class SpringAiAdapter implements AiGateway {
    private final ChatClient chatClient;
    private final ProjectRepository projectRepository;
    private static final Logger log = LoggerFactory.getLogger(SpringAiAdapter.class);

    public SpringAiAdapter(ChatClient.Builder builder, ProjectRepository projectRepository, VectorStore vectorStore) {
        this.projectRepository = projectRepository;
        this.chatClient = builder.defaultSystem("""
                        You are an assistant project manager expert at managing many resources and schedules.
                        Adopt the user's tone to make them feel comfortable with you. If they are playful and silly, so are you. If they are professional and matter-of-fact, so are you.
                        Keep your responses short and direct because people need your help in a hurry, but for complex tasks, think out loud by writing each step.
                        For questions about long documents, pull the most relevant quote from the document and consider whether it answers the user's question or whether it lacks sufficient detail.
                        Today is {current_date}. This message was sent by {user_name} at exactly {message_creation_time}.
                        Available projects are: {available_projects}. The project name is its natural identifier.
                        When calling functions, always use the exact name of the project as provided here. For example, a user's request may reference `projct a`, `12345`, or simply `A`, but if `Project A (12345)` is on the list of available projects, then function calls should be made with `Project A (12345)`. However, if the user's request references a significantly different project name like `projct b`, `54333`, or simply `B`, then the request should be rejected.""")
                .defaultAdvisors(
                        // new VectorStoreChatMemoryAdvisor(vectorStore),
                        new LoggingAdvisor())
                .build();
    }

    @Override
    public Flux<String> sendMessageAndReceiveReplies(UserMessageDto userMessageDto) {
        String projectNames = String.join(", ", projectRepository.findAllNames());
        return chatClient.prompt()
                .system(sp -> sp.params(Map.of(
                        "current_date", LocalDate.now().toString(),
                        "message_creation_time", userMessageDto.creationTime().toString(),
                        "user_name", userMessageDto.userName(),
                        "available_projects", projectNames
                )))
                .functions("clockInFunction", "clockOutFunction", "findAllProjectNamesFunction", "updateProjectFunction")
                .user(userMessageDto.userMessageText())
                .advisors(advisorSpec -> advisorSpec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, userMessageDto.chatId())
                        .param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 100))
                .stream()
                .content()
                .onErrorResume(WebClientResponseException.class, e -> {
                    log.error("WebClient request failed with status: {} and response body: {}",
                            e.getStatusCode(), e.getResponseBodyAsString(), e);
                    return Mono.error(new RuntimeException("Failed to communicate with AI service", e));
                });
    }
}

Reproduction output examples

With chat history disabled:

Nathan
Aug 31, 2024, 4:24 PM
What functions are available?
DroidComm
Aug 31, 2024, 4:24 PM
Here are the available functions:

1. **clockOutFunction** - Update an existing timesheet entry for a resource by adding a clock-out time.
2. **updateProjectFunction** - Update an existing timesheet entry's project.
3. **clockInFunction** - Create a new timesheet entry for a resource, optionally including a project name.
4. **findAllProjectNamesFunction** - Fetch the list of available project names.

Let me know how I can assist you with any of these!

With chat history enabled:

Nathan
Aug 31, 2024, 4:25 PM
What functions are available?
DroidComm
Aug 31, 2024, 4:25 PM
The available functions are:

1. **clockInFunction**: Create a new timesheet entry for a resource.
2. **updateProjectFunction**: Update an existing timesheet entry's project.
3. **findAllProjectNamesFunction**: Fetch the list of available project names.
4. **unfulfilledRequestFunction**: Handle requests or commands that the AI cannot fulfill by logging them as potential feature requests or bug reports.

If you need assistance with any of these functions, just let me know!

The response with chat history enabled shows an incorrect list of functions, including "unfulfilledRequestFunction," which is not actually available in the current setup.

LiveNathan commented 2 months ago

wow, this is VERY annoying. Every time I update the functions I have to completely delete the history.

markpollack commented 2 months ago

Thanks for reporting this, we will look into it ASAP.

tzolov commented 2 months ago

Hi @LiveNathan can you share please what Spring AI version are which chat model are us using?

tzolov commented 2 months ago

Also is this observer only with the VectorStoreChatMemoryAdvisor or with the other ChatMemory advisors as well?

LiveNathan commented 2 months ago

Hi @LiveNathan can you share please what Spring AI version are which chat model are us using?

1.0.0-M2

<properties>
        <java.version>21</java.version>
        <vaadin.version>24.4.9</vaadin.version>
        <spring-ai.version>1.0.0-M2</spring-ai.version>
        <eclipsestore.version>1.4.0</eclipsestore.version>
    </properties>

    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>3.3.3</version>
    </parent>
tzolov commented 2 months ago

@LiveNathan, I've wrote a dedicated integration tests (below) that applies all 3 chat memory advisors and mimics your system message and configuration. It works fine on latest Spring AI 1.0.0-SN and OpenAI GPT-4o-mini:

/*
 * Copyright 2023 - 2024 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.springframework.ai.openai.chat.client;

import static org.assertj.core.api.Assertions.assertThat;

import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

import javax.sql.DataSource;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.PromptChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.VectorStoreChatMemoryAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.InMemoryChatMemory;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.OpenAiEmbeddingModel;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.vectorstore.PgVectorStore;
import org.springframework.ai.vectorstore.PgVectorStore.PgDistanceType;
import org.springframework.ai.vectorstore.PgVectorStore.PgIndexType;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
import org.springframework.boot.autoconfigure.jdbc.DataSourceProperties;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Description;
import org.springframework.context.annotation.Primary;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.util.StringUtils;
import org.testcontainers.containers.PostgreSQLContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;

import com.zaxxer.hikari.HikariDataSource;

@SpringBootTest(classes = SpringAiIssue1301IT.Conf.class)
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
@Testcontainers
class SpringAiIssue1301IT {

    private static final Logger logger = LoggerFactory.getLogger(SpringAiIssue1301IT.class);

    @Container
    @SuppressWarnings("resource")
    static PostgreSQLContainer<?> postgresContainer = new PostgreSQLContainer<>("pgvector/pgvector:pg16")
        .withUsername("postgres")
        .withPassword("postgres");

    @Autowired
    protected ChatModel chatModel;

    @Autowired
    protected EmbeddingModel embeddingModel;

    @Autowired
    protected VectorStore vectorStore;

    @Test
    void functionCallTest() {

        ChatMemory chatMemory = new InMemoryChatMemory();

        // @formatter:off
        var chatClient = ChatClient.builder(chatModel)
            .defaultSystem("""
                        You are an assistant project manager expert at managing many resources and schedules.
                        Adopt the user's tone to make them feel comfortable with you. If they are playful and silly, so are you. If they are professional and matter-of-fact, so are you.
                        Keep your responses short and direct because people need your help in a hurry, but for complex tasks, think out loud by writing each step.
                        For questions about long documents, pull the most relevant quote from the document and consider whether it answers the user's question or whether it lacks sufficient detail.
                        Today is {current_date}. This message was sent by {user_name} at exactly {message_creation_time}.
                        Available projects are: {available_projects}. The project name is its natural identifier.
                        When calling functions, always use the exact name of the project as provided here. For example, a user's request may reference `project a`, `12345`, or simply `A`,
                        but if `Project A (12345)` is on the list of available projects, then function calls should be made with `Project A (12345)`. However, if the user's request references
                        a significantly different project name like `project b`, `54333`, or simply `B`, then the request should be rejected.
                        """)            
            .defaultAdvisors(new MessageChatMemoryAdvisor(chatMemory))
            .defaultAdvisors(new PromptChatMemoryAdvisor(chatMemory))
            .defaultAdvisors(new VectorStoreChatMemoryAdvisor(vectorStore))
            .build();           
            // @formatter:on

        for (int i = 0; i < 5; i++) {
        // @formatter:off
            String response = chatClient.prompt()
                .system(sp -> sp.params(Map.of(
                        "current_date", LocalDate.now().toString(),
                        "message_creation_time", LocalDateTime.now(),
                        "user_name", "Alice",
                        "available_projects", List.of("Project A (12345)", "Project B (54333)")
                )))
                .user(u -> u.text("What functions are available? List the function names."))
                .functions("getCurrentWeather", "clockInFunction", "clockOutFunction", "findAllProjectNamesFunction", "updateProjectFunction")
                .advisors(advisorSpec -> advisorSpec.param(AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY, "696")
                                                    .param(AbstractChatMemoryAdvisor.CHAT_MEMORY_RETRIEVE_SIZE_KEY, 100))
                .call()
                .content();
                // @formatter:on

            logger.info("Response: {}", response);

            assertThat(response).contains("getCurrentWeather", "clockInFunction", "clockOutFunction",
                    "findAllProjectNamesFunction", "updateProjectFunction");

        }
    }

    @Test
    void streamFunctionCallTest() {

        ChatMemory chatMemory = new InMemoryChatMemory();

        // @formatter:off
        var chatClient = ChatClient.builder(chatModel)
            .defaultSystem("""
                        You are an assistant project manager expert at managing many resources and schedules.
                        Adopt the user's tone to make them feel comfortable with you. If they are playful and silly, so are you. If they are professional and matter-of-fact, so are you.
                        Keep your responses short and direct because people need your help in a hurry, but for complex tasks, think out loud by writing each step.
                        For questions about long documents, pull the most relevant quote from the document and consider whether it answers the user's question or whether it lacks sufficient detail.
                        Today is {current_date}. This message was sent by {user_name} at exactly {message_creation_time}.
                        Available projects are: {available_projects}. The project name is its natural identifier.
                        When calling functions, always use the exact name of the project as provided here. For example, a user's request may reference `project a`, `12345`, or simply `A`,
                        but if `Project A (12345)` is on the list of available projects, then function calls should be made with `Project A (12345)`. However, if the user's request references
                        a significantly different project name like `project b`, `54333`, or simply `B`, then the request should be rejected.
                        """)            
            .defaultAdvisors(new MessageChatMemoryAdvisor(chatMemory))
            .defaultAdvisors(new PromptChatMemoryAdvisor(chatMemory))
            .defaultAdvisors(new VectorStoreChatMemoryAdvisor(vectorStore))
            .build();           
            // @formatter:on

        for (int i = 0; i < 5; i++) {
        // @formatter:off
            String response = chatClient.prompt()
                .system(sp -> sp.params(Map.of(
                        "current_date", LocalDate.now().toString(),
                        "message_creation_time", LocalDateTime.now(),
                        "user_name", "Alice",
                        "available_projects", List.of("Project A (12345)", "Project B (54333)")
                )))
                .user(u -> u.text("What functions are available? List the function names."))
                .functions("getCurrentWeather", "clockInFunction", "clockOutFunction", "findAllProjectNamesFunction", "updateProjectFunction")
                .advisors(advisorSpec -> advisorSpec.param(AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY, "696")
                                                    .param(AbstractChatMemoryAdvisor.CHAT_MEMORY_RETRIEVE_SIZE_KEY, 100))
                .stream()
                .content().collectList().block().stream().collect(Collectors.joining());
                // @formatter:on

            logger.info("Response: {}", response);

            assertThat(response).contains("getCurrentWeather", "clockInFunction", "clockOutFunction",
                    "findAllProjectNamesFunction", "updateProjectFunction");
        }
    }

    @Configuration
    @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class })
    public static class Conf {

        record Request(String text) {
        }

        record Response(String text) {
        }

        @Bean
        @Description("Clock in")
        public Function<Request, Response> clockInFunction() {
            return request -> new Response("clockInFunction " + request.text());
        }

        @Bean
        @Description("Clock out")
        public Function<Request, Response> clockOutFunction() {
            return request -> new Response("clockOutFunction " + request.text());
        }

        @Bean
        @Description("Find all project names")
        public Function<Request, Response> findAllProjectNamesFunction() {
            return request -> new Response("findAllProjectNamesFunction " + request.text());
        }

        @Bean
        @Description("Update project")
        public Function<Request, Response> updateProjectFunction() {
            return request -> new Response("updateProjectFunction " + request.text());
        }

        record Location(String location) {
        }

        @Bean
        @Description("Get the weather in location")
        public Function<Location, Response> getCurrentWeather() {
            return request -> new Response("Current temperature is 11C");
        }

        @Bean
        public OpenAiApi openAiApi() {
            String apiKey = System.getenv("OPENAI_API_KEY");
            if (!StringUtils.hasText(apiKey)) {
                throw new IllegalArgumentException(
                        "You must provide an API key.  Put it in an environment variable under the name OPENAI_API_KEY");
            }

            return new OpenAiApi(apiKey);
        }

        @Bean
        public OpenAiChatModel openAiChatModel(OpenAiApi api, FunctionCallbackContext functionCallbackContext) {
            OpenAiChatModel openAiChatModel = new OpenAiChatModel(api,
                    OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_O_MINI).build(),
                    functionCallbackContext, RetryUtils.DEFAULT_RETRY_TEMPLATE);
            return openAiChatModel;
        }

        @Bean
        public OpenAiEmbeddingModel openAiEmbeddingModel(OpenAiApi api) {
            return new OpenAiEmbeddingModel(api);
        }

        @Bean
        @ConditionalOnMissingBean
        public FunctionCallbackContext springAiFunctionManager(ApplicationContext context) {
            FunctionCallbackContext manager = new FunctionCallbackContext();
            manager.setApplicationContext(context);
            return manager;
        }

        @Bean
        public VectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
            return new PgVectorStore(jdbcTemplate, embeddingModel, PgVectorStore.INVALID_EMBEDDING_DIMENSION,
                    PgDistanceType.COSINE_DISTANCE, true, PgIndexType.HNSW, true);
        }

        @Bean
        public JdbcTemplate myJdbcTemplate(DataSource dataSource) {
            return new JdbcTemplate(dataSource);
        }

        @Bean
        @Primary
        public DataSourceProperties dataSourceProperties() {
            var datasource = new DataSourceProperties();
            datasource.setUrl(String.format("jdbc:postgresql://%s:%d/%s", postgresContainer.getHost(),
                    postgresContainer.getMappedPort(5432), "postgres"));
            datasource.setUsername(postgresContainer.getUsername());
            datasource.setPassword(postgresContainer.getPassword());
            return datasource;
        }

        @Bean
        public HikariDataSource dataSource(DataSourceProperties dataSourceProperties) {
            return dataSourceProperties.initializeDataSourceBuilder().type(HikariDataSource.class).build();
        }

    }
}
LiveNathan commented 2 months ago

Also is this observer only with the VectorStoreChatMemoryAdvisor or with the other ChatMemory advisors as well?

IDK. I have only used the vector memory store and the in-memory chat history. Of course, the same problem does not occur with the in-memory chat history because it resets every time the app starts.

tzolov commented 2 months ago

As you can see in my test i'm running the same query 5 times (worked with 10 times as well)
What model are you using?

LiveNathan commented 2 months ago

@LiveNathan, I've wrote a dedicated integration tests (below) that applies all 3 chat memory advisors and mimics your system message and configuration. It works fine on latest Spring AI 1.0.0-SN and OpenAI GPT-4o-mini:

Wow, thanks! So you are effectively:

  1. Getting the list of functions so that it ends up in the history
  2. Changing the list of functions
  3. Getting the list of functions again

And it successfully fetches the updated list of functions and doesn't get confused by the history? Hmmm, ok, I'm not sure what to say, then. For it was pretty clear that it was didn't work until I turned off the vector store history. I'm using chroma, but that shouldn't matter, I guess.

LiveNathan commented 2 months ago

As you can see in my test i'm running the same query 5 times (worked with 10 times as well) What model are you using?

spring.ai.openai.chat.options.model=gpt-4o-mini

If it helps, I just added you as a contributor to https://github.com/LiveNathan/ai-function-calling/ so you can see the code. It's just a demo project for now, but it may turn into a paid product so I'm not quite ready to make it public.

tzolov commented 2 months ago

FYI, my user question is What functions are available? List the function names. E.g. i'm asking explicitly for the names. Not sure how important is this. My test is not changing dynamically the list of functions (not sure you have mentioned this before).
I'm afraid I won't have time to review your project. But perhaps you can we-write my test to reproduce the issue?

LiveNathan commented 2 months ago

FYI, my user question is What functions are available? List the function names. E.g. i'm asking explicitly for the names. Not sure how important is this. My test is not changing dynamically the list of functions (not sure you have mentioned this before).

You're right. I was not clear in my original post. I have updated it. Here's the change.

Steps to reproduce

  1. Enable chat history in the application.
  2. Ask the AI, "What functions are available?"
  3. Modify the available functions.
  4. Ask the AI, "What functions are available?"
  5. Observe that the response includes incorrect or outdated functions that are not part of the currently available functions.
  6. Disable chat history.
  7. Ask the same question, "What functions are available?"
  8. Observe that the response correctly lists the current available functions.

I'm afraid I won't have time to review your project.

Ok, I rescinded the collaboration invitation.

But perhaps you can we-write my test to reproduce the issue?

Ok, that's a little intimidating. I'm just a beginner. I'll give it a shot. :)

LiveNathan commented 2 months ago

Ok, I tried to get the test going for you, but I ran into a snag trying to get ChromaDB setup using Testcontainers. I tried to use the implementation I found on the Testcontainers website, but I've never written a test against ChromaDB or any other vector store for that matter so I'm not sure how to do it.

Tests failed: Caused by: org.springframework.beans.factory.support.BeanDefinitionOverrideException: Invalid bean definition with name 'vectorStore' defined in class path resource [org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfiguration.class]: Cannot register bean definition [Root bean: class [null]; scope=; abstract=false; lazyInit=null; autowireMode=3; dependencyCheck=0; autowireCandidate=true; primary=false; factoryBeanName=org.springframework.ai.autoconfigure.vectorstore.chroma.ChromaVectorStoreAutoConfiguration; factoryMethodName=vectorStore; initMethodNames=null; destroyMethodNames=[(inferred)]; defined in class path resource [org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfiguration.class]] for bean 'vectorStore' since there is already [Root bean: class [null]; scope=; abstract=false; lazyInit=null; autowireMode=3; dependencyCheck=0; autowireCandidate=true; primary=false; factoryBeanName=springAiIssue1301IT.Conf; factoryMethodName=vectorStore; initMethodNames=null; destroyMethodNames=[(inferred)]; defined in dev.nathanlively.adapter.out.ai.SpringAiIssue1301IT$Conf] bound.

package dev.nathanlively.adapter.out.ai;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.VectorStoreChatMemoryAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.InMemoryChatMemory;
import org.springframework.ai.chroma.ChromaApi;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.OpenAiEmbeddingModel;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.vectorstore.ChromaVectorStore;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Description;
import org.springframework.util.StringUtils;
import org.testcontainers.chromadb.ChromaDBContainer;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel;

import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY;
import static org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor.CHAT_MEMORY_RETRIEVE_SIZE_KEY;

@SpringBootTest(classes = SpringAiIssue1301IT.Conf.class)
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
@Testcontainers
public class SpringAiIssue1301IT {
    private static final Logger log = LoggerFactory.getLogger(SpringAiIssue1301IT.class);

    ChromaDBContainer chromadb = new ChromaDBContainer("chromadb/chroma:0.4.22");

    @Autowired
    protected ChatModel chatModel;

    @Autowired
    protected EmbeddingModel embeddingModel;

    @Autowired
    protected VectorStore vectorStore;

    @Test
    void functionCallTest() throws Exception {
        ChatMemory chatMemory = new InMemoryChatMemory();

        var chatClient = ChatClient.builder(chatModel).defaultSystem("""
                        You are an assistant project manager expert at managing many resources and schedules.
                        Adopt the user's tone to make them feel comfortable register you. If they are playful and silly, so are you. If they are professional and matter-of-fact, so are you.
                        Keep your responses short and direct because people need your help in a hurry, but for complex tasks, think out loud by writing each step.
                        For questions about long documents, pull the most relevant quote from the document and consider whether it answers the user's question or whether it lacks sufficient detail.
                        Today is {current_date}. This message was sent by {user_name} at exactly {message_creation_time} instant register {message_creation_timezone} timezone.
                        Available projects are: {available_projects}. The project name is its natural identifier.""")
                .defaultFunctions("clockIn")
                .defaultAdvisors(
                        new MessageChatMemoryAdvisor(chatMemory),
                        new VectorStoreChatMemoryAdvisor(vectorStore))
                .build();

        String response = chatClient.prompt()
                .system(sp -> sp.params(Map.of(
                        "current_date", LocalDate.now().toString(),
                        "message_creation_time", LocalDateTime.now(),
                        "message_creation_timezone", ZoneId.systemDefault().toString(),
                        "user_name", "Alice",
                        "available_projects", List.of("Project A (12345)", "Project B (54333)")
                )))
                .user(u -> u.text("What functions are available? List the function names."))
                .advisors(advisorSpec -> advisorSpec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, "696").param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 100))
.functions("clockIn")
                .call().content();

        log.info("Response: {}", response);

        assertThat(response).contains("clockIn");

       // Update the available functions.
            response = chatClient.prompt()
                    .system(sp -> sp.params(Map.of(
                            "current_date", LocalDate.now().toString(),
                            "message_creation_time", LocalDateTime.now(),
                            "message_creation_timezone", ZoneId.systemDefault().toString(),
                            "user_name", "Alice",
                            "available_projects", List.of("Project A (12345)", "Project B (54333)")
                    )))
                    .user(u -> u.text("What functions are available? List the function names."))
                    .advisors(advisorSpec -> advisorSpec.param(CHAT_MEMORY_CONVERSATION_ID_KEY, "696").param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 100))
                    .functions("clockIn", "clockOut")
                    .call().content();

            // Assert that the response contains the new functions. Expect assertion to fail as long as vector store is active.
            assertThat(response).contains("clockIn");
            assertThat(response).contains("clockOut");
    }

    @Configuration
    @EnableAutoConfiguration(exclude = {DataSourceAutoConfiguration.class})
    public class Conf {
        record Request(String text) {
        }

        record Response(String text) {
        }

        @Bean
        @Description("Clock in")
        public Function<Request, Response> clockIn() {
            return request -> new Response("clockIn " + request.text());
        }

        @Bean
        @Description("Clock out")
        public Function<Request, Response> clockOut() {
            return request -> new Response("clockOut " + request.text());
        }

        @Bean
        public OpenAiApi openAiApi() {
            String apiKey = System.getenv("OPENAI_API_KEY");
            if (!StringUtils.hasText(apiKey)) {
                throw new IllegalArgumentException(
                        "You must provide an API key.  Put it in an environment variable under the name OPENAI_API_KEY");
            }

            return new OpenAiApi(apiKey);
        }

        @Bean
        public OpenAiChatModel openAiChatModel(OpenAiApi api, FunctionCallbackContext functionCallbackContext) {
            return new OpenAiChatModel(api,
                    OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_O_MINI).build(),
                    functionCallbackContext, RetryUtils.DEFAULT_RETRY_TEMPLATE);
        }

        @Bean
        public OpenAiEmbeddingModel openAiEmbeddingModel(OpenAiApi api) {
            return new OpenAiEmbeddingModel(api);
        }

        @Bean
        @ConditionalOnMissingBean
        public FunctionCallbackContext springAiFunctionManager(ApplicationContext context) {
            FunctionCallbackContext manager = new FunctionCallbackContext();
            manager.setApplicationContext(context);
            return manager;
        }

        @Bean
        public VectorStore vectorStore(EmbeddingModel embeddingModel, ChromaApi chromaApi) {
            return new ChromaVectorStore(embeddingModel, chromaApi, true);
        }

        @Bean
        public ChromaApi chromaApi() {
            String chromaDbHost = chromadb.getHost();
            int chromaDbPort = chromadb.getMappedPort(8000);  // Adjust port if necessary

            String baseUrl = String.format("http://%s:%d", chromaDbHost, chromaDbPort);
            return new ChromaApi(baseUrl);
        }
    }
}
russdanner commented 1 week ago

I am seeing something similar when using the QuestionAnswerAdvisor or RetrievalAugmentationAdvisor and a custom function. The advisors run but the functions do not run. If I remove the advisors and re-issue the prompt, the functions are used.

        def checkAvailabilityFuncCallWrapper = FunctionCallbackWrapper.builder(new CheckAvailabilityTool())
        .withName("CheckAvailability")
        .withDescription("Returns true if rooms are available")
        .withResponseConverter((response) -> "" + response.available())
        .build()

        def advisor = RetrievalAugmentationAdvisor.builder().documentRetriever(this.vectorStoreDocumentRetriever).build()

        def response = this.chatClient.prompt().user(userPrompt).functions(checkAvailabilityFuncCallWrapper).advisors(advisor).call()
        def answer = response.content() 

I am using 1.0.0-SNAPSHOT

russdanner commented 1 week ago

Disregard my previous message if you are using the latest build from source. Advisors and Functions look to be working as expected when building from source.