HamaWhiteGG / langchain-java

Java version of LangChain, while empowering LLM for Big Data.
Apache License 2.0
551 stars 107 forks source link

"Question" appearing in the SQLQuery using Ollama #158

Open rwankar opened 6 months ago

rwankar commented 6 months ago

I'm trying the latest version of langchain-java (0.2.0-SNAPSHOT) and I'm trying out a simple DatabaseChain test using locally installed Ollama with llama2.

I've created a Postgresql database called test with 2 tables. Artist and Album. Album has a column artist_id set as a foreign key to id column in Artist table.

My Test program looks like this.

class SQLDatabaseChainTest {

    protected static SQLDatabase database;
    protected static BaseLanguageModel llm;
    protected static SQLDatabaseChain chain;

    public static void setup() {
        database = SQLDatabase.fromUri("jdbc:postgresql://localhost:5432/test", "test", "");

        llm = Ollama.builder()
                .temperature(0f)
                .model("llama2")
                .build()
                .init();

        chain = SQLDatabaseChain.fromLLM(llm, database);
    }

    public static void main(String[] args) {
        setup();
        System.out.println("Setup Complete");
        String actual = chain.run("How many artists are there?");
        System.out.println(actual);

        database.close();
    }
}

When I run it, I get the following output and error. The "Question" is appearing in the SQL query which is the issue. The query itself is correct.

Setup Complete
main                     Apr 28 09:17 DEBUG LLMChain                  * Prompt after formatting:
You are a PostgreSQL expert. Given an input question, first create a syntactically correct PostgreSQL query to run, then look at the results of the query and return the answer to the input question. 
 Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per PostgreSQL. You can order the results to return the most informative data in the database.Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.Pay attention to use CURRENT_DATE function to get the current date, if the question involves "today".Use the following format:Question: Question hereSQLQuery: SQL Query to runSQLResult: Result of the SQLQueryAnswer: Final answer hereOnly use the following tables: 
CREATE TABLE album (
        id int4(10),
        vdb_id int4(10),
        ts timestamp(29,6),
        name varchar(200),
        artist_id int4(10)
)

/*
3 rows from album table:
id      vdb_id  ts      name    artist_id
1       1       2022-01-01 00:00:00     Master of Puppets       1
2       1       2022-01-01 00:00:00     Load    1
3       2       2023-01-01 00:00:00     The Joshua Tree 2
*/

CREATE TABLE artist (
        id int4(10),
        vdb_id int4(10),
        ts timestamp(29,6),
        name varchar(200),
        email varchar(100)
)

/*
3 rows from artist table:
id      vdb_id  ts      name    email
1       1       2022-01-01 00:00:00     Metallica       james@metallica.com
2       1       2023-01-01 00:00:00     U2      bono@u2.com
*/ Question: How many artists are there?
main                     Apr 28 09:24 INFO  SQLDatabaseChain          * SQL command:
 Question: How many artists are there?

SQLQuery: SELECT COUNT(*) FROM artist;
Exception in thread "main" org.postgresql.util.PSQLException: ERROR: syntax error at or near "Question"
  Position: 1
        at org.postgresql.core.v3.QueryExecutorImpl.receiveErrorResponse(QueryExecutorImpl.java:2725)
        at org.postgresql.core.v3.QueryExecutorImpl.processResults(QueryExecutorImpl.java:2412)
        at org.postgresql.core.v3.QueryExecutorImpl.execute(QueryExecutorImpl.java:371)
        at org.postgresql.jdbc.PgStatement.executeInternal(PgStatement.java:502)
        at org.postgresql.jdbc.PgStatement.execute(PgStatement.java:419)
        at org.postgresql.jdbc.PgStatement.executeWithFlags(PgStatement.java:341)
        at org.postgresql.jdbc.PgStatement.executeCachedSql(PgStatement.java:326)
        at org.postgresql.jdbc.PgStatement.executeWithFlags(PgStatement.java:302)
        at org.postgresql.jdbc.PgStatement.execute(PgStatement.java:297)
        at com.hw.langchain.sql.database.SQLDatabase.run(SQLDatabase.java:225)
        at com.hw.langchain.chains.sql.database.base.SQLDatabaseChain.innerCall(SQLDatabaseChain.java:138)
        at com.hw.langchain.chains.base.Chain.call(Chain.java:117)
        at com.hw.langchain.chains.base.Chain.call(Chain.java:103)
        at com.hw.langchain.chains.base.Chain.run(Chain.java:225)
        at SQLDatabaseChainTest.main(SQLDatabaseChainTest.java:57)

If I copy the prompt and directly paste it into Ollama, I get the following output.

Question: How many artists are there?

SQLQuery: SELECT COUNT(*) FROM artist;

SQLResult: 3

Answer: There are 3 artists in the database.

Any suggestions on how to resolve this?