pganalyze / pg_query_go

Go library to parse and normalize SQL queries using the PostgreSQL query parser
BSD 3-Clause "New" or "Revised" License
660 stars 80 forks source link

Get Table names from parsed query #18

Open Hassnain-Alvi opened 5 years ago

Hassnain-Alvi commented 5 years ago

Hi. How can i get table names from the query parser? in pg_query its mentioned as follows :

PgQuery.parse("SELECT ? FROM x JOIN y USING (id) WHERE z = ?").tables

but when i do it in pg_query_go as follows: tables := pg_query.Parse("SELECT ? FROM x JOIN y USING (id) WHERE z = ?").tables

it returns an error. Can you please post an example for this? Thanks

elliotcourant commented 5 years ago

If I remember correctly, all the table names are in the parse tree at the same type. I have a method (albeit ugly) that uses reflect to return a list of table names based on that premise.

import (
    "fmt"
    "github.com/readystock/golinq"
    "github.com/lfittl/pg_query_go/nodes"
    "reflect"
    "strings"
)

func GetTables(stmt interface{}) []string {
    tables := make([]string, 0)
    linq.From(examineTables(stmt, 0)).Distinct().ToSlice(&tables)
    return tables
}

func examineTables(value interface{}, depth int) []string {
    args := make([]string, 0)
    print := func(msg string, args ...interface{}) {
        fmt.Printf("%s%s\n", strings.Repeat("\t", depth), fmt.Sprintf(msg, args...))
    }

    if value == nil {
        return args
    }

    t := reflect.TypeOf(value)
    v := reflect.ValueOf(value)

    if v.Type() == reflect.TypeOf(pg_query.RangeVar{}) {
        rangeVar := value.(pg_query.RangeVar)
        args = append(args, *rangeVar.Relname)
    }

    switch t.Kind() {
    case reflect.Ptr:
        if v.Elem().IsValid() {
            args = append(args, examineTables(v.Elem().Interface(), depth+1)...)
        }
    case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice:
        depth--
        if v.Len() > 0 {
            print("[")
            for i := 0; i < v.Len(); i++ {
                depth++
                print("[%d] Type {%s} {", i, v.Index(i).Type().String())
                args = append(args, examineTables(v.Index(i).Interface(), depth+1)...)
                print("},")
                depth--
            }
            print("]")
        } else {
            print("[]")
        }
    case reflect.Struct:
        for i := 0; i < t.NumField(); i++ {
            f := t.Field(i)
            print("[%d] Field {%s} Type {%s} Kind {%s}", i, f.Name, f.Type.String(), reflect.ValueOf(value).Field(i).Kind().String())
            args = append(args, examineTables(reflect.ValueOf(value).Field(i).Interface(), depth+1)...)
        }
    }
    return args
}

It probably isn't the best way to do it, but it works.

elliotcourant commented 5 years ago

Here is an example test to help.

import (
    "github.com/lfittl/pg_query_go"
    pg_query2 "github.com/lfittl/pg_query_go/nodes"
    "github.com/stretchr/testify/assert"
    "testing"
)

var (
    tableTestQueries = []struct {
        Query  string
        Tables []string
    }{
        {
            Query:  "SELECT $1::text;",
            Tables: []string{},
        },
        {
            Query:  "SELECT e.typdelim FROM pg_catalog.pg_type t, pg_catalog.pg_type e WHERE t.oid = $1 and t.typelem = e.oid",
            Tables: []string{"pg_type"},
        },
        {
            Query:  "SELECT e.typdelim FROM pg_catalog.pg_type t, pg_catalog.pg_type e WHERE t.oid = $1 and t.typelem = e.oid AND $2=$3",
            Tables: []string{"pg_type"},
        },
        {
            Query:  "SELECT e.typdelim FROM pg_catalog.pg_type t, pg_catalog.pg_type e WHERE t.oid = $1 and t.typelem = e.oid AND $2=$1",
            Tables: []string{"pg_type"},
        },
        {
            Query:  "SELECT products.id FROM products JOIN types ON types.id=products.type_id",
            Tables: []string{"products", "types"},
        },
        {
            Query:  "SELECT products.id FROM products JOIN types ON types.id=products.type_id WHERE products.id IN (SELECT id FROM other)",
            Tables: []string{"products", "types", "other"},
        },
        {
            Query:  "INSERT INTO products (id) VALUES(1);",
            Tables: []string{"products"},
        },
        {
            Query:  "UPDATE variations SET id=4 WHERE id=3;",
            Tables: []string{"variations"},
        },
    }
)

func Test_GetTables(t *testing.T) {
    for _, item := range tableTestQueries {
        parsed, err := pg_query.Parse(item.Query)
        if err != nil {
            t.Error(err)
            t.FailNow()
        }

        stmt := parsed.Statements[0].(pg_query2.RawStmt).Stmt

        tableCount := GetTables(stmt)

        assert.Equal(t, item.Tables, tableCount, "number of tables does not match expected")
    }
}

With some small modifications to the method I posted in my last comment, it could also include the schema in the table name array.

Hassnain-Alvi commented 5 years ago

Thanks it worked

SerialVelocity commented 6 months ago

That seems to fail with aliases. e.g. WITH a AS (SELECT * FROM tmp) SELECT * FROM a