NumbersStationAI / DuckDB-NSQL

DuckDB NSQL Model
Apache License 2.0
202 stars 10 forks source link

Canonical example produces unexpected output #12

Open jaraco opened 1 month ago

jaraco commented 1 month ago

I've managed to get the example to work by tweaking the requirements, but even after running the example

import duckdb
from llama_cpp import Llama
from wurlitzer import pipes

from examples.utils import generate_sql

# Set up client with model path and context size
with pipes() as (out, err):
    client = Llama(
        model_path="examples/DuckDB-NSQL-7B-v0.1-q8_0.gguf",
        n_ctx=2048,
    )

# Connect to DuckDB database
con = duckdb.connect("examples/nyc.duckdb")

# Sample question for SQL generation
question = "get all columns from taxi table starting with 'a'"

# Generate SQL, check validity, and print
query = generate_sql(question, con, client)
print(query)

print(con.execute(query).fetchdf())

The output I get for the query doesn't seem to correspond to the intention:

 DuckDB-NSQL debt/just-install @ py -m example
 SELECT COLUMNS('a.*') FROM taxi;
 SELECT COLUMNS('a.*') FROM taxi;
    tpep_pickup_datetime tpep_dropoff_datetime  passenger_count  ...  total_amount  congestion_surcharge airport_fee
0    2022-11-04 00:51:52   2022-11-04 01:02:08              1.0  ...         15.80                   2.5        0.00
1    2022-11-04 00:25:29   2022-11-04 00:39:51              5.0  ...         19.56                   2.5        0.00
2    2022-11-04 00:43:21   2022-11-04 00:54:51              5.0  ...         18.36                   2.5        0.00
3    2022-11-04 00:05:49   2022-11-04 00:21:23              1.0  ...         18.96                   2.5        0.00
4    2022-11-04 00:35:49   2022-11-04 00:35:53              1.0  ...         -5.05                   0.0       -1.25
..                   ...                   ...              ...  ...           ...                   ...         ...
995  2022-11-04 00:40:37   2022-11-04 00:48:56              1.0  ...         13.30                   2.5        0.00
996  2022-11-04 00:57:24   2022-11-04 01:27:29              1.0  ...         32.80                   2.5        0.00
997  2022-11-04 01:29:40   2022-11-04 01:56:05              1.0  ...         25.70                   2.5        0.00
998  2022-11-04 01:44:59   2022-11-04 01:53:23              1.0  ...         12.80                   2.5        0.00
999  2022-11-04 01:06:47   2022-11-04 01:24:46              2.0  ...         25.55                   2.5        0.00

[1000 rows x 18 columns]

It's not outputting the columns that start with a. It's not even outputting rows for columns that start with a. In fact, it seems to be emitting all the rows from the taxi table.

If the LLM can't produce basic queries, I have little hope for it doing anything more useful.