diff --git a/conn.go b/conn.go index 5eb17df..907001a 100644 --- a/conn.go +++ b/conn.go @@ -4,6 +4,7 @@ import ( "context" "database/sql/driver" "errors" + "regexp" "time" "github.com/aws/aws-sdk-go/aws" @@ -48,10 +49,9 @@ func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error) } return newRows(rowsConfig{ - Athena: c.athena, - QueryID: queryID, - // todo add check for ddl queries to not skip header(#10) - SkipHeader: true, + Athena: c.athena, + QueryID: queryID, + SkipHeader: isDDLQuery(query), }) } @@ -136,3 +136,11 @@ func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) { var _ driver.Queryer = (*conn)(nil) var _ driver.Execer = (*conn)(nil) + +// supported DDL statements by Athena +// https://docs.aws.amazon.com/athena/latest/ug/language-reference.html +var ddlQueryRegex = regexp.MustCompile(`(?i)^(ALTER|CREATE|DESCRIBE|DROP|MSCK|SHOW)`) + +func isDDLQuery(query string) bool { + return ddlQueryRegex.Match([]byte(query)) +} diff --git a/db_test.go b/db_test.go index 8583810..738c987 100644 --- a/db_test.go +++ b/db_test.go @@ -126,6 +126,25 @@ func TestOpen(t *testing.T) { require.NoError(t, err, "Query") } +func TestDDLQuery(t *testing.T) { + harness := setup(t) + defer harness.teardown() + + rows := harness.mustQuery("show tables") + + output := make([]string, 0) + for rows.Next() { + var table string + + err := rows.Scan(&table) + assert.NoError(t, err, "rows.Scan()") + + output = append(output, table) + } + + assert.Equal(t, 1, len(output), "query output") +} + type dummyRow struct { NullValue *struct{} `json:"nullValue"` SmallintType int `json:"smallintType"`