diff --git a/storage/cosmosdb.go b/storage/cosmosdb.go index 3474b6f..af464fc 100644 --- a/storage/cosmosdb.go +++ b/storage/cosmosdb.go @@ -164,7 +164,7 @@ func (s *CosmosDBAdapter) GetLatestMigration() (int, error) { func (s *CosmosDBAdapter) Create(item any, params ...map[string]any) error { // Extract provider-specific parameters - paramMap := s.extractParams(params...) + paramMap := extractParams(params...) containerName := s.getContainerName(item) containerClient, err := s.databaseClient.NewContainer(containerName) @@ -230,7 +230,7 @@ func (s *CosmosDBAdapter) Get(dest any, filter map[string]any, params ...map[str } // Extract provider-specific parameters - paramMap := s.extractParams(params...) + paramMap := extractParams(params...) containerName := s.getContainerName(dest) containerClient, err := s.databaseClient.NewContainer(containerName) @@ -302,7 +302,7 @@ func (s *CosmosDBAdapter) Update(item any, filter map[string]any, params ...map[ } // Extract provider-specific parameters - paramMap := s.extractParams(params...) + paramMap := extractParams(params...) containerName := s.getContainerName(item) containerClient, err := s.databaseClient.NewContainer(containerName) @@ -388,7 +388,7 @@ func (s *CosmosDBAdapter) Delete(item any, filter map[string]any, params ...map[ } // Extract provider-specific parameters - paramMap := s.extractParams(params...) + paramMap := extractParams(params...) containerName := s.getContainerName(item) containerClient, err := s.databaseClient.NewContainer(containerName) @@ -428,8 +428,11 @@ func (s *CosmosDBAdapter) Delete(item any, filter map[string]any, params ...map[ func (s *CosmosDBAdapter) List(dest any, sortKey string, filter map[string]any, limit int, cursor string, params ...map[string]any) (string, error) { // Extract sort direction from params - paramMap := s.extractParams(params...) - sortDirection := s.extractSortDirection(paramMap) + paramMap := extractParams(params...) + sortDirection, err := extractSortDirection(paramMap) + if err != nil { + return "", fmt.Errorf("failed to list: %w", err) + } return s.executePaginatedQuery(dest, sortKey, sortDirection, limit, cursor, filter, params...) } @@ -441,8 +444,11 @@ func (s *CosmosDBAdapter) Search(dest any, sortKey string, query string, limit i // For custom queries, use the Query method instead // Extract sort direction from params - paramMap := s.extractParams(params...) - sortDirection := s.extractSortDirection(paramMap) + paramMap := extractParams(params...) + sortDirection, err := extractSortDirection(paramMap) + if err != nil { + return "", fmt.Errorf("failed to search: %w", err) + } // Use executePaginatedQuery with empty filter (the query parameter is ignored for CosmosDB) return s.executePaginatedQuery(dest, sortKey, sortDirection, limit, cursor, map[string]any{}, params...) @@ -514,17 +520,23 @@ func (s *CosmosDBAdapter) Query(dest any, statement string, limit int, cursor st return nextCursor, nil } +// executePaginatedQuery runs a cursor-paginated Cosmos DB query against the container for dest. +// The cursor is a Cosmos DB continuation token. func (s *CosmosDBAdapter) executePaginatedQuery( dest any, sortKey string, - sortDirection string, + sortDirection SortingDirection, limit int, cursor string, filter map[string]any, params ...map[string]any, ) (string, error) { + if err := validateSortKey(sortKey); err != nil { + return "", err + } + // Extract provider-specific parameters - paramMap := s.extractParams(params...) + paramMap := extractParams(params...) containerName := s.getContainerName(dest) containerClient, err := s.databaseClient.NewContainer(containerName) @@ -706,31 +718,6 @@ func (s *CosmosDBAdapter) executeQuery( return page, err } -// extractParams merges all provided parameter maps into a single map -func (s *CosmosDBAdapter) extractParams(params ...map[string]any) map[string]any { - paramMap := make(map[string]any) - for _, param := range params { - for k, v := range param { - paramMap[k] = v - } - } - return paramMap -} - -// extractSortDirection extracts and validates sort direction from params -func (s *CosmosDBAdapter) extractSortDirection(paramMap map[string]any) string { - sortDirection := "ASC" // Default to ASC - if dir, exists := paramMap["sort_direction"]; exists { - if dirStr, ok := dir.(string); ok { - sortDirection = strings.ToUpper(dirStr) - if sortDirection != "ASC" && sortDirection != "DESC" { - sortDirection = "ASC" // Fallback to ASC for invalid values - } - } - } - return sortDirection -} - // buildFilter constructs WHERE clause conditions from filter map func (s *CosmosDBAdapter) buildFilter(filter map[string]any, paramIndex *int) (string, []azcosmos.QueryParameter) { conditions := []string{} diff --git a/storage/dynamodb.go b/storage/dynamodb.go index c7cbc33..4774381 100644 --- a/storage/dynamodb.go +++ b/storage/dynamodb.go @@ -216,6 +216,9 @@ func (s *DynamoDBAdapter) executePaginatedQuery( } func (s *DynamoDBAdapter) List(dest any, sortKey string, filter map[string]any, limit int, cursor string, params ...map[string]any) (string, error) { + if err := validateSortKey(sortKey); err != nil { + return "", err + } return s.executePaginatedQuery(dest, limit, cursor, func(input *dynamodb.ExecuteStatementInput) *dynamodb.ExecuteStatementInput { query := fmt.Sprintf(`SELECT * FROM "%s"`, s.getTableName(dest)) @@ -235,6 +238,9 @@ func (s *DynamoDBAdapter) List(dest any, sortKey string, filter map[string]any, } func (s *DynamoDBAdapter) Search(dest any, sortKey string, query string, limit int, cursor string, params ...map[string]any) (string, error) { + if err := validateSortKey(sortKey); err != nil { + return "", err + } return s.executePaginatedQuery(dest, limit, cursor, func(input *dynamodb.ExecuteStatementInput) *dynamodb.ExecuteStatementInput { // Parse Lucene query destType := reflect.TypeOf(dest).Elem().Elem() diff --git a/storage/sql.go b/storage/sql.go index 8313e11..8827f35 100644 --- a/storage/sql.go +++ b/storage/sql.go @@ -214,13 +214,19 @@ func (s *SQLAdapter) Delete(item any, filter map[string]any, params ...map[strin return result.Error } +// executePaginatedQuery runs a cursor-paginated SELECT using the provided query builder scope. +// The cursor is base64-encoded from the sortKey field value of the last returned row. func (s *SQLAdapter) executePaginatedQuery( dest any, sortKey string, + sortDirection SortingDirection, limit int, cursor string, builder queryBuilder, ) (string, error) { + if err := validateSortKey(sortKey); err != nil { + return "", err + } var cursorValue string if cursor != "" { bytes, err := base64.StdEncoding.DecodeString(cursor) @@ -231,10 +237,14 @@ func (s *SQLAdapter) executePaginatedQuery( } q := s.DB.Model(dest).Scopes(builder) - q = q.Limit(limit + 1).Order(fmt.Sprintf("%s ASC", sortKey)) + q = q.Limit(limit + 1).Order(fmt.Sprintf("%s %s", sortKey, sortDirection)) if cursorValue != "" { - q = q.Where(fmt.Sprintf("%s > ?", sortKey), cursorValue) + cursorOp := ">" + if sortDirection == Descending { + cursorOp = "<" + } + q = q.Where(fmt.Sprintf("%s %s ?", sortKey, cursorOp), cursorValue) } if result := q.Find(dest); result.Error != nil { @@ -251,19 +261,29 @@ func (s *SQLAdapter) executePaginatedQuery( if destSlice.Len() > limit { lastItem := destSlice.Index(limit - 1) field := reflect.Indirect(lastItem).FieldByName(sortKey) - if field.IsValid() && field.Kind() == reflect.String { + if !field.IsValid() { + slog.Warn("cursor extraction failed: sort_key does not match any exported struct field", + "sort_key", sortKey, + "hint", "sort_key must be the Go struct field name (e.g. 'CreatedAt'), not the DB column name (e.g. 'created_at')") + } else if field.Kind() != reflect.String { + slog.Warn("cursor extraction failed: struct field is not a string", + "sort_key", sortKey, + "field_kind", field.Kind().String()) + } else { nextCursor = base64.StdEncoding.EncodeToString([]byte(field.String())) } destSlice.Set(destSlice.Slice(0, limit)) - } else { - nextCursor = "" } return nextCursor, nil } func (s *SQLAdapter) List(dest any, sortKey string, filter map[string]any, limit int, cursor string, params ...map[string]any) (string, error) { - return s.executePaginatedQuery(dest, sortKey, limit, cursor, func(q *gorm.DB) *gorm.DB { + sortDirection, err := extractSortDirection(extractParams(params...)) + if err != nil { + return "", fmt.Errorf("failed to list: %w", err) + } + return s.executePaginatedQuery(dest, sortKey, sortDirection, limit, cursor, func(q *gorm.DB) *gorm.DB { if len(filter) > 0 { query, bindings := s.buildQuery(filter) return q.Where(query, bindings) @@ -273,8 +293,12 @@ func (s *SQLAdapter) List(dest any, sortKey string, filter map[string]any, limit } func (s *SQLAdapter) Search(dest any, sortKey string, query string, limit int, cursor string, params ...map[string]any) (string, error) { + sortDirection, err := extractSortDirection(extractParams(params...)) + if err != nil { + return "", fmt.Errorf("failed to search: %w", err) + } if query == "" { - return s.executePaginatedQuery(dest, sortKey, limit, cursor, func(q *gorm.DB) *gorm.DB { + return s.executePaginatedQuery(dest, sortKey, sortDirection, limit, cursor, func(q *gorm.DB) *gorm.DB { return q }) } @@ -296,7 +320,7 @@ func (s *SQLAdapter) Search(dest any, sortKey string, query string, limit int, c slog.Debug(fmt.Sprintf(`Where clause: %s, with params %s`, whereClause, queryParams)) - return s.executePaginatedQuery(dest, sortKey, limit, cursor, func(q *gorm.DB) *gorm.DB { + return s.executePaginatedQuery(dest, sortKey, sortDirection, limit, cursor, func(q *gorm.DB) *gorm.DB { if whereClause != "" { return q.Where(whereClause, queryParams...) } @@ -324,6 +348,7 @@ func (s *SQLAdapter) Query(dest any, statement string, limit int, cursor string, return "", fmt.Errorf("not implemented yet") } + func (s *SQLAdapter) buildQuery(filter map[string]any) (string, map[string]any) { clauses := []string{} bindings := make(map[string]any) diff --git a/storage/sql_test.go b/storage/sql_test.go new file mode 100644 index 0000000..788e558 --- /dev/null +++ b/storage/sql_test.go @@ -0,0 +1,106 @@ +package storage + +import ( + "maps" + "testing" +) + +func TestListRejectsMaliciousSortKey(t *testing.T) { + // Reset singleton so we get a fresh SQLite adapter + prev := sqlAdapterInstance + sqlAdapterInstance = nil + t.Cleanup(func() { sqlAdapterInstance = prev }) + adapter := GetSQLAdapterInstance(map[string]string{ + "provider": "sqlite", + }) + type Row struct { + ID string `gorm:"primaryKey"` + } + _ = adapter.DB.AutoMigrate(&Row{}) + + var rows []Row + _, err := adapter.List(&rows, "id; DROP TABLE rows", map[string]any{}, 10, "") + if err == nil { + t.Error("expected error for malicious sortKey, got nil") + } +} + +func TestValidateSortKey(t *testing.T) { + tests := []struct { + name string + input string + wantErr bool + }{ + {"simple column", "id", false}, + {"snake_case column", "created_at", false}, + {"mixed case", "createdAt", false}, + {"with numbers", "field1", false}, + {"empty string", "", true}, + {"leading digit", "1field", true}, + {"dot notation injection", "id; DROP TABLE users", true}, + {"semicolon", "id;DROP", true}, + {"space", "col name", true}, + {"table.column dot", "t.col", true}, + {"SQL comment", "id--", true}, + {"single quote", "id'", true}, + {"underscore prefix", "_ts", false}, // CosmosDB system fields like _ts are valid + {"null byte", "id\x00DROP", true}, // null byte injection rejected + {"unicode lookalike", "iа", true}, // Cyrillic а (U+0430) rejected, not ASCII + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateSortKey(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("validateSortKey(%q) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + }) + } +} + +func TestExtractSortDirection(t *testing.T) { + tests := []struct { + name string + input map[string]any + expected SortingDirection + wantErr bool + }{ + {"default when missing", map[string]any{}, Ascending, false}, + {"ASC explicit", map[string]any{SortDirectionKey: "ASC"}, Ascending, false}, + {"DESC", map[string]any{SortDirectionKey: "DESC"}, Descending, false}, + {"lowercase desc", map[string]any{SortDirectionKey: "desc"}, Descending, false}, + {"invalid returns error", map[string]any{SortDirectionKey: "SIDEWAYS"}, "", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := extractSortDirection(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("extractSortDirection(%v) error = %v, wantErr %v", tt.input, err, tt.wantErr) + return + } + if !tt.wantErr && got != tt.expected { + t.Errorf("extractSortDirection(%v) = %q; want %q", tt.input, got, tt.expected) + } + }) + } +} + +func TestExtractParams(t *testing.T) { + tests := []struct { + name string + input []map[string]any + expected map[string]any + }{ + {"empty input", []map[string]any{}, map[string]any{}}, + {"single map", []map[string]any{{"a": 1}}, map[string]any{"a": 1}}, + {"two maps merged", []map[string]any{{"a": 1}, {"b": 2}}, map[string]any{"a": 1, "b": 2}}, + {"later map wins on collision", []map[string]any{{"a": 1}, {"a": 2}}, map[string]any{"a": 2}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractParams(tt.input...) + if !maps.Equal(got, tt.expected) { + t.Errorf("extractParams(%v) = %v; want %v", tt.input, got, tt.expected) + } + }) + } +} diff --git a/storage/storage.go b/storage/storage.go index 2414a3f..d1fdd9d 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -3,6 +3,10 @@ package storage import ( "embed" "errors" + "fmt" + "maps" + "regexp" + "strings" ) var ConfigFs embed.FS @@ -22,7 +26,15 @@ type StorageAdapter interface { Get(dest any, filter map[string]any, params ...map[string]any) error Update(item any, filter map[string]any, params ...map[string]any) error Delete(item any, filter map[string]any, params ...map[string]any) error + // List returns a page of items matching filter, ordered by sortKey. + // sortKey must be the Go struct field name (e.g. "CreatedAt"), not the DB column name. + // Pass SortDirectionKey via params to control order; defaults to Ascending. + // Returns a cursor for the next page, or "" on the final page. List(dest any, sortKey string, filter map[string]any, limit int, cursor string, params ...map[string]any) (string, error) + // Search returns a page of items matching a query string, ordered by sortKey. + // sortKey must be the Go struct field name (e.g. "CreatedAt"), not the DB column name. + // Pass SortDirectionKey via params to control order; defaults to Ascending. + // Returns a cursor for the next page, or "" on the final page. Search(dest any, sortKey string, query string, limit int, cursor string, params ...map[string]any) (string, error) Count(dest any, filter map[string]any, params ...map[string]any) (int64, error) Query(dest any, statement string, limit int, cursor string, params ...map[string]any) (string, error) @@ -47,6 +59,56 @@ const ( COSMOSDB_PROVIDER StorageProviders = "cosmosdb" ) +type SortingDirection string + +const ( + Ascending SortingDirection = "ASC" + Descending SortingDirection = "DESC" +) + +const SortDirectionKey = "sort_direction" + +// extractParams merges all provided parameter maps into a single flat map. +// When keys collide, later maps win. +func extractParams(params ...map[string]any) map[string]any { + flatParams := make(map[string]any) + for _, param := range params { + maps.Copy(flatParams, param) + } + return flatParams +} + +// extractSortDirection reads SortDirectionKey from paramMap and returns the corresponding SortingDirection. +// Defaults to Ascending when the key is absent. +// Returns an error if the value is present but not a valid SortingDirection ("ASC" or "DESC", case-insensitive). +func extractSortDirection(paramMap map[string]any) (SortingDirection, error) { + if dir, exists := paramMap[SortDirectionKey]; exists { + if dirStr, ok := dir.(string); ok { + switch SortingDirection(strings.ToUpper(dirStr)) { + case Ascending: + return Ascending, nil + case Descending: + return Descending, nil + } + } + return "", fmt.Errorf("invalid sort direction: %v", dir) + } + return Ascending, nil +} + +// validColumnName matches identifiers safe to interpolate as SQL/NoSQL column names. +// Allows letters, digits, and underscores; may start with a letter or underscore. +var validColumnName = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + +// validateSortKey returns an error if key contains characters that could enable +// injection via ORDER BY or similar clauses where parameterization is not available. +func validateSortKey(key string) error { + if !validColumnName.MatchString(key) { + return fmt.Errorf("invalid sort key %q: must match [a-zA-Z_][a-zA-Z0-9_]*", key) + } + return nil +} + func (s StorageAdapterFactory) GetInstance(adapterType StorageAdapterType, config any) (StorageAdapter, error) { if config == nil { config = make(map[string]string)