Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get Table names from parsed query #18

Open
Hassnain-Alvi opened this issue Mar 22, 2019 · 4 comments
Open

Get Table names from parsed query #18

Hassnain-Alvi opened this issue Mar 22, 2019 · 4 comments
Labels

Comments

@Hassnain-Alvi
Copy link

Hi.
How can i get table names from the query parser?
in pg_query its mentioned as follows :

PgQuery.parse("SELECT ? FROM x JOIN y USING (id) WHERE z = ?").tables

but when i do it in pg_query_go as follows:
tables := pg_query.Parse("SELECT ? FROM x JOIN y USING (id) WHERE z = ?").tables

it returns an error.
Can you please post an example for this?
Thanks

@elliotcourant
Copy link

elliotcourant commented Mar 22, 2019

If I remember correctly, all the table names are in the parse tree at the same type. I have a method (albeit ugly) that uses reflect to return a list of table names based on that premise.

import (
	"fmt"
	"github.com/readystock/golinq"
	"github.com/lfittl/pg_query_go/nodes"
	"reflect"
	"strings"
)

func GetTables(stmt interface{}) []string {
	tables := make([]string, 0)
	linq.From(examineTables(stmt, 0)).Distinct().ToSlice(&tables)
	return tables
}

func examineTables(value interface{}, depth int) []string {
	args := make([]string, 0)
	print := func(msg string, args ...interface{}) {
		fmt.Printf("%s%s\n", strings.Repeat("\t", depth), fmt.Sprintf(msg, args...))
	}

	if value == nil {
		return args
	}

	t := reflect.TypeOf(value)
	v := reflect.ValueOf(value)

	if v.Type() == reflect.TypeOf(pg_query.RangeVar{}) {
		rangeVar := value.(pg_query.RangeVar)
		args = append(args, *rangeVar.Relname)
	}

	switch t.Kind() {
	case reflect.Ptr:
		if v.Elem().IsValid() {
			args = append(args, examineTables(v.Elem().Interface(), depth+1)...)
		}
	case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice:
		depth--
		if v.Len() > 0 {
			print("[")
			for i := 0; i < v.Len(); i++ {
				depth++
				print("[%d] Type {%s} {", i, v.Index(i).Type().String())
				args = append(args, examineTables(v.Index(i).Interface(), depth+1)...)
				print("},")
				depth--
			}
			print("]")
		} else {
			print("[]")
		}
	case reflect.Struct:
		for i := 0; i < t.NumField(); i++ {
			f := t.Field(i)
			print("[%d] Field {%s} Type {%s} Kind {%s}", i, f.Name, f.Type.String(), reflect.ValueOf(value).Field(i).Kind().String())
			args = append(args, examineTables(reflect.ValueOf(value).Field(i).Interface(), depth+1)...)
		}
	}
	return args
}

It probably isn't the best way to do it, but it works.

@elliotcourant
Copy link

Here is an example test to help.

import (
	"github.com/lfittl/pg_query_go"
	pg_query2 "github.com/lfittl/pg_query_go/nodes"
	"github.com/stretchr/testify/assert"
	"testing"
)

var (
	tableTestQueries = []struct {
		Query  string
		Tables []string
	}{
		{
			Query:  "SELECT $1::text;",
			Tables: []string{},
		},
		{
			Query:  "SELECT e.typdelim FROM pg_catalog.pg_type t, pg_catalog.pg_type e WHERE t.oid = $1 and t.typelem = e.oid",
			Tables: []string{"pg_type"},
		},
		{
			Query:  "SELECT e.typdelim FROM pg_catalog.pg_type t, pg_catalog.pg_type e WHERE t.oid = $1 and t.typelem = e.oid AND $2=$3",
			Tables: []string{"pg_type"},
		},
		{
			Query:  "SELECT e.typdelim FROM pg_catalog.pg_type t, pg_catalog.pg_type e WHERE t.oid = $1 and t.typelem = e.oid AND $2=$1",
			Tables: []string{"pg_type"},
		},
		{
			Query:  "SELECT products.id FROM products JOIN types ON types.id=products.type_id",
			Tables: []string{"products", "types"},
		},
		{
			Query:  "SELECT products.id FROM products JOIN types ON types.id=products.type_id WHERE products.id IN (SELECT id FROM other)",
			Tables: []string{"products", "types", "other"},
		},
		{
			Query:  "INSERT INTO products (id) VALUES(1);",
			Tables: []string{"products"},
		},
		{
			Query:  "UPDATE variations SET id=4 WHERE id=3;",
			Tables: []string{"variations"},
		},
	}
)

func Test_GetTables(t *testing.T) {
	for _, item := range tableTestQueries {
		parsed, err := pg_query.Parse(item.Query)
		if err != nil {
			t.Error(err)
			t.FailNow()
		}

		stmt := parsed.Statements[0].(pg_query2.RawStmt).Stmt

		tableCount := GetTables(stmt)

		assert.Equal(t, item.Tables, tableCount, "number of tables does not match expected")
	}
}

With some small modifications to the method I posted in my last comment, it could also include the schema in the table name array.

@Hassnain-Alvi
Copy link
Author

Thanks it worked

@lfittl lfittl added the question label Jan 3, 2024
@SerialVelocity
Copy link

That seems to fail with aliases.
e.g.
WITH a AS (SELECT * FROM tmp) SELECT * FROM a

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants