salesforce / WikiSQL

A large annotated semantic parsing corpus for developing natural language interfaces.
BSD 3-Clause "New" or "Revised" License
1.63k stars 323 forks source link

How to get the original SQL query from the "sql"/"query" field which is in json format #102

Open lxw0109 opened 1 year ago

lxw0109 commented 1 year ago

How can I convert the following code in "sql" field into the original SQL query

{"phase": 1, "table_id": "1-1000181-1", "question": "Tell me what the notes are for South Australia ", "sql": {"sel": 5, "conds": [[3, 0, "SOUTH AUSTRALIA"]], "agg": 0}}

I tried using lib.query.Query.from_dict method but get SELECT col5 FROM table WHERE col3 = SOUTH AUSTRALIA and tried using lib.dbengine.DBEngine.execute_query method but get SELECT col5 AS result FROM table_1_1000181_1 WHERE col3 = :col3. None of the above two methods get the correct SQL query, so how can I get it? Anybody help?

magic-YuanTian commented 1 year ago

same question

skyrise-l commented 1 year ago

same question

3cham commented 1 year ago

You could add following code into the Query class and provide the types to get the correct query. E.g:

class Query
    ...
    def to_query(self, types):
        if self.agg_ops[self.agg_index]:
            rep = 'SELECT {agg} ({sel}) FROM table'.format(
                agg=self.agg_ops[self.agg_index],
                sel='col{}'.format(self.sel_index),
            )
        else:
            rep = f'SELECT col{self.sel_index} FROM table'
        if self.conditions:
            cond_strings = []
            for i, o, v in self.conditions:
                if types[i] == "text":
                    cond_strings.append(f"col{i} {self.cond_ops[o]} '{v}'")
                else:
                    cond_strings.append(f"col{i} {self.cond_ops[o]} {v}")
            rep +=  ' WHERE ' + ' AND '.join(cond_strings)
        return rep
    ...
Gyyz commented 1 month ago

try and modify this:

import re
import json

class SQLToJsonConverter:
    def __init__(self, table_name, column_names):
        self.table_name = table_name
        self.column_names = column_names
        self.agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']  # Aggregation operations
        self.cond_ops = ['=', '>', '<', 'OP']  # Condition operations

    def parse_sql(self, sql_query):
        # Initialize the JSON structure
        query_json = {
            "table": self.table_name,
            "agg": 0,
            "sel": 0,
            "conds": []
        }

        # Parse the SELECT part
        select_match = re.search(r'SELECT\s+(?:(\w+)\s*\(\s*(col\d+)\s*\)|col(\d+))\s+FROM', sql_query, re.IGNORECASE)
        if select_match:
            if select_match.group(1):  # Aggregation function exists
                agg_function = select_match.group(1).upper()
                if agg_function in self.agg_ops:
                    query_json["agg"] = self.agg_ops.index(agg_function)
                query_json["sel"] = int(select_match.group(2)[3:])
            else:  # No aggregation function
                query_json["sel"] = int(select_match.group(3)[3:])

        # Parse the WHERE part
        where_match = re.search(r'WHERE\s+(.+)', sql_query, re.IGNORECASE)
        if where_match:
            conditions_str = where_match.group(1)
            conditions = conditions_str.split(' AND ')
            for cond in conditions:
                col_match = re.search(r'col(\d+)\s*(\W+)\s*(.+)', cond)
                if col_match:
                    i = int(col_match.group(1))
                    o = col_match.group(2).strip()
                    v = col_match.group(3).strip().strip("'")
                    if o in self.cond_ops:
                        o = self.cond_ops.index(o)
                    query_json["conds"].append([i, o, v])

        # Convert column indices to column names
        query_json["sel"] = self.column_names[query_json["sel"]]
        query_json["conds"] = [
            [self.column_names[cond[0]], cond[1], cond[2]] for cond in query_json["conds"]
        ]

        return query_json

    def sql_to_json(self, sql_query):
        parsed_query = self.parse_sql(sql_query)
        return json.dumps(parsed_query, indent=4)

# Example usage:
table_name = "table"
column_names = ["id", "name", "value", "count"]  # Example column names corresponding to col0, col1, etc.
converter = SQLToJsonConverter(table_name, column_names)

sql_query = "SELECT MAX(col1) FROM table WHERE col2 = 'value1' AND col3 > 10"
json_output = converter.sql_to_json(sql_query)
print(json_output)