tjake / Jlama

Jlama is a modern LLM inference engine for Java
Apache License 2.0
665 stars 62 forks source link

Bug: when using a tools, the model tries to use the tools way too many times #93

Closed ricoapon closed 2 weeks ago

ricoapon commented 1 month ago

I first want to say that I love JLama! It makes it very accessible to use AI. First time I toyed around with it because of JLama. Many people and companies (including me) do not like using external parties, because we send data to those parties. And we have no clue what is being done with that. Love the fact that this is fully local :).

Problem

I tried to create an extremely simple example where I use a model to use "tools". In my case I tried to use a sum and square root method, as is documented in https://docs.langchain4j.dev/tutorials/tools/ as examples.

The problem is that the AI tries to use these tools many times repeatedly, when only once should be enough. My guess is that this is an issue with the model, and I am not sure if you want me to report bugs on models. But since this is such an extremely simple case I thought it was worth reporting.

Versions

I am using langchain4j 0.35.0 and jlama 0.7.0. Using Windows computer with Java 21.

I am using model tjake/Llama-3.2-1B-Instruct-JQ4.

Reproduction path

I used this Java class:

import dev.langchain4j.agent.tool.Tool;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.jlama.JlamaChatModel;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.Result;

public class Main {
    public static void main() {
        String modelName = "tjake/Llama-3.2-1B-Instruct-JQ4";
        String prompt = """
                What is the sum of 12444 and 6532?
                """.trim();

        System.out.println(promptModelUsingLangchain(prompt, modelName));
    }

    interface Assistant {
        Result<String> chat(String userMessage);
    }

    private static String promptModelUsingLangchain(String prompt, String modelName) {
        ChatLanguageModel model = JlamaChatModel.builder()
                .modelName(modelName)
                .build();

        Assistant assistant = AiServices.builder(Assistant.class)
                .chatLanguageModel(model)
                .tools(new MyTools())
                .build();

        Result<String> result = assistant.chat(prompt);
        return result.content();
    }

    static class MyTools {
        @Tool
        int add(int a, int b) {
            System.out.println("Method was called with " + a + ", " + b);
            return a + b;
        }

        @Tool
        public double squareRoot(double x) {
            System.out.println("Method was called with " + x);
            return Math.sqrt(x);
        }
    }
}

It can give slightly different answers if you change the prompt a little bit. For example, sometimes it doesn't use its tools, so I tried adding "use your tools" to the prompt etc.

Most runs, I get logging like this:

Method was called with 6532, 12444
Method was called with 6532, 12444
Method was called with 6532, 18976
Method was called with 6532, 25508
Method was called with 6532, 32040
Method was called with 6532, 38572
Method was called with 6532, 45104
Method was called with 6532, 51636
Method was called with 6532, 58168
Method was called with 6532, 64700

And then the actual answer eventually pops up:

def add(arg1, arg0):
    return arg1 + arg0

result = add(12444, 6532)
print(result)

This is not even an answer to the question. You can clearly see that it does something with the output. The model recognizes the output.

When using square root, I can get something similar. Prompt: What is the square root of 1244?. Logging:

Method was called with 1244.0
Method was called with 1244.0
Method was called with 1244.0
Method was called with 1244.0
Method was called with 1244.0
Method was called with 1244.0
Method was called with 1244.0
Method was called with 1244.0
Method was called with 1244.0
Method was called with 1244.0
Method was called with 1244.0
Method was called with 1244.0
Method was called with 1244.0
Method was called with 1244.0

I stopped the program, because it doesn't seem to give me any response after having it run for more than two minutes.

tjake commented 1 month ago

Hello!

Yes I noticed this too and have a fix I will be updating langchain4j soon