From cb698b10693bd208b9ac9acb9d175b5501b73c88 Mon Sep 17 00:00:00 2001 From: Riyadh Al Nur Date: Tue, 31 Dec 2019 18:31:53 +0800 Subject: [PATCH] [FIX] Skip header row when not a DDL query DDL queries on Athena don't return a header row. If query contains a DDL statement, the first row is not skipped when doing a scan. Signed-off-by: Riyadh Al Nur --- conn.go | 16 ++++++++++++---- db_test.go | 19 +++++++++++++++++++ 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/conn.go b/conn.go index 5eb17df..907001a 100644 --- a/conn.go +++ b/conn.go @@ -4,6 +4,7 @@ import ( "context" "database/sql/driver" "errors" + "regexp" "time" "github.com/aws/aws-sdk-go/aws" @@ -48,10 +49,9 @@ func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error) } return newRows(rowsConfig{ - Athena: c.athena, - QueryID: queryID, - // todo add check for ddl queries to not skip header(#10) - SkipHeader: true, + Athena: c.athena, + QueryID: queryID, + SkipHeader: isDDLQuery(query), }) } @@ -136,3 +136,11 @@ func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) { var _ driver.Queryer = (*conn)(nil) var _ driver.Execer = (*conn)(nil) + +// supported DDL statements by Athena +// https://docs.aws.amazon.com/athena/latest/ug/language-reference.html +var ddlQueryRegex = regexp.MustCompile(`(?i)^(ALTER|CREATE|DESCRIBE|DROP|MSCK|SHOW)`) + +func isDDLQuery(query string) bool { + return ddlQueryRegex.Match([]byte(query)) +} diff --git a/db_test.go b/db_test.go index 8583810..738c987 100644 --- a/db_test.go +++ b/db_test.go @@ -126,6 +126,25 @@ func TestOpen(t *testing.T) { require.NoError(t, err, "Query") } +func TestDDLQuery(t *testing.T) { + harness := setup(t) + defer harness.teardown() + + rows := harness.mustQuery("show tables") + + output := make([]string, 0) + for rows.Next() { + var table string + + err := rows.Scan(&table) + assert.NoError(t, err, "rows.Scan()") + + output = append(output, table) + } + + assert.Equal(t, 1, len(output), "query output") +} + type dummyRow struct { NullValue *struct{} `json:"nullValue"` SmallintType int `json:"smallintType"`