Skip to content

Commit

Permalink
pgcdc: normalize identifiers
Browse files Browse the repository at this point in the history
  • Loading branch information
rockwotj committed Dec 13, 2024
1 parent 3854e27 commit 9e48cc7
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 26 deletions.
8 changes: 6 additions & 2 deletions internal/impl/postgresql/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -648,9 +648,13 @@ pg_stream:
slot_name: test_slot_native_decoder
stream_snapshot: true
include_transaction_markers: false
schema: public
# This is intentionally with uppercase - we want to validate
# we treat identifiers the same as Postgres Queries.
schema: PuBliC
tables:
- flights
# This is intentionally with uppercase - we want to validate
# we treat identifiers the same as Postgres Queries.
- FLIGHTS
`, databaseURL)

cacheConf := fmt.Sprintf(`
Expand Down
8 changes: 5 additions & 3 deletions internal/impl/postgresql/pglogicalstream/logical_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,18 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) {
return nil, err
}

if err := sanitize.ValidatePostgresIdentifier(config.DBSchema); err != nil {
schema, err := sanitize.NormalizePostgresIdentifier(config.DBSchema)
if err != nil {
return nil, fmt.Errorf("invalid schema name %q: %w", config.DBSchema, err)
}

tables := []TableFQN{}
for _, table := range config.DBTables {
if err := sanitize.ValidatePostgresIdentifier(table); err != nil {
normalized, err := sanitize.NormalizePostgresIdentifier(table)
if err != nil {
return nil, fmt.Errorf("invalid table name %q: %w", table, err)
}
tables = append(tables, TableFQN{Schema: config.DBSchema, Table: table})
tables = append(tables, TableFQN{Schema: schema, Table: normalized})
}
stream := &Stream{
pgConn: dbConn,
Expand Down
34 changes: 17 additions & 17 deletions internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,54 +381,54 @@ func QuotePostgresIdentifier(name string) string {
return quoted.String()
}

// ValidatePostgresIdentifier checks if a string is a valid PostgreSQL identifier
// NormalizePostgresIdentifier checks if a string is a valid PostgreSQL identifier
// This follows PostgreSQL's standard naming rules
func ValidatePostgresIdentifier(name string) error {
func NormalizePostgresIdentifier(name string) (string, error) {
if len(name) == 0 {
return errors.New("empty identifier is not allowed")
return "", errors.New("empty identifier is not allowed")
}

// It's not fully clear to me if the max here is before or after unescaping the quotes.
// We'll just play it safe and validate before quotes, it seems unlikely folks are using large
// identifiers.
if len(name) > MaxIdentifierLength {
return fmt.Errorf("identifier length exceeds maximum of %d characters", MaxIdentifierLength)
return "", fmt.Errorf("identifier length exceeds maximum of %d characters", MaxIdentifierLength)
}

// Handle quoted identifiers.
if strings.HasPrefix(name, `"`) && strings.HasSuffix(name, `"`) && len(name) >= 2 {
name := name[1 : len(name)-1]
if name == "" {
return errors.New("quoted identifiers cannot be empty")
unquoted := name[1 : len(name)-1]
if unquoted == "" {
return "", errors.New("quoted identifiers cannot be empty")
}
for i := 0; i < len(name); i++ {
if name[i] != '"' {
for i := 0; i < len(unquoted); i++ {
if unquoted[i] != '"' {
continue
}
if i+1 >= len(name) {
return fmt.Errorf("invalid quoted identifier: %s", name)
if i+1 >= len(unquoted) {
return "", fmt.Errorf("invalid quoted identifier: %s", unquoted)
}
if name[i+1] != '"' {
return fmt.Errorf("invalid quoted identifier: %s", name)
if unquoted[i+1] != '"' {
return "", fmt.Errorf("invalid quoted identifier: %s", unquoted)
}
i++ // Skip over the next character to handle triple quotes
}
return nil
return name, nil
}

// First character must be a letter or underscore
if !unicode.IsLetter(rune(name[0])) && name[0] != '_' {
return errors.New("identifier must start with a letter or underscore")
return "", errors.New("identifier must start with a letter or underscore")
}

// Subsequent characters must be letters, numbers, underscores, or dots
for i, char := range name {
if !unicode.IsLetter(char) && !unicode.IsDigit(char) && char != '_' && char != '.' {
return fmt.Errorf("invalid character '%c' at position %d in identifier '%s'", char, i, name)
return "", fmt.Errorf("invalid character '%c' at position %d in identifier '%s'", char, i, name)
}
}

// TODO(cdc): We should also ensure that this is not a reserved keyword.

return nil
return QuotePostgresIdentifier(strings.ToLower(name)), nil
}
23 changes: 19 additions & 4 deletions internal/impl/postgresql/pglogicalstream/sanitize/sanitize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
package sanitize_test

import (
"strconv"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -255,18 +256,31 @@ func TestQuerySanitize(t *testing.T) {
}

func TestIdentifierValidation(t *testing.T) {
successfulTests := []string{
quoted := []string{
`"FooBar"`,
`"Foo""Bar"`,
`"Foo""""Bar"`,
}

for _, i := range quoted {
i := i
t.Run(i, func(t *testing.T) {
_, err := sanitize.NormalizePostgresIdentifier(i)
require.NoError(t, err)
})
}

unquoted := []string{
`_Foobar`,
strings.Repeat("a", 63),
}

for _, i := range successfulTests {
for _, i := range unquoted {
i := i
t.Run(i, func(t *testing.T) {
require.NoError(t, sanitize.ValidatePostgresIdentifier(i))
normalized, err := sanitize.NormalizePostgresIdentifier(i)
require.NoError(t, err)
require.Equal(t, strconv.Quote(strings.ToLower(i)), normalized)
})
}

Expand All @@ -285,7 +299,8 @@ func TestIdentifierValidation(t *testing.T) {
for _, i := range errorTests {
i := i
t.Run(i, func(t *testing.T) {
require.Error(t, sanitize.ValidatePostgresIdentifier(i))
_, err := sanitize.NormalizePostgresIdentifier(i)
require.Error(t, err)
})
}
}

0 comments on commit 9e48cc7

Please sign in to comment.