andialbrecht / sqlparse

A non-validating SQL parser module for Python
BSD 3-Clause "New" or "Revised" License
3.72k stars 694 forks source link

How to extract columns name along with table name from complex query having subquery #680

Open cehsonu100 opened 2 years ago

cehsonu100 commented 2 years ago

How to extract columns name from complex query having subquery like below:

SELECT "Gender" AS "Gender", "Ethnicity" AS "Ethnicity", count(*) AS "count" FROM (Select distinct avg(t3."Age") As "Age", t1."Ethnicity" As "Ethnicity", t2."Gender" As "Gender" from (SELECT ethnicityid, genderid, age As "Age" FROM schema1.t3 FINAL where (__USER_TABLE_SECURITY_t3_USER_TABLE_SECURITY__) ) as t3 inner join (SELECT name As "Ethnicity", ethnicityid FROM schema1.t1 FINAL where (t1.code = 'AP' and __USER_TABLE_SECURITY_t1_USER_TABLE_SECURITY__) ) as t1 on t3.ethnicityid = t1.ethnicityid inner join (SELECT name As "Gender", genderid FROM schema1.t2 FINAL where (__USER_TABLE_SECURITY_t2_USER_TABLE_SECURITY__) ) as t2 on t3.genderid = t2.genderid where (1 = 1) group by "Ethnicity", "Gender") AS "virtual_table" GROUP BY "Gender", "Ethnicity" ORDER BY "count" DESC LIMIT 1000;

ChrisJaunes commented 2 years ago

You can use "splparse.parse" to parse this complex query, then you can iterate over tokens and get cloumns from "select token" and "from token". Note that it is handled according to different cloumn types.

ChrisJaunes commented 2 years ago

more, you can refer to https://github.com/andialbrecht/sqlparse/blob/master/examples/extract_table_names.py

ChrisJaunes commented 2 years ago

If you need to extract the outermost columns, some modifications are required

for example:

if __name__ == "__main__":
    sql = """SELECT "Gender" AS "Gender", "Ethnicity" AS "Ethnicity", count(*) AS "count" FROM (Select distinct avg(t3."Age") As "Age", t1."Ethnicity" As "Ethnicity", t2."Gender" As "Gender" from (SELECT ethnicityid, genderid, age As "Age" FROM schema1.t3 FINAL where (__USER_TABLE_SECURITY_t3_USER_TABLE_SECURITY__) ) as t3 inner join (SELECT name As "Ethnicity", ethnicityid FROM schema1.t1 FINAL where (t1.code = 'AP' and __USER_TABLE_SECURITY_t1_USER_TABLE_SECURITY__) ) as t1 on t3.ethnicityid = t1.ethnicityid inner join (SELECT name As "Gender", genderid FROM schema1.t2 FINAL where (__USER_TABLE_SECURITY_t2_USER_TABLE_SECURITY__) ) as t2 on t3.genderid = t2.genderid where (1 = 1) group by "Ethnicity", "Gender") AS "virtual_table" GROUP BY "Gender", "Ethnicity" ORDER BY "count" DESC LIMIT 1000;"""
    import sqlparse
    statement = sqlparse.parse(sql)[0]
    raw_columns = []
    USE_SELECT, USE_FROM = False, False
    for token in statement:
        if token.ttype == sqlparse.tokens.Keyword and token.value == "FROM":
            USE_FROM = True
            break
        if USE_SELECT and not (token.is_whitespace or token.match(sqlparse.tokens.Punctuation, ',')):
            raw_columns.append(token)
        if token.ttype == sqlparse.tokens.DML and token.value == "SELECT":
            USE_SELECT = True

    def extract_table_identifiers(token_stream):
        for item in token_stream:
            if isinstance(item, sqlparse.sql.IdentifierList):
                for identifier in item.get_identifiers():
                    yield identifier.get_name()
            elif isinstance(item, sqlparse.sql.Identifier):
                yield item.get_name()
            elif item.ttype is sqlparse.tokens.Keyword:
                yield item.value

    print(list(extract_table_identifiers(raw_columns)))

output

['Gender', 'Ethnicity', 'count']