Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
0b48b63
feat: add sorting to sql adapter
Lutherwaves Feb 17, 2026
bdf9c0a
refactor: add SortingDirection type and SortDirectionKey constant
Lutherwaves Feb 20, 2026
c38344b
refactor: use maps.Copy and flatParams in sql extractParams
Lutherwaves Feb 20, 2026
c86c808
refactor: use SortingDirection type in sql adapter
Lutherwaves Feb 20, 2026
edb279e
refactor: use switch and maps.Equal in sql sorting helpers
Lutherwaves Feb 20, 2026
8f1e293
refactor: use SortingDirection type in cosmosdb adapter
Lutherwaves Feb 20, 2026
1a05bc9
refactor: deduplicate extractParams and extractSortDirection as packa…
Lutherwaves Feb 20, 2026
5d6c85e
feat: add validateSortKey to guard against ORDER BY injection
Lutherwaves Feb 23, 2026
3a314bd
fix: allow underscore-prefixed sort keys, add edge case tests
Lutherwaves Feb 23, 2026
2bc9da7
fix: validate sortKey in CosmosDB executePaginatedQuery to prevent in…
Lutherwaves Feb 23, 2026
7b508f6
fix: validate sortKey in SQL executePaginatedQuery to prevent injection
Lutherwaves Feb 23, 2026
5d5f92c
chore(test): restore sqlAdapterInstance singleton after TestListRejec…
Lutherwaves Feb 23, 2026
83e52f2
fix: warn on cursor field extraction failure, remove redundant else b…
Lutherwaves Feb 23, 2026
b72b96e
refactor(style): use snake_case log keys in cursor extraction warnings
Lutherwaves Feb 23, 2026
0cec523
fix: validate sortKey in DynamoDB List and Search to prevent PartiQL …
Lutherwaves Feb 23, 2026
b25e988
chore(docs): add godoc to executePaginatedQuery in SQL and CosmosDB a…
Lutherwaves Feb 23, 2026
4b4931c
chore(docs): document sortKey and SortDirectionKey on StorageAdapter …
Lutherwaves Feb 23, 2026
d31ee23
refactor: more explicit error msgs on failure to list/search
Lutherwaves Feb 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 22 additions & 35 deletions storage/cosmosdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func (s *CosmosDBAdapter) GetLatestMigration() (int, error) {

func (s *CosmosDBAdapter) Create(item any, params ...map[string]any) error {
// Extract provider-specific parameters
paramMap := s.extractParams(params...)
paramMap := extractParams(params...)

containerName := s.getContainerName(item)
containerClient, err := s.databaseClient.NewContainer(containerName)
Expand Down Expand Up @@ -230,7 +230,7 @@ func (s *CosmosDBAdapter) Get(dest any, filter map[string]any, params ...map[str
}

// Extract provider-specific parameters
paramMap := s.extractParams(params...)
paramMap := extractParams(params...)

containerName := s.getContainerName(dest)
containerClient, err := s.databaseClient.NewContainer(containerName)
Expand Down Expand Up @@ -302,7 +302,7 @@ func (s *CosmosDBAdapter) Update(item any, filter map[string]any, params ...map[
}

// Extract provider-specific parameters
paramMap := s.extractParams(params...)
paramMap := extractParams(params...)

containerName := s.getContainerName(item)
containerClient, err := s.databaseClient.NewContainer(containerName)
Expand Down Expand Up @@ -388,7 +388,7 @@ func (s *CosmosDBAdapter) Delete(item any, filter map[string]any, params ...map[
}

// Extract provider-specific parameters
paramMap := s.extractParams(params...)
paramMap := extractParams(params...)

containerName := s.getContainerName(item)
containerClient, err := s.databaseClient.NewContainer(containerName)
Expand Down Expand Up @@ -428,8 +428,11 @@ func (s *CosmosDBAdapter) Delete(item any, filter map[string]any, params ...map[

func (s *CosmosDBAdapter) List(dest any, sortKey string, filter map[string]any, limit int, cursor string, params ...map[string]any) (string, error) {
// Extract sort direction from params
paramMap := s.extractParams(params...)
sortDirection := s.extractSortDirection(paramMap)
paramMap := extractParams(params...)
sortDirection, err := extractSortDirection(paramMap)
if err != nil {
return "", fmt.Errorf("failed to list: %w", err)
}

return s.executePaginatedQuery(dest, sortKey, sortDirection, limit, cursor, filter, params...)
}
Expand All @@ -441,8 +444,11 @@ func (s *CosmosDBAdapter) Search(dest any, sortKey string, query string, limit i
// For custom queries, use the Query method instead

// Extract sort direction from params
paramMap := s.extractParams(params...)
sortDirection := s.extractSortDirection(paramMap)
paramMap := extractParams(params...)
sortDirection, err := extractSortDirection(paramMap)
if err != nil {
return "", fmt.Errorf("failed to search: %w", err)
}

// Use executePaginatedQuery with empty filter (the query parameter is ignored for CosmosDB)
return s.executePaginatedQuery(dest, sortKey, sortDirection, limit, cursor, map[string]any{}, params...)
Expand Down Expand Up @@ -514,17 +520,23 @@ func (s *CosmosDBAdapter) Query(dest any, statement string, limit int, cursor st
return nextCursor, nil
}

// executePaginatedQuery runs a cursor-paginated Cosmos DB query against the container for dest.
// The cursor is a Cosmos DB continuation token.
func (s *CosmosDBAdapter) executePaginatedQuery(
dest any,
sortKey string,
sortDirection string,
sortDirection SortingDirection,
limit int,
cursor string,
filter map[string]any,
params ...map[string]any,
) (string, error) {
if err := validateSortKey(sortKey); err != nil {
return "", err
}

// Extract provider-specific parameters
paramMap := s.extractParams(params...)
paramMap := extractParams(params...)

containerName := s.getContainerName(dest)
containerClient, err := s.databaseClient.NewContainer(containerName)
Expand Down Expand Up @@ -706,31 +718,6 @@ func (s *CosmosDBAdapter) executeQuery(
return page, err
}

// extractParams merges all provided parameter maps into a single map
func (s *CosmosDBAdapter) extractParams(params ...map[string]any) map[string]any {
paramMap := make(map[string]any)
for _, param := range params {
for k, v := range param {
paramMap[k] = v
}
}
return paramMap
}

// extractSortDirection extracts and validates sort direction from params
func (s *CosmosDBAdapter) extractSortDirection(paramMap map[string]any) string {
sortDirection := "ASC" // Default to ASC
if dir, exists := paramMap["sort_direction"]; exists {
if dirStr, ok := dir.(string); ok {
sortDirection = strings.ToUpper(dirStr)
if sortDirection != "ASC" && sortDirection != "DESC" {
sortDirection = "ASC" // Fallback to ASC for invalid values
}
}
}
return sortDirection
}

// buildFilter constructs WHERE clause conditions from filter map
func (s *CosmosDBAdapter) buildFilter(filter map[string]any, paramIndex *int) (string, []azcosmos.QueryParameter) {
conditions := []string{}
Expand Down
6 changes: 6 additions & 0 deletions storage/dynamodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ func (s *DynamoDBAdapter) executePaginatedQuery(
}

func (s *DynamoDBAdapter) List(dest any, sortKey string, filter map[string]any, limit int, cursor string, params ...map[string]any) (string, error) {
if err := validateSortKey(sortKey); err != nil {
return "", err
}
return s.executePaginatedQuery(dest, limit, cursor, func(input *dynamodb.ExecuteStatementInput) *dynamodb.ExecuteStatementInput {
query := fmt.Sprintf(`SELECT * FROM "%s"`, s.getTableName(dest))

Expand All @@ -235,6 +238,9 @@ func (s *DynamoDBAdapter) List(dest any, sortKey string, filter map[string]any,
}

func (s *DynamoDBAdapter) Search(dest any, sortKey string, query string, limit int, cursor string, params ...map[string]any) (string, error) {
if err := validateSortKey(sortKey); err != nil {
return "", err
}
return s.executePaginatedQuery(dest, limit, cursor, func(input *dynamodb.ExecuteStatementInput) *dynamodb.ExecuteStatementInput {
// Parse Lucene query
destType := reflect.TypeOf(dest).Elem().Elem()
Expand Down
41 changes: 33 additions & 8 deletions storage/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,19 @@ func (s *SQLAdapter) Delete(item any, filter map[string]any, params ...map[strin
return result.Error
}

// executePaginatedQuery runs a cursor-paginated SELECT using the provided query builder scope.
// The cursor is base64-encoded from the sortKey field value of the last returned row.
func (s *SQLAdapter) executePaginatedQuery(
dest any,
sortKey string,
sortDirection SortingDirection,
limit int,
cursor string,
builder queryBuilder,
) (string, error) {
if err := validateSortKey(sortKey); err != nil {
return "", err
}
var cursorValue string
if cursor != "" {
bytes, err := base64.StdEncoding.DecodeString(cursor)
Expand All @@ -231,10 +237,14 @@ func (s *SQLAdapter) executePaginatedQuery(
}
q := s.DB.Model(dest).Scopes(builder)

q = q.Limit(limit + 1).Order(fmt.Sprintf("%s ASC", sortKey))
q = q.Limit(limit + 1).Order(fmt.Sprintf("%s %s", sortKey, sortDirection))

if cursorValue != "" {
q = q.Where(fmt.Sprintf("%s > ?", sortKey), cursorValue)
cursorOp := ">"
if sortDirection == Descending {
cursorOp = "<"
}
q = q.Where(fmt.Sprintf("%s %s ?", sortKey, cursorOp), cursorValue)
}

if result := q.Find(dest); result.Error != nil {
Expand All @@ -251,19 +261,29 @@ func (s *SQLAdapter) executePaginatedQuery(
if destSlice.Len() > limit {
lastItem := destSlice.Index(limit - 1)
field := reflect.Indirect(lastItem).FieldByName(sortKey)
if field.IsValid() && field.Kind() == reflect.String {
if !field.IsValid() {
slog.Warn("cursor extraction failed: sort_key does not match any exported struct field",
"sort_key", sortKey,
"hint", "sort_key must be the Go struct field name (e.g. 'CreatedAt'), not the DB column name (e.g. 'created_at')")
} else if field.Kind() != reflect.String {
slog.Warn("cursor extraction failed: struct field is not a string",
"sort_key", sortKey,
"field_kind", field.Kind().String())
} else {
nextCursor = base64.StdEncoding.EncodeToString([]byte(field.String()))
}
destSlice.Set(destSlice.Slice(0, limit))
} else {
nextCursor = ""
}

return nextCursor, nil
}

func (s *SQLAdapter) List(dest any, sortKey string, filter map[string]any, limit int, cursor string, params ...map[string]any) (string, error) {
return s.executePaginatedQuery(dest, sortKey, limit, cursor, func(q *gorm.DB) *gorm.DB {
sortDirection, err := extractSortDirection(extractParams(params...))
if err != nil {
return "", fmt.Errorf("failed to list: %w", err)
}
return s.executePaginatedQuery(dest, sortKey, sortDirection, limit, cursor, func(q *gorm.DB) *gorm.DB {
if len(filter) > 0 {
query, bindings := s.buildQuery(filter)
return q.Where(query, bindings)
Expand All @@ -273,8 +293,12 @@ func (s *SQLAdapter) List(dest any, sortKey string, filter map[string]any, limit
}

func (s *SQLAdapter) Search(dest any, sortKey string, query string, limit int, cursor string, params ...map[string]any) (string, error) {
sortDirection, err := extractSortDirection(extractParams(params...))
if err != nil {
return "", fmt.Errorf("failed to search: %w", err)
}
if query == "" {
return s.executePaginatedQuery(dest, sortKey, limit, cursor, func(q *gorm.DB) *gorm.DB {
return s.executePaginatedQuery(dest, sortKey, sortDirection, limit, cursor, func(q *gorm.DB) *gorm.DB {
return q
})
}
Expand All @@ -296,7 +320,7 @@ func (s *SQLAdapter) Search(dest any, sortKey string, query string, limit int, c

slog.Debug(fmt.Sprintf(`Where clause: %s, with params %s`, whereClause, queryParams))

return s.executePaginatedQuery(dest, sortKey, limit, cursor, func(q *gorm.DB) *gorm.DB {
return s.executePaginatedQuery(dest, sortKey, sortDirection, limit, cursor, func(q *gorm.DB) *gorm.DB {
if whereClause != "" {
return q.Where(whereClause, queryParams...)
}
Expand Down Expand Up @@ -324,6 +348,7 @@ func (s *SQLAdapter) Query(dest any, statement string, limit int, cursor string,
return "", fmt.Errorf("not implemented yet")
}


func (s *SQLAdapter) buildQuery(filter map[string]any) (string, map[string]any) {
clauses := []string{}
bindings := make(map[string]any)
Expand Down
106 changes: 106 additions & 0 deletions storage/sql_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package storage

import (
"maps"
"testing"
)

func TestListRejectsMaliciousSortKey(t *testing.T) {
// Reset singleton so we get a fresh SQLite adapter
prev := sqlAdapterInstance
sqlAdapterInstance = nil
t.Cleanup(func() { sqlAdapterInstance = prev })
adapter := GetSQLAdapterInstance(map[string]string{
"provider": "sqlite",
})
type Row struct {
ID string `gorm:"primaryKey"`
}
_ = adapter.DB.AutoMigrate(&Row{})

var rows []Row
_, err := adapter.List(&rows, "id; DROP TABLE rows", map[string]any{}, 10, "")
if err == nil {
t.Error("expected error for malicious sortKey, got nil")
}
}

func TestValidateSortKey(t *testing.T) {
tests := []struct {
name string
input string
wantErr bool
}{
{"simple column", "id", false},
{"snake_case column", "created_at", false},
{"mixed case", "createdAt", false},
{"with numbers", "field1", false},
{"empty string", "", true},
{"leading digit", "1field", true},
{"dot notation injection", "id; DROP TABLE users", true},
{"semicolon", "id;DROP", true},
{"space", "col name", true},
{"table.column dot", "t.col", true},
{"SQL comment", "id--", true},
{"single quote", "id'", true},
{"underscore prefix", "_ts", false}, // CosmosDB system fields like _ts are valid
{"null byte", "id\x00DROP", true}, // null byte injection rejected
{"unicode lookalike", "iа", true}, // Cyrillic а (U+0430) rejected, not ASCII
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateSortKey(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("validateSortKey(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr)
}
})
}
}

func TestExtractSortDirection(t *testing.T) {
tests := []struct {
name string
input map[string]any
expected SortingDirection
wantErr bool
}{
{"default when missing", map[string]any{}, Ascending, false},
{"ASC explicit", map[string]any{SortDirectionKey: "ASC"}, Ascending, false},
{"DESC", map[string]any{SortDirectionKey: "DESC"}, Descending, false},
{"lowercase desc", map[string]any{SortDirectionKey: "desc"}, Descending, false},
{"invalid returns error", map[string]any{SortDirectionKey: "SIDEWAYS"}, "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := extractSortDirection(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("extractSortDirection(%v) error = %v, wantErr %v", tt.input, err, tt.wantErr)
return
}
if !tt.wantErr && got != tt.expected {
t.Errorf("extractSortDirection(%v) = %q; want %q", tt.input, got, tt.expected)
}
})
}
}

func TestExtractParams(t *testing.T) {
tests := []struct {
name string
input []map[string]any
expected map[string]any
}{
{"empty input", []map[string]any{}, map[string]any{}},
{"single map", []map[string]any{{"a": 1}}, map[string]any{"a": 1}},
{"two maps merged", []map[string]any{{"a": 1}, {"b": 2}}, map[string]any{"a": 1, "b": 2}},
{"later map wins on collision", []map[string]any{{"a": 1}, {"a": 2}}, map[string]any{"a": 2}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractParams(tt.input...)
if !maps.Equal(got, tt.expected) {
t.Errorf("extractParams(%v) = %v; want %v", tt.input, got, tt.expected)
}
})
}
}
Loading