From d4231736fad6af76e005d7c757b4e9a653c20b37 Mon Sep 17 00:00:00 2001 From: lutherwaves Date: Wed, 24 Dec 2025 22:47:31 +0000 Subject: [PATCH 01/13] feat: implement Lucene query parser with PostgreSQL and DynamoDB support Implements Apache Lucene query syntax parser using go-lucene library with custom drivers for PostgreSQL (JSONB) and DynamoDB PartiQL. Supports field:value queries, wildcards, ranges, boolean operators, quoted phrases, fuzzy search, implicit search expansion, and JSONB field notation. Includes field validation, security limits, and comprehensive test coverage. --- go.mod | 1 + go.sum | 2 + storage/search/lucene/driver.go | 581 +++++++++++++++++ storage/search/lucene/parser.go | 908 +++++++++++++++++---------- storage/search/lucene/parser_test.go | 880 ++++++++++++++++++++++++++ storage/sql.go | 6 +- 6 files changed, 2038 insertions(+), 340 deletions(-) create mode 100644 storage/search/lucene/driver.go create mode 100644 storage/search/lucene/parser_test.go diff --git a/go.mod b/go.mod index 0039d51..ff6e0a7 100644 --- a/go.mod +++ b/go.mod @@ -38,6 +38,7 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/golang-jwt/jwt/v5 v5.3.0 // indirect + github.com/grindlemire/go-lucene v0.0.26 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgx/v5 v5.6.0 // indirect diff --git a/go.sum b/go.sum index b736785..85fc238 100644 --- a/go.sum +++ b/go.sum @@ -75,6 +75,8 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/grindlemire/go-lucene v0.0.26 h1:81ttZkMvU3rFD0TfmjdIZT2U0Fd4TT7buDy+xq1x5EQ= +github.com/grindlemire/go-lucene v0.0.26/go.mod h1:INRJBdhkLjS4jc7XgkGPfzC5wuFg3BHDukXMTc+OTbc= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= diff --git a/storage/search/lucene/driver.go b/storage/search/lucene/driver.go new file mode 100644 index 0000000..8320901 --- /dev/null +++ b/storage/search/lucene/driver.go @@ -0,0 +1,581 @@ +package lucene + +import ( + "fmt" + "strings" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" + "github.com/grindlemire/go-lucene/pkg/driver" + "github.com/grindlemire/go-lucene/pkg/lucene/expr" +) + +// PostgresJSONBDriver is a custom PostgreSQL driver that supports JSONB field notation. +// It extends the base PostgreSQL driver to handle field->>'subfield' syntax. +type PostgresJSONBDriver struct { + driver.Base + fields map[string]FieldInfo // Map of field names to their metadata +} + +func NewPostgresJSONBDriver(fields []FieldInfo) *PostgresJSONBDriver { + fieldMap := make(map[string]FieldInfo) + for _, f := range fields { + fieldMap[f.Name] = f + } + + // RenderFNs map - we handle most operators in renderParamInternal + // Only keeping base implementations for operators we don't intercept + fns := map[expr.Operator]driver.RenderFN{ + expr.Literal: driver.Shared[expr.Literal], + expr.And: driver.Shared[expr.And], + expr.Or: driver.Shared[expr.Or], + expr.Not: driver.Shared[expr.Not], + expr.Equals: driver.Shared[expr.Equals], + expr.Range: driver.Shared[expr.Range], + expr.Must: driver.Shared[expr.Must], + expr.MustNot: driver.Shared[expr.MustNot], + expr.Wild: driver.Shared[expr.Wild], + expr.Regexp: driver.Shared[expr.Regexp], + expr.Like: driver.Shared[expr.Like], + expr.Greater: driver.Shared[expr.Greater], + expr.GreaterEq: driver.Shared[expr.GreaterEq], + expr.Less: driver.Shared[expr.Less], + expr.LessEq: driver.Shared[expr.LessEq], + expr.In: driver.Shared[expr.In], + expr.List: driver.Shared[expr.List], + } + + return &PostgresJSONBDriver{ + Base: driver.Base{ + RenderFNs: fns, + }, + fields: fieldMap, + } +} + +// RenderParam renders the expression with PostgreSQL-style $N placeholders. +func (p *PostgresJSONBDriver) RenderParam(e *expr.Expression) (string, []any, error) { + // Process JSONB field notation before rendering + p.processJSONBFields(e) + + // Use our custom rendering logic + str, params, err := p.renderParamInternal(e) + if err != nil { + return "", nil, err + } + + // Convert ? to $N format + str = convertToPostgresPlaceholders(str) + + return str, params, nil +} + +// renderParamInternal dispatches to specialized renderers based on operator type. +func (p *PostgresJSONBDriver) renderParamInternal(e *expr.Expression) (string, []any, error) { + if e == nil { + return "", nil, nil + } + + switch e.Op { + case expr.Like, expr.Wild: + return p.renderLikeOrWild(e) + case expr.Fuzzy: + return p.renderFuzzy(e) + case expr.Boost: + return "", nil, fmt.Errorf("boost operator (^) is not supported in SQL filtering; it only affects ranking/scoring") + case expr.Range: + return p.renderRange(e) + case expr.Equals, expr.Greater, expr.Less, expr.GreaterEq, expr.LessEq: + return p.renderComparison(e) + case expr.And, expr.Or, expr.Must, expr.MustNot: + return p.renderBinary(e) + default: + // Use base implementation for all other operators + return p.Base.RenderParam(e) + } +} + +// renderLikeOrWild converts LIKE and Wild operators to PostgreSQL ILIKE for case-insensitive matching. +func (p *PostgresJSONBDriver) renderLikeOrWild(e *expr.Expression) (string, []any, error) { + leftStr, leftParams, err := p.serializeColumn(e.Left) + if err != nil { + return "", nil, err + } + + rightStr, rightParams, err := p.serializeValue(e.Right) + if err != nil { + return "", nil, err + } + + params := append(leftParams, rightParams...) + + if isJSONBSyntax(leftStr) { + return fmt.Sprintf("%s ILIKE %s", leftStr, rightStr), params, nil + } + return fmt.Sprintf("%s::text ILIKE %s", leftStr, rightStr), params, nil +} + +// renderFuzzy handles fuzzy search using PostgreSQL similarity() function. +// Requires pg_trgm extension. +// For queries like "name:roam~2", the structure is: +// - Op: Fuzzy +// - Left: Equals expression (name:roam) with Left=Column("name"), Right=Literal("roam") +// - Right: nil (distance stored in unexported fuzzyDistance field) +func (p *PostgresJSONBDriver) renderFuzzy(e *expr.Expression) (string, []any, error) { + leftExpr, ok := e.Left.(*expr.Expression) + if !ok || leftExpr.Op != expr.Equals { + return "", nil, fmt.Errorf("fuzzy operator requires field:value syntax (e.g., name:roam~2)") + } + + colStr, colParams, err := p.serializeColumn(leftExpr.Left) + if err != nil { + return "", nil, err + } + + termStr, termParams, err := p.serializeValue(leftExpr.Right) + if err != nil { + return "", nil, err + } + + params := append(colParams, termParams...) + + // Use threshold 0.3 (lower = more matches, higher = stricter). + // The fuzzy distance from go-lucene is unexported, so we use a reasonable default. + threshold := 0.3 + + if isJSONBSyntax(colStr) { + return fmt.Sprintf("similarity(%s, %s) > %f", colStr, termStr, threshold), params, nil + } + return fmt.Sprintf("similarity(%s::text, %s) > %f", colStr, termStr, threshold), params, nil +} + +// renderComparison handles comparison operators with IS NULL support for null values. +func (p *PostgresJSONBDriver) renderComparison(e *expr.Expression) (string, []any, error) { + leftStr, leftParams, err := p.serializeColumn(e.Left) + if err != nil { + return "", nil, err + } + + if isNullValue(e.Right) { + if e.Op == expr.Equals { + return fmt.Sprintf("%s IS NULL", leftStr), leftParams, nil + } + return "", nil, fmt.Errorf("cannot use comparison operators (>, <, >=, <=) with null value") + } + + rightStr, rightParams, err := p.serializeValue(e.Right) + if err != nil { + return "", nil, err + } + + params := append(leftParams, rightParams...) + + var opSymbol string + switch e.Op { + case expr.Equals: + opSymbol = "=" + case expr.Greater: + opSymbol = ">" + case expr.Less: + opSymbol = "<" + case expr.GreaterEq: + opSymbol = ">=" + case expr.LessEq: + opSymbol = "<=" + } + + return fmt.Sprintf("%s %s %s", leftStr, opSymbol, rightStr), params, nil +} + +// renderBinary handles binary and unary logical operators recursively. +// Note: Must and MustNot are unary (only Left operand), while And and Or are binary. +func (p *PostgresJSONBDriver) renderBinary(e *expr.Expression) (string, []any, error) { + switch e.Op { + case expr.Must, expr.MustNot: + if e.Left == nil { + return "", nil, fmt.Errorf("%s operator requires a left operand", e.Op) + } + + if leftExpr, ok := e.Left.(*expr.Expression); ok { + leftStr, leftParams, err := p.renderParamInternal(leftExpr) + if err != nil { + return "", nil, err + } + + if e.Op == expr.Must { + return leftStr, leftParams, nil + } + return fmt.Sprintf("NOT (%s)", leftStr), leftParams, nil + } + + leftStr, leftParams, err := p.serializeColumn(e.Left) + if err != nil { + leftStr, leftParams, err = p.serializeValue(e.Left) + if err != nil { + return p.Base.RenderParam(e) + } + } + + if e.Op == expr.Must { + return leftStr, leftParams, nil + } + return fmt.Sprintf("NOT (%s)", leftStr), leftParams, nil + + case expr.And, expr.Or: + if e.Left == nil || e.Right == nil { + return "", nil, fmt.Errorf("%s operator requires both left and right operands", e.Op) + } + + leftExpr, leftIsExpr := e.Left.(*expr.Expression) + rightExpr, rightIsExpr := e.Right.(*expr.Expression) + + if !leftIsExpr || !rightIsExpr { + return p.Base.RenderParam(e) + } + + leftStr, leftParams, err := p.renderParamInternal(leftExpr) + if err != nil { + return "", nil, err + } + + rightStr, rightParams, err := p.renderParamInternal(rightExpr) + if err != nil { + return "", nil, err + } + + params := append(leftParams, rightParams...) + + if e.Op == expr.And { + return fmt.Sprintf("(%s) AND (%s)", leftStr, rightStr), params, nil + } + return fmt.Sprintf("(%s) OR (%s)", leftStr, rightStr), params, nil + + default: + return "", nil, fmt.Errorf("unsupported operator: %v", e.Op) + } +} + +func (p *PostgresJSONBDriver) serializeColumn(in any) (string, []any, error) { + switch v := in.(type) { + case expr.Column: + colStr := string(v) + if isJSONBSyntax(colStr) { + return colStr, nil, nil + } + return fmt.Sprintf(`"%s"`, colStr), nil, nil + case string: + if isJSONBSyntax(v) { + return v, nil, nil + } + return fmt.Sprintf(`"%s"`, v), nil, nil + case *expr.Expression: + if v.Op == expr.Literal && v.Left != nil { + if col, ok := v.Left.(expr.Column); ok { + colStr := string(col) + if isJSONBSyntax(colStr) { + return colStr, nil, nil + } + return fmt.Sprintf(`"%s"`, colStr), nil, nil + } + } + return p.renderParamInternal(v) + default: + return "", nil, fmt.Errorf("unexpected column type: %T", v) + } +} + +// serializeValue converts Lucene wildcards (* and ?) to SQL wildcards (% and _). +func (p *PostgresJSONBDriver) serializeValue(in any) (string, []any, error) { + switch v := in.(type) { + case string: + return "?", []any{convertWildcards(v)}, nil + case *expr.Expression: + if v.Op == expr.Literal && v.Left != nil { + literalVal := fmt.Sprintf("%v", v.Left) + return "?", []any{convertWildcards(literalVal)}, nil + } + if v.Op == expr.Wild && v.Left != nil { + literalVal := fmt.Sprintf("%v", v.Left) + return "?", []any{convertWildcards(literalVal)}, nil + } + return p.renderParamInternal(v) + case nil: + return "", nil, fmt.Errorf("nil value in expression") + default: + return "?", []any{v}, nil + } +} + +// processJSONBFields recursively processes the expression tree to convert +// field.subfield notation to PostgreSQL JSONB syntax field->>'subfield'. +func (p *PostgresJSONBDriver) processJSONBFields(e *expr.Expression) { + if e == nil { + return + } + + // Process left side if it's a column + if col, ok := e.Left.(expr.Column); ok { + e.Left = p.formatFieldName(string(col)) + } + + // Recursively process expressions + if leftExpr, ok := e.Left.(*expr.Expression); ok { + p.processJSONBFields(leftExpr) + } + if rightExpr, ok := e.Right.(*expr.Expression); ok { + p.processJSONBFields(rightExpr) + } + + // Process expression slices + if exprs, ok := e.Left.([]*expr.Expression); ok { + for _, ex := range exprs { + p.processJSONBFields(ex) + } + } + if exprs, ok := e.Right.([]*expr.Expression); ok { + for _, ex := range exprs { + p.processJSONBFields(ex) + } + } +} + +// formatFieldName converts field.subfield to JSONB syntax if the base field is JSONB. +func (p *PostgresJSONBDriver) formatFieldName(fieldName string) expr.Column { + parts := strings.SplitN(fieldName, ".", 2) + if len(parts) == 2 { + baseField := parts[0] + subField := parts[1] + + if field, exists := p.fields[baseField]; exists && field.IsJSONB { + // Return as JSONB operator syntax + return expr.Column(fmt.Sprintf("%s->>'%s'", baseField, subField)) + } + } + return expr.Column(fieldName) +} + +// Helper functions for DRY and cleaner code + +// convertWildcards converts Lucene wildcards to SQL wildcards. +// * (any characters) → % (SQL wildcard) +// ? (single character) → _ (SQL wildcard) +// +// Note: go-lucene's base driver also converts wildcards, but only for expr.Like operators. +// We need this function because we also convert wildcards for expr.Wild expressions +// and when serializing values for fuzzy search and other operators. +func convertWildcards(s string) string { + // Use a builder for efficient string manipulation + var result strings.Builder + result.Grow(len(s)) + + for i := 0; i < len(s); i++ { + c := s[i] + switch c { + case '*': + result.WriteByte('%') + case '?': + result.WriteByte('_') + default: + result.WriteByte(c) + } + } + return result.String() +} + +func isJSONBSyntax(col string) bool { + return strings.Contains(col, "->>") +} + +// isNullValue checks if a value represents null in Lucene query syntax. +// Supports: null, NULL, Null (case-insensitive) +// Note: This is a SQL-specific extension (vanilla Lucene doesn't support NULL values). +// We intentionally do NOT support "empty" or "nil" as they could be legitimate search values. +func isNullValue(v any) bool { + strVal := extractStringValue(v) + if strVal == "" { + return false + } + lower := strings.ToLower(strVal) + return lower == "null" +} + +func extractStringValue(v any) string { + switch val := v.(type) { + case string: + return val + case *expr.Expression: + if val.Op == expr.Literal && val.Left != nil { + if strVal, ok := val.Left.(string); ok { + return strVal + } + } + } + return "" +} + +func extractLiteralValue(v any) string { + if v == nil { + return "" + } + + // If it's an expression, try to extract the Left value (for LITERAL expressions) + if ex, ok := v.(*expr.Expression); ok { + if ex.Op == expr.Literal && ex.Left != nil { + // LITERAL expressions store the actual value in Left + return fmt.Sprintf("%v", ex.Left) + } + // For other expression types, return the string representation + return fmt.Sprintf("%v", v) + } + + // For non-expression types, return as string + return fmt.Sprintf("%v", v) +} + +// renderRange handles range queries including open-ended ranges with wildcards (*). +func (p *PostgresJSONBDriver) renderRange(e *expr.Expression) (string, []any, error) { + colStr, _, err := p.serializeColumn(e.Left) + if err != nil { + return "", nil, err + } + + rangeBoundary, ok := e.Right.(*expr.RangeBoundary) + if !ok { + return "", nil, fmt.Errorf("invalid range expression structure: expected *expr.RangeBoundary, got %T", e.Right) + } + + var minVal, maxVal string + var params []any + + if rangeBoundary.Min != nil { + minVal = extractLiteralValue(rangeBoundary.Min) + } + + if rangeBoundary.Max != nil { + maxVal = extractLiteralValue(rangeBoundary.Max) + } + + if minVal == "*" && maxVal == "*" { + return "", nil, fmt.Errorf("both range bounds cannot be wildcards") + } + + if minVal == "*" { + params = append(params, maxVal) + if rangeBoundary.Inclusive { + return fmt.Sprintf("%s <= ?", colStr), params, nil + } + return fmt.Sprintf("%s < ?", colStr), params, nil + } + + if maxVal == "*" { + params = append(params, minVal) + if rangeBoundary.Inclusive { + return fmt.Sprintf("%s >= ?", colStr), params, nil + } + return fmt.Sprintf("%s > ?", colStr), params, nil + } + + params = append(params, minVal, maxVal) + if rangeBoundary.Inclusive { + return fmt.Sprintf("%s BETWEEN ? AND ?", colStr), params, nil + } + return fmt.Sprintf("(%s > ? AND %s < ?)", colStr, colStr), params, nil +} + +// DynamoDBPartiQLDriver converts Lucene queries to DynamoDB PartiQL. +type DynamoDBPartiQLDriver struct { + driver.Base + fields map[string]FieldInfo +} + +func NewDynamoDBPartiQLDriver(fields []FieldInfo) *DynamoDBPartiQLDriver { + fieldMap := make(map[string]FieldInfo) + for _, f := range fields { + fieldMap[f.Name] = f + } + + fns := map[expr.Operator]driver.RenderFN{ + expr.Literal: driver.Shared[expr.Literal], + expr.And: driver.Shared[expr.And], + expr.Or: driver.Shared[expr.Or], + expr.Not: driver.Shared[expr.Not], + expr.Equals: driver.Shared[expr.Equals], + expr.Range: driver.Shared[expr.Range], + expr.Must: driver.Shared[expr.Must], + expr.MustNot: driver.Shared[expr.MustNot], + expr.Wild: driver.Shared[expr.Wild], + expr.Regexp: driver.Shared[expr.Regexp], + expr.Like: dynamoDBLike, // Custom LIKE for DynamoDB functions + expr.Greater: driver.Shared[expr.Greater], + expr.GreaterEq: driver.Shared[expr.GreaterEq], + expr.Less: driver.Shared[expr.Less], + expr.LessEq: driver.Shared[expr.LessEq], + expr.In: driver.Shared[expr.In], + expr.List: driver.Shared[expr.List], + } + + return &DynamoDBPartiQLDriver{ + Base: driver.Base{ + RenderFNs: fns, + }, + fields: fieldMap, + } +} + +// RenderPartiQL renders the expression to DynamoDB PartiQL with AttributeValue parameters. +func (d *DynamoDBPartiQLDriver) RenderPartiQL(e *expr.Expression) (string, []types.AttributeValue, error) { + // Use base rendering with ? placeholders + str, params, err := d.RenderParam(e) + if err != nil { + return "", nil, err + } + + // Convert params to DynamoDB AttributeValues + attrValues := make([]types.AttributeValue, len(params)) + for i, param := range params { + attrValues[i] = &types.AttributeValueMemberS{Value: fmt.Sprintf("%v", param)} + } + + return str, attrValues, nil +} + +// dynamoDBLike implements LIKE using DynamoDB's begins_with and contains functions. +func dynamoDBLike(left, right string) (string, error) { + // Remove quotes from right side to analyze pattern + pattern := strings.Trim(right, "'") + + // Replace wildcards for analysis + hasPrefix := strings.HasPrefix(pattern, "%") + hasSuffix := strings.HasSuffix(pattern, "%") + + if hasPrefix && hasSuffix { + // %value% -> contains(field, value) + value := strings.Trim(pattern, "%") + return fmt.Sprintf("contains(%s, '%s')", left, value), nil + } else if !hasPrefix && hasSuffix { + // value% -> begins_with(field, value) + value := strings.TrimSuffix(pattern, "%") + return fmt.Sprintf("begins_with(%s, '%s')", left, value), nil + } else if hasPrefix && !hasSuffix { + // %value -> contains(field, value) (DynamoDB doesn't have ends_with) + value := strings.TrimPrefix(pattern, "%") + return fmt.Sprintf("contains(%s, '%s')", left, value), nil + } + + // Exact match + return fmt.Sprintf("%s = %s", left, right), nil +} + +// convertToPostgresPlaceholders converts ? placeholders to PostgreSQL's $N format. +func convertToPostgresPlaceholders(query string) string { + paramIndex := 1 + result := strings.Builder{} + for i := 0; i < len(query); i++ { + if query[i] == '?' { + result.WriteString(fmt.Sprintf("$%d", paramIndex)) + paramIndex++ + } else { + result.WriteByte(query[i]) + } + } + return result.String() +} diff --git a/storage/search/lucene/parser.go b/storage/search/lucene/parser.go index 616ef61..5f95e81 100644 --- a/storage/search/lucene/parser.go +++ b/storage/search/lucene/parser.go @@ -8,52 +8,72 @@ import ( "strings" "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" + lucene "github.com/grindlemire/go-lucene" + "github.com/grindlemire/go-lucene/pkg/lucene/expr" ) -type FieldInfo struct { - Name string - IsJSONB bool -} - -type Parser struct { - DefaultFields []FieldInfo -} - -type NodeType int - +// Safety limits for query parsing const ( - NodeTerm NodeType = iota - NodeWildcard - NodeLogical + DefaultMaxQueryLength = 10000 // 10KB - prevents memory exhaustion + DefaultMaxDepth = 20 // Prevents stack overflow from deep nesting + DefaultMaxTerms = 100 // Prevents CPU exhaustion from complex queries ) -type LogicalOperator string +// FieldInfo describes a searchable field and its properties. +type FieldInfo struct { + Name string + IsJSONB bool + ImplicitSearch bool // Whether this field is included in unfielded/implicit queries +} -const ( - AND LogicalOperator = "AND" - OR LogicalOperator = "OR" - NOT LogicalOperator = "NOT" -) +// Parser provides Lucene query parsing with security limits. +type Parser struct { + Fields []FieldInfo // All searchable fields -type MatchType int + // Security limits (configurable with safe defaults) + MaxQueryLength int // Maximum query string length (default: 10KB) + MaxDepth int // Maximum nesting depth (default: 20) + MaxTerms int // Maximum number of terms (default: 100) -const ( - matchExact MatchType = iota - matchStartsWith - matchEndsWith - matchContains -) + // Field lookup maps for O(1) validation + fieldMap map[string]FieldInfo // All fields by name + jsonbFields map[string]bool // JSONB field names for sub-field validation -type Node struct { - Type NodeType - Field string - Value string - Operator LogicalOperator - Children []*Node - Negate bool - MatchType MatchType + // Custom drivers for different backends + postgresDriver *PostgresJSONBDriver + dynamoDriver *DynamoDBPartiQLDriver } +// NewParserFromType creates a parser by introspecting a struct's fields. +// This is the recommended approach for initializing parsers as it: +// - Works with any backend (PostgreSQL, MySQL, DynamoDB, etc.) +// - Zero database overhead +// - Compile-time safety +// - Auto-detects JSONB fields from gorm tags +// - Auto-sets string fields for implicit search (ImplicitSearch=true) +// +// Example: +// +// type Task struct { +// ID string `json:"id"` +// Name string `json:"name"` // Auto: ImplicitSearch=true +// Description string `json:"description"` // Auto: ImplicitSearch=true +// Status string `json:"status" lucene:"explicit"` // Explicit: ImplicitSearch=false +// CreatedAt time.Time `json:"created_at"` // Auto: ImplicitSearch=false (not string) +// Labels JSONB `json:"labels" gorm:"type:jsonb"` // Auto: IsJSONB=true, ImplicitSearch=false +// } +// +// parser, err := lucene.NewParserFromType(Task{}) +// +// Struct tag controls: +// - lucene:"implicit" - Force ImplicitSearch=true (include in unfielded queries) +// - lucene:"explicit" - Force ImplicitSearch=false (require field:value syntax) +// - gorm:"type:jsonb" - Auto-detected as JSONB field +// +// Auto-detection rules (when no lucene tag): +// - String fields: ImplicitSearch=true (included in unfielded queries) +// - Non-string fields (int, time.Time, uuid, etc.): ImplicitSearch=false +// - JSONB fields: ImplicitSearch=false (require field.subfield syntax) func NewParserFromType(model any) (*Parser, error) { fields, err := getStructFields(model) if err != nil { @@ -62,10 +82,52 @@ func NewParserFromType(model any) (*Parser, error) { return NewParser(fields), nil } -func NewParser(defaultFields []FieldInfo) *Parser { - return &Parser{DefaultFields: defaultFields} +func NewParser(fields []FieldInfo) *Parser { + fieldMap := make(map[string]FieldInfo, len(fields)) + jsonbFields := make(map[string]bool) + for _, f := range fields { + fieldMap[f.Name] = f + if f.IsJSONB { + jsonbFields[f.Name] = true + } + } + + return &Parser{ + Fields: fields, + MaxQueryLength: DefaultMaxQueryLength, + MaxDepth: DefaultMaxDepth, + MaxTerms: DefaultMaxTerms, + fieldMap: fieldMap, + jsonbFields: jsonbFields, + postgresDriver: NewPostgresJSONBDriver(fields), + dynamoDriver: NewDynamoDBPartiQLDriver(fields), + } +} + +// Precompiled regex for performance - matches Lucene operators and special syntax +var ( + // Matches field:value pattern (including JSONB like labels.category:value) + fieldValuePattern = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)?:`) + // Extracts field name from field:value pattern + fieldExtractPattern = regexp.MustCompile(`([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)?):`) + // Matches boolean operators (case-insensitive) + booleanOperators = regexp.MustCompile(`(?i)^(AND|OR|NOT|\+|-)$`) + // Matches range syntax + rangePattern = regexp.MustCompile(`^\[.*\s+TO\s+.*\]$|^\{.*\s+TO\s+.*\}$`) +) + +// InvalidFieldError represents an error when a query references a non-existent field +type InvalidFieldError struct { + Field string + ValidFields []string +} + +func (e *InvalidFieldError) Error() string { + return fmt.Sprintf("invalid field '%s' in query; valid fields are: %s", e.Field, strings.Join(e.ValidFields, ", ")) } +// getStructFields uses reflection to extract field metadata from a struct. +// String fields get ImplicitSearch=true, others get ImplicitSearch=false. func getStructFields(model any) ([]FieldInfo, error) { t := reflect.TypeOf(model) if t.Kind() == reflect.Ptr { @@ -91,425 +153,593 @@ func getStructFields(model any) ([]FieldInfo, error) { gormTag := field.Tag.Get("gorm") isJSONB := strings.Contains(gormTag, "type:jsonb") + luceneTag := field.Tag.Get("lucene") + implicitSearch := false + if luceneTag == "implicit" { + implicitSearch = true + } else if luceneTag != "explicit" { + implicitSearch = field.Type.Kind() == reflect.String && !isJSONB + } + fields = append(fields, FieldInfo{ - Name: jsonTag, - IsJSONB: isJSONB, + Name: jsonTag, + IsJSONB: isJSONB, + ImplicitSearch: implicitSearch, }) } return fields, nil } +// ParseToMap parses a Lucene query into a map representation. +// Note: This is a legacy method kept for backward compatibility. func (p *Parser) ParseToMap(query string) (map[string]any, error) { - node, err := p.parse(query) + + if err := p.validateQuery(query); err != nil { + return nil, err + } + + e, err := p.parseWithImplicitSearch(query) if err != nil { return nil, err } - return p.nodeToMap(node), nil + + // Convert expression to map + return p.exprToMap(e), nil } +// ParseToSQL parses a Lucene query and converts it to PostgreSQL SQL with parameters. func (p *Parser) ParseToSQL(query string) (string, []any, error) { - slog.Debug(fmt.Sprintf(`Parsing query to sql: %s`, query)) - re := regexp.MustCompile(`(\w+):"([^"]+)"`) - query = re.ReplaceAllString(query, `$1:$2`) - node, err := p.parse(query) - if err != nil { + slog.Debug(fmt.Sprintf(`Parsing query to SQL: %s`, query)) + + if err := p.validateQuery(query); err != nil { return "", nil, err } - return p.nodeToSQL(node) -} -func (p *Parser) parse(query string) (*Node, error) { - query = strings.TrimSpace(query) - if query == "" { - return nil, nil + // Expand implicit terms first (for validation of the full query) + expandedQuery := p.expandImplicitTerms(query) + + // Validate all field references exist in the model + if err := p.ValidateFields(expandedQuery); err != nil { + return "", nil, err } - if strings.HasPrefix(query, "(") && strings.HasSuffix(query, ")") { - return p.parse(query[1 : len(query)-1]) + // Parse using the library + e, err := p.parseWithImplicitSearch(query) + if err != nil { + return "", nil, err } - if andParts := splitByOperator(query, "AND"); len(andParts) > 1 { - return p.createLogicalNode(AND, andParts) + // Render using custom PostgreSQL driver + sql, params, err := p.postgresDriver.RenderParam(e) + if err != nil { + return "", nil, err } - if orParts := splitByOperator(query, "OR"); len(orParts) > 1 { - return p.createLogicalNode(OR, orParts) + + return sql, params, nil +} + +// ParseToDynamoDBPartiQL parses a Lucene query and converts it to DynamoDB PartiQL. +func (p *Parser) ParseToDynamoDBPartiQL(query string) (string, []types.AttributeValue, error) { + slog.Debug(fmt.Sprintf(`Parsing query to DynamoDB PartiQL: %s`, query)) + + if err := p.validateQuery(query); err != nil { + return "", nil, err } - if notParts := splitByOperator(query, "NOT"); len(notParts) > 1 { - return p.createLogicalNode(NOT, notParts) + + // Expand implicit terms first (for validation of the full query) + expandedQuery := p.expandImplicitTerms(query) + + // Validate all field references exist in the model + if err := p.ValidateFields(expandedQuery); err != nil { + return "", nil, err } - if parts := strings.SplitN(query, ":", 2); len(parts) == 2 { - field := strings.TrimSpace(parts[0]) - value := strings.TrimSpace(parts[1]) - // Skip empty fields or values - if field == "" || value == "" { - return nil, nil - } - return p.createTermNode(field, value) + // Parse using the library + e, err := p.parseWithImplicitSearch(query) + if err != nil { + return "", nil, err } - // Skip empty implicit terms - if query = strings.TrimSpace(query); query == "" { - return nil, nil + // Render using custom DynamoDB driver + partiql, attrs, err := p.dynamoDriver.RenderPartiQL(e) + if err != nil { + return "", nil, err } - return p.createImplicitNode(query) + return partiql, attrs, nil } -func splitByOperator(input string, op string) []string { - // Handle case where the operator is at the beginning of the string - trimmedInput := strings.TrimSpace(input) - lowerInput := strings.ToLower(trimmedInput) - lowerOp := strings.ToLower(op) - - if strings.HasPrefix(lowerInput, lowerOp) { - // Check if it's a standalone word (followed by space or end of string) - opLength := len(op) - if len(trimmedInput) == opLength || (len(trimmedInput) > opLength && trimmedInput[opLength] == ' ') { - afterOp := strings.TrimSpace(trimmedInput[opLength:]) - if afterOp != "" { - return []string{"", afterOp} - } - } +func (p *Parser) validateQuery(query string) error { + if len(query) > p.MaxQueryLength { + return fmt.Errorf("query too long: %d bytes exceeds maximum of %d bytes", len(query), p.MaxQueryLength) + } + + depth := calculateNestingDepth(query) + if depth > p.MaxDepth { + return fmt.Errorf("query too complex: nesting depth %d exceeds maximum of %d", depth, p.MaxDepth) } - // Original logic for operators in the middle - re := regexp.MustCompile(fmt.Sprintf(`(?i)\s+%s\s+`, op)) - parts := re.Split(input, -1) - if len(parts) > 1 { - return parts + terms := countTerms(query) + if terms > p.MaxTerms { + return fmt.Errorf("query too large: %d terms exceeds maximum of %d", terms, p.MaxTerms) } return nil } -func (p *Parser) createImplicitNode(term string) (*Node, error) { - slog.Debug(fmt.Sprintf(`Handling implicit: %s`, term)) - term = strings.Trim(term, `"`) +func calculateNestingDepth(query string) int { + maxDepth := 0 + currentDepth := 0 + inQuotes := false + + for i := 0; i < len(query); i++ { + c := query[i] + + if c == '\\' && i+1 < len(query) { + i++ + continue + } - containsWildcard := strings.Contains(term, "*") || strings.Contains(term, "?") + if c == '"' { + inQuotes = !inQuotes + continue + } - node := &Node{ - Type: NodeLogical, - Operator: OR, + if !inQuotes { + switch c { + case '(', '[', '{': + currentDepth++ + if currentDepth > maxDepth { + maxDepth = currentDepth + } + case ')', ']', '}': + currentDepth-- + } + } } - for _, field := range p.DefaultFields { - var child *Node - var err error + return maxDepth +} - if containsWildcard { - child, err = p.createWildcardNode(field.Name, term) - } else { - child, err = p.createTermNode(field.Name, term) +// countTerms counts search terms in a query. +// Terms include field:value pairs, implicit terms, and quoted phrases. +// Operators (AND, OR, NOT) and parentheses are excluded. +func countTerms(query string) int { + if query == "" { + return 0 + } + + terms := 0 + inQuotes := false + inRange := false + currentTerm := false + + for i := 0; i < len(query); i++ { + c := query[i] - if child.Type == NodeTerm { - child.Type = NodeWildcard - child.MatchType = matchContains + if c == '\\' && i+1 < len(query) { + i++ + currentTerm = true + continue + } + + if c == '"' { + if !inQuotes { + if currentTerm { + terms++ + } + currentTerm = true + } else { + if currentTerm { + terms++ + currentTerm = false + } } + inQuotes = !inQuotes + continue } - if err != nil { - return nil, err + + if !inQuotes { + if c == '[' || c == '{' { + inRange = true + if currentTerm { + terms++ + currentTerm = false + } + continue + } + if c == ']' || c == '}' { + inRange = false + if currentTerm { + terms++ + currentTerm = false + } + continue + } } - node.Children = append(node.Children, child) - } - return node, nil -} + if c == ' ' && !inQuotes && !inRange { + if currentTerm { + terms++ + currentTerm = false + } + continue + } -func (p *Parser) createWildcardNode(field, value string) (*Node, error) { - // Skip empty fields or values - field = strings.TrimSpace(field) - value = strings.TrimSpace(value) + if !inQuotes && !inRange && (c == '(' || c == ')') { + if currentTerm { + terms++ + currentTerm = false + } + continue + } - if field == "" || value == "" { - return nil, nil + if !inQuotes && !inRange && currentTerm { + remaining := query[i:] + if strings.HasPrefix(remaining, "AND ") || strings.HasPrefix(remaining, "OR ") || + strings.HasPrefix(remaining, "NOT ") || strings.HasPrefix(remaining, "and ") || + strings.HasPrefix(remaining, "or ") || strings.HasPrefix(remaining, "not ") { + terms++ + currentTerm = false + if len(remaining) >= 3 && (remaining[0] == 'A' || remaining[0] == 'a') { + i += 3 + } else if len(remaining) >= 3 && (remaining[0] == 'N' || remaining[0] == 'n') { + i += 3 + } else { + i += 2 + } + continue + } + } + + currentTerm = true } - formattedField := p.formatFieldName(field) - - node := &Node{ - Type: NodeWildcard, - Field: formattedField, - Value: value, - } - - // Process the wildcard pattern - if strings.HasPrefix(value, "*") && strings.HasSuffix(value, "*") { - // For *term* pattern - node.MatchType = matchContains - node.Value = strings.Trim(value, "*") - } else if strings.HasPrefix(value, "*") { - // For *term pattern - node.MatchType = matchEndsWith - node.Value = strings.TrimPrefix(value, "*") - } else if strings.HasSuffix(value, "*") { - // For term* pattern - node.MatchType = matchStartsWith - node.Value = strings.TrimSuffix(value, "*") - } else if strings.Contains(value, "*") { - // For patterns like te*rm - node.MatchType = matchContains - // Replace wildcards with % for SQL LIKE - node.Value = strings.ReplaceAll(value, "*", "%") - } else { - // Default to contains match for other patterns - node.MatchType = matchContains - } - - // Skip if the value becomes empty after processing - if node.Value == "" { - return nil, nil + if currentTerm { + terms++ } - return node, nil + return terms } -func (p *Parser) formatFieldName(fieldName string) string { - if parts := strings.SplitN(fieldName, ".", 2); len(parts) == 2 { - baseField := parts[0] - subField := parts[1] +// ValidateFields returns InvalidFieldError if the query references non-existent fields. +func (p *Parser) ValidateFields(query string) error { + matches := fieldExtractPattern.FindAllStringSubmatchIndex(query, -1) + if len(matches) == 0 { + return nil + } + + validFields := p.getValidFieldNames() + + for _, match := range matches { + if len(match) < 4 { + continue + } + fieldStart := match[2] + fieldEnd := match[3] + + if isInsideQuotes(query, fieldStart) { + continue + } + + fieldName := query[fieldStart:fieldEnd] - for _, field := range p.DefaultFields { - if field.IsJSONB && field.Name == baseField { - return fmt.Sprintf("%s->>'%s'", baseField, subField) + if err := p.validateFieldName(fieldName); err != nil { + return &InvalidFieldError{ + Field: fieldName, + ValidFields: validFields, } } } - return fieldName -} -func (p *Parser) createTermNode(field, value string) (*Node, error) { - field = strings.TrimSpace(field) - value = strings.TrimSpace(value) + return nil +} - if field == "" || value == "" { - return nil, nil +func isInsideQuotes(query string, pos int) bool { + inQuotes := false + for i := 0; i < pos && i < len(query); i++ { + c := query[i] + if c == '\\' && i+1 < len(query) { + i++ + continue + } + if c == '"' { + inQuotes = !inQuotes + } } - formattedField := p.formatFieldName(field) + return inQuotes +} - trimmedValue := strings.TrimSpace(strings.Trim(value, `"`)) +// validateFieldName validates both simple fields (name) and JSONB sub-fields (labels.category). +func (p *Parser) validateFieldName(fieldName string) error { + if strings.Contains(fieldName, ".") { + parts := strings.SplitN(fieldName, ".", 2) + if len(parts) != 2 { + return fmt.Errorf("invalid field format: %s", fieldName) + } - // Skip if the value becomes empty after trimming - if trimmedValue == "" { - return nil, nil + baseField := parts[0] + + if !p.jsonbFields[baseField] { + if _, exists := p.fieldMap[baseField]; !exists { + return fmt.Errorf("field '%s' does not exist", baseField) + } + return fmt.Errorf("field '%s' is not a JSONB field; cannot use sub-field notation", baseField) + } + + return nil } - node := &Node{ - Type: NodeTerm, - Field: formattedField, - Value: strings.Trim(value, `"`), + if _, exists := p.fieldMap[fieldName]; !exists { + return fmt.Errorf("field '%s' does not exist", fieldName) } - if strings.Contains(value, "*") || strings.Contains(value, "?") { - node.Type = NodeWildcard + return nil +} - // Determine the match type based on wildcard position - if strings.HasPrefix(value, "*") && strings.HasSuffix(value, "*") { - node.MatchType = matchContains - node.Value = strings.Trim(value, "*") - } else if strings.HasPrefix(value, "*") { - node.MatchType = matchEndsWith - node.Value = strings.TrimPrefix(value, "*") - } else if strings.HasSuffix(value, "*") { - node.MatchType = matchStartsWith - node.Value = strings.TrimSuffix(value, "*") +func (p *Parser) getValidFieldNames() []string { + var names []string + for _, f := range p.Fields { + if f.IsJSONB { + names = append(names, f.Name+".*") } else { - // For patterns like te*rm or te?rm - node.MatchType = matchContains - // For SQL LIKE, convert * to % and ? to _ - node.Value = strings.ReplaceAll(strings.ReplaceAll(value, "*", "%"), "?", "_") + names = append(names, f.Name) } + } + return names +} - // Skip if the value becomes empty after processing wildcards - if node.Value == "" { - return nil, nil +func (p *Parser) getImplicitSearchFields() []FieldInfo { + var fields []FieldInfo + for _, field := range p.Fields { + if field.ImplicitSearch { + fields = append(fields, field) } } - - return node, nil + return fields } -func (p *Parser) createLogicalNode(op LogicalOperator, parts []string) (*Node, error) { - node := &Node{ - Type: NodeLogical, - Operator: op, +// isImplicitTerm returns true if token is a search term without an explicit field prefix. +func isImplicitTerm(token string) bool { + token = strings.TrimSpace(token) + if token == "" { + return false } - for _, part := range parts { - if strings.TrimSpace(part) == "" { - continue - } - child, err := p.parse(part) - if err != nil { - return nil, err - } - if child != nil { - node.Children = append(node.Children, child) + // Check if it's a boolean operator + if booleanOperators.MatchString(token) { + return false + } + + // Check if it starts with + or - (required/prohibited operators) + if strings.HasPrefix(token, "+") || strings.HasPrefix(token, "-") { + // Remove the prefix and check the rest + rest := token[1:] + if fieldValuePattern.MatchString(rest) { + return false // It's a +field:value or -field:value } + // Otherwise it's an implicit term with +/- modifier + return true } - // If no valid children were found, return nil - if len(node.Children) == 0 { - return nil, nil + // Check if it's a field:value pattern + if fieldValuePattern.MatchString(token) { + return false } - return node, nil -} + // Check if it's a range query + if rangePattern.MatchString(token) { + return false + } -func (p *Parser) nodeToMap(node *Node) map[string]any { - if node == nil { - return nil + // Check if it's a parenthesis + if token == "(" || token == ")" { + return false } - switch node.Type { - case NodeTerm: - return map[string]any{node.Field: node.Value} - case NodeWildcard: - return map[string]any{node.Field: map[string]string{ - "$like": wildcardToPattern(node.Value, node.MatchType), - }} - case NodeLogical: - result := make(map[string]any) - children := make([]map[string]any, 0, len(node.Children)) - for _, child := range node.Children { - children = append(children, p.nodeToMap(child)) - } - result[string(node.Operator)] = children - return result + // Quoted strings are also implicit terms (they search across implicit search fields) + if strings.HasPrefix(token, `"`) && strings.HasSuffix(token, `"`) { + return true } - return nil + + return true } -func (p *Parser) nodeToSQL(node *Node) (string, []any, error) { - if node == nil { - return "", nil, nil - } +// expandImplicitTerms expands implicit search terms to explicit field:value patterns +// across all implicit search fields. For example: +// "paint" → "(name:*paint* OR description:*paint*)" +// "paint*" → "(name:paint* OR description:paint*)" +// '"Living Room"' → '(name:"Living Room" OR description:"Living Room")' +func (p *Parser) expandImplicitTerms(query string) string { + implicitFields := p.getImplicitSearchFields() + if len(implicitFields) == 0 { + return query + } + + // Tokenize the query while preserving structure + tokens := tokenizeQuery(query) + var result []string + + for _, token := range tokens { + if isImplicitTerm(token) { + // Check if it has a +/- prefix + prefix := "" + term := token + if strings.HasPrefix(token, "+") || strings.HasPrefix(token, "-") { + prefix = string(token[0]) + term = token[1:] + } - switch node.Type { - case NodeTerm: - if strings.Contains(node.Field, "->>") { - return fmt.Sprintf("%s = ?", node.Field), []any{node.Value}, nil - } - return fmt.Sprintf("%s = ?", node.Field), []any{node.Value}, nil - case NodeWildcard: - pattern := wildcardToPattern(node.Value, node.MatchType) - if strings.Contains(node.Field, "->>") { - return fmt.Sprintf("%s ILIKE ?", node.Field), []any{pattern}, nil + // Check if it's a quoted phrase (exact match) or already has wildcards + searchTerm := term + isQuotedPhrase := strings.HasPrefix(term, `"`) && strings.HasSuffix(term, `"`) + hasWildcards := strings.Contains(term, "*") || strings.Contains(term, "?") + + // For implicit search without wildcards or quotes, use contains matching + // This provides a better user experience for simple searches + if !isQuotedPhrase && !hasWildcards { + searchTerm = "*" + term + "*" + } + + // Expand to all implicit search fields with OR + var fieldTerms []string + for _, field := range implicitFields { + fieldTerms = append(fieldTerms, fmt.Sprintf("%s:%s", field.Name, searchTerm)) + } + + if len(fieldTerms) == 1 { + result = append(result, prefix+fieldTerms[0]) + } else { + expanded := "(" + strings.Join(fieldTerms, " OR ") + ")" + if prefix != "" { + expanded = prefix + expanded + } + result = append(result, expanded) + } } else { - return fmt.Sprintf("%s::text ILIKE ?", node.Field), []any{pattern}, nil + result = append(result, token) } - case NodeLogical: - var parts []string - var params []any + } - for _, child := range node.Children { - sqlPart, childParams, err := p.nodeToSQL(child) - if err != nil { - return "", nil, err + return strings.Join(result, " ") +} + +// tokenizeQuery splits query into tokens, preserving quoted strings and range brackets. +func tokenizeQuery(query string) []string { + var tokens []string + var current strings.Builder + inQuotes := false + inRange := false + rangeDepth := 0 + + for i := 0; i < len(query); i++ { + c := query[i] + + // Handle quotes + if c == '"' && (i == 0 || query[i-1] != '\\') { + inQuotes = !inQuotes + current.WriteByte(c) + continue + } + + // Handle range brackets + if !inQuotes { + if c == '[' || c == '{' { + inRange = true + rangeDepth++ + current.WriteByte(c) + continue } - if sqlPart != "" { - parts = append(parts, sqlPart) - params = append(params, childParams...) + if c == ']' || c == '}' { + current.WriteByte(c) + rangeDepth-- + if rangeDepth == 0 { + inRange = false + } + continue } } - if len(parts) == 0 { - return "", nil, nil + // Handle spaces (token separators) when not in quotes or range + if c == ' ' && !inQuotes && !inRange { + if current.Len() > 0 { + tokens = append(tokens, current.String()) + current.Reset() + } + continue } - if len(parts) == 1 { - return parts[0], params, nil + // Handle parentheses as separate tokens + if !inQuotes && !inRange && (c == '(' || c == ')') { + if current.Len() > 0 { + tokens = append(tokens, current.String()) + current.Reset() + } + tokens = append(tokens, string(c)) + continue } - operator := string(node.Operator) - if node.Negate { - operator = "NOT " + operator - } + current.WriteByte(c) + } - return fmt.Sprintf("(%s)", strings.Join(parts, fmt.Sprintf(" %s ", operator))), params, nil + if current.Len() > 0 { + tokens = append(tokens, current.String()) } - return "", nil, fmt.Errorf("unsupported node type") + return tokens } -func (p *Parser) ParseToDynamoDBPartiQL(query string) (string, []types.AttributeValue, error) { - slog.Debug(fmt.Sprintf(`Parsing query to DynamoDB PartiQL: %s`, query)) - node, err := p.parse(query) - if err != nil { - return "", nil, err +// parseWithImplicitSearch expands unfielded terms across all implicit search fields with OR. +func (p *Parser) parseWithImplicitSearch(query string) (*expr.Expression, error) { + query = strings.TrimSpace(query) + if query == "" { + return nil, nil } - return p.nodeToDynamoDBPartiQL(node) + + // Expand implicit terms to explicit field:value patterns + expandedQuery := p.expandImplicitTerms(query) + + slog.Debug("Query expansion", "original", query, "expanded", expandedQuery) + + // Get first implicit field as fallback for the parser + fallbackField := "" + implicitFields := p.getImplicitSearchFields() + if len(implicitFields) > 0 { + fallbackField = implicitFields[0].Name + } else if len(p.Fields) > 0 { + fallbackField = p.Fields[0].Name + } + + return lucene.Parse(expandedQuery, lucene.WithDefaultField(fallbackField)) } -func (p *Parser) nodeToDynamoDBPartiQL(node *Node) (string, []types.AttributeValue, error) { - if node == nil { - return "", nil, nil - } - - switch node.Type { - case NodeTerm: - // For term node, create an exact match condition - return fmt.Sprintf("%s = ?", node.Field), []types.AttributeValue{ - &types.AttributeValueMemberS{Value: node.Value}, - }, nil - case NodeWildcard: - // For wildcard node, use begins_with or contains based on the match type - switch node.MatchType { - case matchStartsWith: - return fmt.Sprintf("begins_with(%s, ?)", node.Field), []types.AttributeValue{ - &types.AttributeValueMemberS{Value: node.Value}, - }, nil - case matchEndsWith, matchContains: - return fmt.Sprintf("contains(%s, ?)", node.Field), []types.AttributeValue{ - &types.AttributeValueMemberS{Value: node.Value}, - }, nil - default: - return fmt.Sprintf("%s = ?", node.Field), []types.AttributeValue{ - &types.AttributeValueMemberS{Value: node.Value}, - }, nil - } - case NodeLogical: - // For logical node, combine conditions with appropriate operator - var parts []string - var params []types.AttributeValue - - for _, child := range node.Children { - part, childParams, err := p.nodeToDynamoDBPartiQL(child) - if err != nil { - return "", nil, err - } - if part != "" { - parts = append(parts, part) - params = append(params, childParams...) - } - } +// exprToMap converts expression to map format (legacy, kept for backward compatibility). +func (p *Parser) exprToMap(e *expr.Expression) map[string]any { + if e == nil { + return nil + } - if len(parts) == 0 { - return "", nil, nil - } + result := make(map[string]any) - operator := string(node.Operator) - if node.Negate { - operator = "NOT " + operator + switch e.Op { + case expr.Equals: + if col, ok := e.Left.(expr.Column); ok { + result[string(col)] = p.valueToAny(e.Right) + } + case expr.Like: + if col, ok := e.Left.(expr.Column); ok { + pattern := p.valueToAny(e.Right) + result[string(col)] = map[string]any{"$like": pattern} + } + case expr.And, expr.Or, expr.Not: + var children []map[string]any + if leftExpr, ok := e.Left.(*expr.Expression); ok { + children = append(children, p.exprToMap(leftExpr)) + } + if rightExpr, ok := e.Right.(*expr.Expression); ok { + children = append(children, p.exprToMap(rightExpr)) + } + result[e.Op.String()] = children + default: + // For other operators, do a simple conversion + if col, ok := e.Left.(expr.Column); ok { + result[string(col)] = p.valueToAny(e.Right) } - - return fmt.Sprintf("(%s)", strings.Join(parts, fmt.Sprintf(" %s ", operator))), params, nil } - return "", nil, fmt.Errorf("unsupported node type") + return result } -func wildcardToPattern(value string, matchType MatchType) string { - switch matchType { - case matchStartsWith: - return value + "%" - case matchEndsWith: - return "%" + value - case matchContains: - return "%" + value + "%" +func (p *Parser) valueToAny(v any) any { + switch val := v.(type) { + case *expr.Expression: + return p.exprToMap(val) + case string: + return val + case int, float64: + return val default: - return value + return fmt.Sprintf("%v", v) } } diff --git a/storage/search/lucene/parser_test.go b/storage/search/lucene/parser_test.go new file mode 100644 index 0000000..12f0b6f --- /dev/null +++ b/storage/search/lucene/parser_test.go @@ -0,0 +1,880 @@ +package lucene + +import ( + "fmt" + "strings" + "testing" +) + +// TestBasicFieldSearch tests basic field:value queries +func TestBasicFieldSearch(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", IsJSONB: false}, + {Name: "email", IsJSONB: false}, + } + parser := NewParser(fields) + + tests := []struct { + name string + query string + wantSQL string + wantVals int + }{ + { + name: "simple field query", + query: "name:john", + wantSQL: `"name" = $1`, + wantVals: 1, + }, + { + name: "wildcard prefix", + query: "name:john*", + wantSQL: `"name"::text ILIKE $1`, + wantVals: 1, + }, + { + name: "wildcard suffix", + query: "name:*john", + wantSQL: `"name"::text ILIKE $1`, + wantVals: 1, + }, + { + name: "wildcard contains", + query: "name:*john*", + wantSQL: `"name"::text ILIKE $1`, + wantVals: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, vals, err := parser.ParseToSQL(tt.query) + if err != nil { + t.Fatalf("ParseToSQL() error = %v", err) + } + if !strings.Contains(sql, tt.wantSQL) { + t.Errorf("ParseToSQL() sql = %v, want to contain %v", sql, tt.wantSQL) + } + if len(vals) != tt.wantVals { + t.Errorf("ParseToSQL() vals count = %v, want %v", len(vals), tt.wantVals) + } + }) + } +} + +// TestBooleanOperators tests AND, OR, NOT operators +func TestBooleanOperators(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", IsJSONB: false}, + {Name: "status", IsJSONB: false}, + {Name: "role", IsJSONB: false}, + } + parser := NewParser(fields) + + tests := []struct { + name string + query string + wantSQL []string + }{ + { + name: "AND operator", + query: "name:john AND status:active", + wantSQL: []string{`"name"`, `"status"`, "AND"}, + }, + { + name: "OR operator", + query: "name:john OR name:jane", + wantSQL: []string{`"name"`, "OR"}, + }, + { + name: "NOT operator", + query: "name:john NOT status:inactive", + wantSQL: []string{`"name"`, `"status"`, "NOT"}, + }, + { + name: "complex nested", + query: "(name:john OR name:jane) AND status:active", + wantSQL: []string{`"name"`, `"status"`, "OR", "AND"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, _, err := parser.ParseToSQL(tt.query) + if err != nil { + t.Fatalf("ParseToSQL() error = %v", err) + } + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("ParseToSQL() sql = %v, want to contain %v", sql, want) + } + } + }) + } +} + +// TestRequiredProhibited tests + and - operators +func TestRequiredProhibited(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", IsJSONB: false}, + {Name: "status", IsJSONB: false}, + } + parser := NewParser(fields) + + tests := []struct { + name string + query string + wantSQL []string + }{ + { + name: "required term", + query: "+name:john", + wantSQL: []string{`"name"`}, + }, + { + name: "prohibited term", + query: "-status:inactive", + wantSQL: []string{`"status"`, "NOT"}, + }, + { + name: "mixed required and prohibited", + query: "+name:john -status:inactive", + wantSQL: []string{`"name"`, `"status"`, "NOT"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, _, err := parser.ParseToSQL(tt.query) + if err != nil { + t.Fatalf("ParseToSQL() error = %v", err) + } + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("ParseToSQL() sql = %v, want to contain %v", sql, want) + } + } + }) + } +} + +// TestRangeQueries tests range query syntax +func TestRangeQueries(t *testing.T) { + fields := []FieldInfo{ + {Name: "age", IsJSONB: false}, + {Name: "date", IsJSONB: false}, + } + parser := NewParser(fields) + + tests := []struct { + name string + query string + wantSQL []string + }{ + { + name: "inclusive range", + query: "age:[18 TO 65]", + wantSQL: []string{`"age" BETWEEN`}, + }, + { + name: "exclusive range", + query: "age:{18 TO 65}", + wantSQL: []string{`"age" >`, `"age" <`}, + }, + { + name: "open-ended range min", + query: "age:[18 TO *]", + wantSQL: []string{`"age" >=`}, + }, + { + name: "open-ended range max", + query: "age:[* TO 65]", + wantSQL: []string{`"age" <=`}, + }, + { + name: "date range", + query: "date:[2020-01-01 TO 2023-12-31]", + wantSQL: []string{`"date"`}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, _, err := parser.ParseToSQL(tt.query) + if err != nil { + t.Fatalf("ParseToSQL() error = %v", err) + } + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("ParseToSQL() sql = %v, want to contain %v", sql, want) + } + } + }) + } +} + +// TestQuotedPhrases tests quoted phrase handling +func TestQuotedPhrases(t *testing.T) { + fields := []FieldInfo{ + {Name: "description", IsJSONB: false}, + {Name: "title", IsJSONB: false}, + } + parser := NewParser(fields) + + tests := []struct { + name string + query string + wantSQL []string + }{ + { + name: "simple quoted phrase", + query: `description:"hello world"`, + wantSQL: []string{`"description"`}, + }, + { + name: "phrase with special chars", + query: `title:"Go: The Complete Guide"`, + wantSQL: []string{`"title"`}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, _, err := parser.ParseToSQL(tt.query) + if err != nil { + t.Fatalf("ParseToSQL() error = %v", err) + } + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("ParseToSQL() sql = %v, want to contain %v", sql, want) + } + } + }) + } +} + +// TestEscapedCharacters tests escaped character handling +func TestEscapedCharacters(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", IsJSONB: false}, + } + parser := NewParser(fields) + + tests := []struct { + name string + query string + wantSQL []string + }{ + { + name: "escaped colon", + query: `name:test\:value`, + wantSQL: []string{`"name"`}, + }, + { + name: "escaped plus", + query: `name:C\+\+`, + wantSQL: []string{`"name"`}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, _, err := parser.ParseToSQL(tt.query) + if err != nil { + t.Fatalf("ParseToSQL() error = %v", err) + } + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("ParseToSQL() sql = %v, want to contain %v", sql, want) + } + } + }) + } +} + +// TestComplexQueries tests complex query combinations +func TestComplexQueries(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", IsJSONB: false}, + {Name: "age", IsJSONB: false}, + {Name: "status", IsJSONB: false}, + {Name: "email", IsJSONB: false}, + } + parser := NewParser(fields) + + tests := []struct { + name string + query string + wantSQL []string + shouldErr bool + }{ + { + name: "complex with ranges and wildcards", + query: "name:john* AND age:[25 TO 65]", + wantSQL: []string{`"name"`, `"age"`}, + shouldErr: false, + }, + { + name: "complex with required and prohibited", + query: "+name:john -status:inactive AND age:[30 TO *]", + wantSQL: []string{`"name"`, `"status"`, `"age"`}, + shouldErr: false, + }, + { + name: "complex with quoted phrases", + query: `name:"John Doe" AND (status:active OR status:pending)`, + wantSQL: []string{`"name"`, `"status"`}, + shouldErr: false, + }, + { + name: "complex nested query", + query: "((name:john OR name:jane) AND status:active) OR (age:[18 TO 25] AND status:pending)", + wantSQL: []string{`"name"`, `"status"`, `"age"`}, + shouldErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, _, err := parser.ParseToSQL(tt.query) + if tt.shouldErr { + if err == nil { + t.Errorf("ParseToSQL() expected error but got none") + } + return + } + if err != nil { + t.Fatalf("ParseToSQL() error = %v", err) + } + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("ParseToSQL() sql = %v, want to contain %v", sql, want) + } + } + }) + } +} + +// TestImplicitSearch tests implicit search across fields with ImplicitSearch=true +func TestImplicitSearch(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", IsJSONB: false, ImplicitSearch: true}, + {Name: "email", IsJSONB: false, ImplicitSearch: true}, + {Name: "description", IsJSONB: false, ImplicitSearch: true}, + } + parser := NewParser(fields) + + tests := []struct { + name string + query string + wantOR bool + wantParams int + }{ + { + name: "implicit search", + query: "john", + wantOR: true, + wantParams: 3, // Should expand to 3 fields + }, + { + name: "implicit search with wildcard", + query: "john*", + wantOR: true, + wantParams: 3, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, params, err := parser.ParseToSQL(tt.query) + if err != nil { + t.Fatalf("ParseToSQL() error = %v", err) + } + if tt.wantOR && !strings.Contains(sql, "OR") { + t.Errorf("ParseToSQL() sql = %v, want to contain OR", sql) + } + if len(params) != tt.wantParams { + t.Errorf("ParseToSQL() params count = %v, want %v", len(params), tt.wantParams) + } + }) + } +} + +// TestJSONBFields tests JSONB field notation +func TestJSONBFields(t *testing.T) { + fields := []FieldInfo{ + {Name: "metadata", IsJSONB: true}, + } + parser := NewParser(fields) + + tests := []struct { + name string + query string + wantSQL []string + }{ + { + name: "JSONB field access", + query: "metadata.key:value", + wantSQL: []string{`metadata->>'key'`}, + }, + { + name: "JSONB with wildcard", + query: "metadata.tags:prod*", + wantSQL: []string{`metadata->>'tags'`, "ILIKE"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, _, err := parser.ParseToSQL(tt.query) + if err != nil { + t.Fatalf("ParseToSQL() error = %v", err) + } + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("ParseToSQL() sql = %v, want to contain %v", sql, want) + } + } + }) + } +} + +// TestMapOutput tests the legacy map output format +func TestMapOutput(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", IsJSONB: false}, + {Name: "status", IsJSONB: false}, + } + parser := NewParser(fields) + + result, err := parser.ParseToMap("name:john AND status:active") + if err != nil { + t.Fatalf("ParseToMap() error = %v", err) + } + + if result == nil { + t.Errorf("ParseToMap() returned nil") + } +} + +// TestFieldValidation tests field validation for invalid field references +func TestFieldValidation(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", IsJSONB: false, ImplicitSearch: true}, + {Name: "description", IsJSONB: false, ImplicitSearch: true}, + {Name: "status", IsJSONB: false}, + {Name: "labels", IsJSONB: true}, + {Name: "metadata", IsJSONB: true}, + } + parser := NewParser(fields) + + tests := []struct { + name string + query string + wantErr bool + errField string + }{ + { + name: "valid field query", + query: "name:john", + wantErr: false, + }, + { + name: "valid JSONB sub-field", + query: "labels.category:urgent", + wantErr: false, + }, + { + name: "invalid field", + query: "nonexistent:value", + wantErr: true, + errField: "nonexistent", + }, + { + name: "invalid JSONB base field", + query: "fakejsonb.key:value", + wantErr: true, + errField: "fakejsonb.key", + }, + { + name: "sub-field on non-JSONB field", + query: "name.subfield:value", + wantErr: true, + errField: "name.subfield", + }, + { + name: "implicit search (no explicit fields) - valid", + query: "paint", + wantErr: false, + }, + { + name: "mixed valid and implicit", + query: "status:active AND paint", + wantErr: false, + }, + { + name: "mixed valid and invalid", + query: "name:john AND invalid_field:test", + wantErr: true, + errField: "invalid_field", + }, + { + name: "complex valid query", + query: "(name:john OR description:test) AND status:active AND labels.priority:high", + wantErr: false, + }, + { + name: "invalid field in complex query", + query: "(name:john OR badfield:test) AND status:active", + wantErr: true, + errField: "badfield", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, _, err := parser.ParseToSQL(tt.query) + if tt.wantErr { + if err == nil { + t.Errorf("ParseToSQL() expected error for query %q but got none", tt.query) + return + } + if _, ok := err.(*InvalidFieldError); !ok { + t.Errorf("ParseToSQL() error = %v, want InvalidFieldError", err) + return + } + if !strings.Contains(err.Error(), tt.errField) { + t.Errorf("ParseToSQL() error = %v, want to mention field %q", err, tt.errField) + } + } else { + if err != nil { + t.Errorf("ParseToSQL() unexpected error = %v for query %q", err, tt.query) + } + } + }) + } +} + +// TestValidateFields tests the ValidateFields method directly +func TestValidateFields(t *testing.T) { + fields := []FieldInfo{ + {Name: "id", IsJSONB: false}, + {Name: "tenant_id", IsJSONB: false}, + {Name: "name", IsJSONB: false, ImplicitSearch: true}, + {Name: "description", IsJSONB: false, ImplicitSearch: true}, + {Name: "status", IsJSONB: false}, + {Name: "labels", IsJSONB: true}, + {Name: "properties", IsJSONB: true}, + {Name: "metadata", IsJSONB: true}, + } + parser := NewParser(fields) + + tests := []struct { + name string + query string + wantErr bool + }{ + {"valid simple field", "name:test", false}, + {"valid multiple fields", "name:test AND status:active", false}, + {"valid JSONB sub-field", "labels.category:urgent", false}, + {"valid deep JSONB", "metadata.nested_key:value", false}, + {"invalid field", "unknown_field:test", true}, + {"invalid JSONB base", "unknown.subkey:test", true}, + {"sub-field on non-JSONB", "status.sub:test", true}, + {"empty query", "", false}, + {"implicit only - no field prefix", "searchterm", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := parser.ValidateFields(tt.query) + if tt.wantErr && err == nil { + t.Errorf("ValidateFields(%q) expected error but got none", tt.query) + } + if !tt.wantErr && err != nil { + t.Errorf("ValidateFields(%q) unexpected error: %v", tt.query, err) + } + }) + } +} + +// TestNullValueQueries tests null value handling for IS NULL queries. +// Note: This is a SQL-specific extension (vanilla Lucene doesn't support NULL values). +// Only "null" (case-insensitive) is supported for IS NULL queries; "nil" is treated as a literal string. +func TestNullValueQueries(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", IsJSONB: false}, + {Name: "parent_id", IsJSONB: false}, + {Name: "deleted_at", IsJSONB: false}, + {Name: "attachment_ids", IsJSONB: false}, + } + parser := NewParser(fields) + + tests := []struct { + name string + query string + wantSQL string + wantErr bool + }{ + { + name: "field is null (lowercase)", + query: "deleted_at:null", + wantSQL: "IS NULL", + }, + { + name: "field is NULL (uppercase)", + query: "deleted_at:NULL", + wantSQL: "IS NULL", + }, + { + name: "field is Null (mixed case)", + query: "deleted_at:Null", + wantSQL: "IS NULL", + }, + { + name: "parent_id is null", + query: "parent_id:null", + wantSQL: "IS NULL", + }, + { + name: "combined null with other conditions", + query: "deleted_at:null AND name:john", + wantSQL: "IS NULL", + }, + { + name: "NOT null (is not null)", + query: "NOT deleted_at:null", + wantSQL: "NOT", + }, + { + name: "nil should be treated as literal value (not NULL)", + query: "name:nil", + wantSQL: "=", + wantErr: false, // Should not error, but should treat "nil" as literal string, not IS NULL + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, _, err := parser.ParseToSQL(tt.query) + if tt.wantErr { + if err == nil { + t.Errorf("ParseToSQL(%q) expected error but got none", tt.query) + } + return + } + if err != nil { + t.Fatalf("ParseToSQL(%q) error = %v", tt.query, err) + } + if !strings.Contains(sql, tt.wantSQL) { + t.Errorf("ParseToSQL(%q) sql = %v, want to contain %v", tt.query, sql, tt.wantSQL) + } + }) + } +} + +// TestEmptyAsLiteralValue tests that 'empty' is treated as a literal value (not special keyword) +func TestEmptyAsLiteralValue(t *testing.T) { + fields := []FieldInfo{ + {Name: "status", IsJSONB: false}, + {Name: "name", IsJSONB: false}, + } + parser := NewParser(fields) + + // 'empty' should be treated as a regular search value, not a special keyword + sql, params, err := parser.ParseToSQL("status:empty") + if err != nil { + t.Fatalf("ParseToSQL() error = %v", err) + } + + // Should generate a regular equals query, not IS NULL + if strings.Contains(sql, "IS NULL") { + t.Errorf("'empty' should be treated as literal value, not IS NULL. Got: %s", sql) + } + + // The value should be in params + if len(params) != 1 || params[0] != "empty" { + t.Errorf("Expected params to contain 'empty', got: %v", params) + } +} + +// BenchmarkParser benchmarks the parser performance +func BenchmarkParser(b *testing.B) { + fields := []FieldInfo{ + {Name: "name", IsJSONB: false}, + {Name: "age", IsJSONB: false}, + {Name: "status", IsJSONB: false}, + {Name: "email", IsJSONB: false}, + } + parser := NewParser(fields) + + query := `(name:john* OR email:*@example.com) AND (status:active OR status:pending) AND age:[25 TO 65]` + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, _ = parser.ParseToSQL(query) + } +} + +// TestFuzzySearch tests fuzzy search operator (~) using pg_trgm similarity +func TestFuzzySearch(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", IsJSONB: false}, + {Name: "description", IsJSONB: false}, + {Name: "status", IsJSONB: false}, + {Name: "labels", IsJSONB: true}, + } + parser := NewParser(fields) + + tests := []struct { + name string + query string + wantSQL string + wantErr bool + }{ + { + name: "basic fuzzy search", + query: "name:roam~", + wantSQL: "similarity", + }, + { + name: "fuzzy with distance", + query: "name:roam~2", + wantSQL: "similarity", + }, + { + name: "fuzzy on JSONB field", + query: "labels.category:construction~", + wantSQL: "similarity", + }, + { + name: "fuzzy combined with other conditions", + query: "name:roam~ AND status:active", + wantSQL: "similarity", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, params, err := parser.ParseToSQL(tt.query) + if tt.wantErr { + if err == nil { + t.Errorf("ParseToSQL(%q) expected error but got none", tt.query) + } + return + } + if err != nil { + t.Fatalf("ParseToSQL(%q) error = %v", tt.query, err) + } + if !strings.Contains(sql, tt.wantSQL) { + t.Errorf("ParseToSQL(%q) sql = %v, want to contain %v", tt.query, sql, tt.wantSQL) + } + if len(params) == 0 { + t.Errorf("ParseToSQL(%q) expected at least one parameter", tt.query) + } + }) + } +} + +// TestEscaping tests that special characters can be escaped in queries +func TestEscaping(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", IsJSONB: false}, + {Name: "version", IsJSONB: false}, + {Name: "path", IsJSONB: false}, + } + parser := NewParser(fields) + + tests := []struct { + name string + query string + wantSQL string + wantErr bool + }{ + { + name: "escaped plus sign", + query: `name:C\+\+`, + wantSQL: `"name"`, + }, + { + name: "escaped colon", + query: `version:1\.2\.3`, + wantSQL: `"version"`, + }, + { + name: "escaped parentheses", + query: `name:\(test\)`, + wantSQL: `"name"`, + }, + { + name: "escaped path separator", + query: `path:src\/components`, + wantSQL: `"path"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, params, err := parser.ParseToSQL(tt.query) + if tt.wantErr { + if err == nil { + t.Errorf("ParseToSQL(%q) expected error but got none", tt.query) + } + return + } + if err != nil { + t.Fatalf("ParseToSQL(%q) error = %v", tt.query, err) + } + if !strings.Contains(sql, tt.wantSQL) { + t.Errorf("ParseToSQL(%q) sql = %v, want to contain %v", tt.query, sql, tt.wantSQL) + } + // Verify the escaped character is in the parameter + if len(params) > 0 { + paramStr := fmt.Sprintf("%v", params[0]) + // The escaped character should appear as the literal character in params + if strings.Contains(tt.query, `\+`) && !strings.Contains(paramStr, "+") { + t.Errorf("ParseToSQL(%q) expected '+' in params, got %v", tt.query, params) + } + } + }) + } +} + +// TestBoostOperatorError tests that boost operator returns a clear error +func TestBoostOperatorError(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", IsJSONB: false}, + {Name: "status", IsJSONB: false}, + } + parser := NewParser(fields) + + tests := []struct { + name string + query string + wantErr string + }{ + { + name: "boost operator", + query: "name:test^4", + wantErr: "boost operator", + }, + { + name: "boost in compound query", + query: "name:test^2 AND status:active", + wantErr: "boost operator", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, _, err := parser.ParseToSQL(tt.query) + if err == nil { + t.Errorf("ParseToSQL(%q) expected error but got none", tt.query) + return + } + if !strings.Contains(strings.ToLower(err.Error()), strings.ToLower(tt.wantErr)) { + t.Errorf("ParseToSQL(%q) error = %v, want to contain %v", tt.query, err, tt.wantErr) + } + }) + } +} diff --git a/storage/sql.go b/storage/sql.go index 8827f35..bfc1f49 100644 --- a/storage/sql.go +++ b/storage/sql.go @@ -18,8 +18,8 @@ import ( "gorm.io/gorm/logger" "gorm.io/gorm/schema" + serviceErrors "github.com/tink3rlabs/magic/errors" slogger "github.com/tink3rlabs/magic/logger" - "github.com/tink3rlabs/magic/storage/search/lucene" ) @@ -315,6 +315,10 @@ func (s *SQLAdapter) Search(dest any, sortKey string, query string, limit int, c whereClause, queryParams, err := parser.ParseToSQL(query) if err != nil { slog.Error("Filter parsing failed", "error", err) + // Wrap InvalidFieldError as BadRequest for proper HTTP 400 response + if _, ok := err.(*lucene.InvalidFieldError); ok { + return "", &serviceErrors.BadRequest{Message: err.Error()} + } return "", err } From 3cd17c3662575445213937191f32551bea13b819 Mon Sep 17 00:00:00 2001 From: Lutherwaves Date: Wed, 31 Dec 2025 00:15:19 +0000 Subject: [PATCH 02/13] refactor: simplify parser - Remove driver storage redundancy - Simpler and more configurable parser initi - Make security limits and tag names configurable - Clean up unused code --- storage/search/lucene/parser.go | 156 ++++++++++++++++++++++++++------ 1 file changed, 126 insertions(+), 30 deletions(-) diff --git a/storage/search/lucene/parser.go b/storage/search/lucene/parser.go index 5f95e81..79fa152 100644 --- a/storage/search/lucene/parser.go +++ b/storage/search/lucene/parser.go @@ -19,6 +19,39 @@ const ( DefaultMaxTerms = 100 // Prevents CPU exhaustion from complex queries ) +// Struct tag names for field metadata extraction +const ( + TagJSON = "json" // JSON serialization tag + TagGORM = "gorm" // GORM database tag (for detecting JSONB fields) + TagLucene = "lucene" // Lucene search behavior tag (implicit/explicit) +) + +// ParserConfig allows customization of parser behavior and security limits. +type ParserConfig struct { + // Security limits (nil = use defaults) + MaxQueryLength *int // Maximum query string length (default: 10KB) + MaxDepth *int // Maximum nesting depth (default: 20) + MaxTerms *int // Maximum number of terms (default: 100) + + // Tag customization (empty = use defaults) + JSONTag string // JSON tag name (default: "json") + GORMTag string // GORM tag name (default: "gorm") + LuceneTag string // Lucene tag name (default: "lucene") +} + +// IntPtr is a helper function to create int pointers for ParserConfig. +// This makes it easier to set optional configuration values. +// +// Example: +// +// config := &ParserConfig{ +// MaxQueryLength: IntPtr(5000), +// MaxDepth: IntPtr(10), +// } +func IntPtr(i int) *int { + return &i +} + // FieldInfo describes a searchable field and its properties. type FieldInfo struct { Name string @@ -27,6 +60,7 @@ type FieldInfo struct { } // Parser provides Lucene query parsing with security limits. +// Drivers are created on-demand when calling ParseToSQL or ParseToDynamoDBPartiQL. type Parser struct { Fields []FieldInfo // All searchable fields @@ -38,10 +72,6 @@ type Parser struct { // Field lookup maps for O(1) validation fieldMap map[string]FieldInfo // All fields by name jsonbFields map[string]bool // JSONB field names for sub-field validation - - // Custom drivers for different backends - postgresDriver *PostgresJSONBDriver - dynamoDriver *DynamoDBPartiQLDriver } // NewParserFromType creates a parser by introspecting a struct's fields. @@ -65,6 +95,14 @@ type Parser struct { // // parser, err := lucene.NewParserFromType(Task{}) // +// With custom configuration: +// +// config := &lucene.ParserConfig{ +// MaxQueryLength: lucene.IntPtr(5000), +// MaxDepth: lucene.IntPtr(10), +// } +// parser, err := lucene.NewParserFromType(Task{}, config) +// // Struct tag controls: // - lucene:"implicit" - Force ImplicitSearch=true (include in unfielded queries) // - lucene:"explicit" - Force ImplicitSearch=false (require field:value syntax) @@ -74,15 +112,39 @@ type Parser struct { // - String fields: ImplicitSearch=true (included in unfielded queries) // - Non-string fields (int, time.Time, uuid, etc.): ImplicitSearch=false // - JSONB fields: ImplicitSearch=false (require field.subfield syntax) -func NewParserFromType(model any) (*Parser, error) { - fields, err := getStructFields(model) +func NewParserFromType(model any, config ...*ParserConfig) (*Parser, error) { + var cfg *ParserConfig + if len(config) > 0 && config[0] != nil { + cfg = config[0] + } + + fields, err := getStructFieldsWithConfig(model, cfg) if err != nil { return nil, err } - return NewParser(fields), nil + return NewParser(fields, cfg), nil } -func NewParser(fields []FieldInfo) *Parser { +// NewParser creates a parser from field definitions with optional configuration. +// +// Basic usage: +// +// fields := []FieldInfo{{Name: "name", ImplicitSearch: true}} +// parser := lucene.NewParser(fields) +// +// With custom configuration: +// +// config := &lucene.ParserConfig{ +// MaxQueryLength: lucene.IntPtr(5000), +// MaxDepth: lucene.IntPtr(10), +// } +// parser := lucene.NewParser(fields, config) +func NewParser(fields []FieldInfo, config ...*ParserConfig) *Parser { + var cfg *ParserConfig + if len(config) > 0 && config[0] != nil { + cfg = config[0] + } + fieldMap := make(map[string]FieldInfo, len(fields)) jsonbFields := make(map[string]bool) for _, f := range fields { @@ -92,15 +154,30 @@ func NewParser(fields []FieldInfo) *Parser { } } + // Apply config or use defaults + maxQueryLength := DefaultMaxQueryLength + maxDepth := DefaultMaxDepth + maxTerms := DefaultMaxTerms + + if cfg != nil { + if cfg.MaxQueryLength != nil { + maxQueryLength = *cfg.MaxQueryLength + } + if cfg.MaxDepth != nil { + maxDepth = *cfg.MaxDepth + } + if cfg.MaxTerms != nil { + maxTerms = *cfg.MaxTerms + } + } + return &Parser{ Fields: fields, - MaxQueryLength: DefaultMaxQueryLength, - MaxDepth: DefaultMaxDepth, - MaxTerms: DefaultMaxTerms, + MaxQueryLength: maxQueryLength, + MaxDepth: maxDepth, + MaxTerms: maxTerms, fieldMap: fieldMap, jsonbFields: jsonbFields, - postgresDriver: NewPostgresJSONBDriver(fields), - dynamoDriver: NewDynamoDBPartiQLDriver(fields), } } @@ -126,9 +203,8 @@ func (e *InvalidFieldError) Error() string { return fmt.Sprintf("invalid field '%s' in query; valid fields are: %s", e.Field, strings.Join(e.ValidFields, ", ")) } -// getStructFields uses reflection to extract field metadata from a struct. -// String fields get ImplicitSearch=true, others get ImplicitSearch=false. -func getStructFields(model any) ([]FieldInfo, error) { +// getStructFieldsWithConfig extracts field metadata using configurable tag names. +func getStructFieldsWithConfig(model any, config *ParserConfig) ([]FieldInfo, error) { t := reflect.TypeOf(model) if t.Kind() == reflect.Ptr { t = t.Elem() @@ -138,31 +214,47 @@ func getStructFields(model any) ([]FieldInfo, error) { return nil, fmt.Errorf("expected struct, got %s", t.Kind()) } + // Determine tag names from config or use defaults + jsonTag := TagJSON + gormTag := TagGORM + luceneTag := TagLucene + if config != nil { + if config.JSONTag != "" { + jsonTag = config.JSONTag + } + if config.GORMTag != "" { + gormTag = config.GORMTag + } + if config.LuceneTag != "" { + luceneTag = config.LuceneTag + } + } + var fields []FieldInfo for i := 0; i < t.NumField(); i++ { field := t.Field(i) - jsonTag := field.Tag.Get("json") - if jsonTag == "" || jsonTag == "-" { + jsonTagValue := field.Tag.Get(jsonTag) + if jsonTagValue == "" || jsonTagValue == "-" { continue } - if commaIdx := strings.Index(jsonTag, ","); commaIdx != -1 { - jsonTag = jsonTag[:commaIdx] + if commaIdx := strings.Index(jsonTagValue, ","); commaIdx != -1 { + jsonTagValue = jsonTagValue[:commaIdx] } - gormTag := field.Tag.Get("gorm") - isJSONB := strings.Contains(gormTag, "type:jsonb") + gormTagValue := field.Tag.Get(gormTag) + isJSONB := strings.Contains(gormTagValue, "type:jsonb") - luceneTag := field.Tag.Get("lucene") + luceneTagValue := field.Tag.Get(luceneTag) implicitSearch := false - if luceneTag == "implicit" { + if luceneTagValue == "implicit" { implicitSearch = true - } else if luceneTag != "explicit" { + } else if luceneTagValue != "explicit" { implicitSearch = field.Type.Kind() == reflect.String && !isJSONB } fields = append(fields, FieldInfo{ - Name: jsonTag, + Name: jsonTagValue, IsJSONB: isJSONB, ImplicitSearch: implicitSearch, }) @@ -189,6 +281,7 @@ func (p *Parser) ParseToMap(query string) (map[string]any, error) { } // ParseToSQL parses a Lucene query and converts it to PostgreSQL SQL with parameters. +// Creates a PostgreSQL driver on-demand for rendering. func (p *Parser) ParseToSQL(query string) (string, []any, error) { slog.Debug(fmt.Sprintf(`Parsing query to SQL: %s`, query)) @@ -210,8 +303,9 @@ func (p *Parser) ParseToSQL(query string) (string, []any, error) { return "", nil, err } - // Render using custom PostgreSQL driver - sql, params, err := p.postgresDriver.RenderParam(e) + // Create PostgreSQL driver on-demand and render + driver := NewPostgresJSONBDriver(p.Fields) + sql, params, err := driver.RenderParam(e) if err != nil { return "", nil, err } @@ -220,6 +314,7 @@ func (p *Parser) ParseToSQL(query string) (string, []any, error) { } // ParseToDynamoDBPartiQL parses a Lucene query and converts it to DynamoDB PartiQL. +// Creates a DynamoDB driver on-demand for rendering. func (p *Parser) ParseToDynamoDBPartiQL(query string) (string, []types.AttributeValue, error) { slog.Debug(fmt.Sprintf(`Parsing query to DynamoDB PartiQL: %s`, query)) @@ -241,8 +336,9 @@ func (p *Parser) ParseToDynamoDBPartiQL(query string) (string, []types.Attribute return "", nil, err } - // Render using custom DynamoDB driver - partiql, attrs, err := p.dynamoDriver.RenderPartiQL(e) + // Create DynamoDB driver on-demand and render + driver := NewDynamoDBPartiQLDriver(p.Fields) + partiql, attrs, err := driver.RenderPartiQL(e) if err != nil { return "", nil, err } From 0cba321661b55b5d5d37d172fce43645c29fc30d Mon Sep 17 00:00:00 2001 From: Lutherwaves Date: Wed, 31 Dec 2025 12:16:15 +0000 Subject: [PATCH 03/13] refactor: simplify Lucene parser API and architecture - NewParser(model) replaces NewParserFromType(model) - Implicit search restricted to string fields only - Removed complex tag configuration - Split driver.go into postgres_driver.go and dynamodb_driver.go - NewPostgresDriver() and NewDynamoDBDriver() constructors - Checks for JSONB/JSON in type name, maps, and structs - Parse-time validation (HTTP 400) not runtime (HTTP 500) --- go.mod | 1 - go.sum | 11 + storage/dynamodb.go | 2 +- storage/search/lucene/dynamodb_driver.go | 94 ++++ storage/search/lucene/parser.go | 272 ++++------ storage/search/lucene/parser_test.go | 493 ++++++++---------- .../lucene/{driver.go => postgres_driver.go} | 93 +--- storage/sql.go | 2 +- 8 files changed, 422 insertions(+), 546 deletions(-) create mode 100644 storage/search/lucene/dynamodb_driver.go rename storage/search/lucene/{driver.go => postgres_driver.go} (81%) diff --git a/go.mod b/go.mod index ff6e0a7..0039d51 100644 --- a/go.mod +++ b/go.mod @@ -38,7 +38,6 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/golang-jwt/jwt/v5 v5.3.0 // indirect - github.com/grindlemire/go-lucene v0.0.26 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgx/v5 v5.6.0 // indirect diff --git a/go.sum b/go.sum index 85fc238..050f7b6 100644 --- a/go.sum +++ b/go.sum @@ -117,12 +117,23 @@ github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHo github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= +<<<<<<< HEAD golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +======= +go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= +go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8= +golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= +golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= +golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= +golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +>>>>>>> d0d7bae (refactor: simplify Lucene parser API and architecture) golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= diff --git a/storage/dynamodb.go b/storage/dynamodb.go index 4774381..2cd6595 100644 --- a/storage/dynamodb.go +++ b/storage/dynamodb.go @@ -245,7 +245,7 @@ func (s *DynamoDBAdapter) Search(dest any, sortKey string, query string, limit i // Parse Lucene query destType := reflect.TypeOf(dest).Elem().Elem() model := reflect.New(destType).Elem().Interface() - parser, _ := lucene.NewParserFromType(model) + parser, _ := lucene.NewParser(model) whereClause, params, _ := parser.ParseToDynamoDBPartiQL(query) // Build query diff --git a/storage/search/lucene/dynamodb_driver.go b/storage/search/lucene/dynamodb_driver.go new file mode 100644 index 0000000..8a7cd45 --- /dev/null +++ b/storage/search/lucene/dynamodb_driver.go @@ -0,0 +1,94 @@ +package lucene + +import ( + "fmt" + "strings" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" + "github.com/grindlemire/go-lucene/pkg/driver" + "github.com/grindlemire/go-lucene/pkg/lucene/expr" +) + +// DynamoDBPartiQLDriver converts Lucene queries to DynamoDB PartiQL. +type DynamoDBPartiQLDriver struct { + driver.Base + fields map[string]FieldInfo +} + +func NewDynamoDBDriver(fields []FieldInfo) *DynamoDBPartiQLDriver { + fieldMap := make(map[string]FieldInfo) + for _, f := range fields { + fieldMap[f.Name] = f + } + + fns := map[expr.Operator]driver.RenderFN{ + expr.Literal: driver.Shared[expr.Literal], + expr.And: driver.Shared[expr.And], + expr.Or: driver.Shared[expr.Or], + expr.Not: driver.Shared[expr.Not], + expr.Equals: driver.Shared[expr.Equals], + expr.Range: driver.Shared[expr.Range], + expr.Must: driver.Shared[expr.Must], + expr.MustNot: driver.Shared[expr.MustNot], + expr.Wild: driver.Shared[expr.Wild], + expr.Regexp: driver.Shared[expr.Regexp], + expr.Like: dynamoDBLike, // Custom LIKE for DynamoDB functions + expr.Greater: driver.Shared[expr.Greater], + expr.GreaterEq: driver.Shared[expr.GreaterEq], + expr.Less: driver.Shared[expr.Less], + expr.LessEq: driver.Shared[expr.LessEq], + expr.In: driver.Shared[expr.In], + expr.List: driver.Shared[expr.List], + } + + return &DynamoDBPartiQLDriver{ + Base: driver.Base{ + RenderFNs: fns, + }, + fields: fieldMap, + } +} + +// RenderPartiQL renders the expression to DynamoDB PartiQL with AttributeValue parameters. +func (d *DynamoDBPartiQLDriver) RenderPartiQL(e *expr.Expression) (string, []types.AttributeValue, error) { + // Use base rendering with ? placeholders + str, params, err := d.RenderParam(e) + if err != nil { + return "", nil, err + } + + // Convert params to DynamoDB AttributeValues + attrValues := make([]types.AttributeValue, len(params)) + for i, param := range params { + attrValues[i] = &types.AttributeValueMemberS{Value: fmt.Sprintf("%v", param)} + } + + return str, attrValues, nil +} + +// dynamoDBLike implements LIKE using DynamoDB's begins_with and contains functions. +func dynamoDBLike(left, right string) (string, error) { + // Remove quotes from right side to analyze pattern + pattern := strings.Trim(right, "'") + + // Replace wildcards for analysis + hasPrefix := strings.HasPrefix(pattern, "%") + hasSuffix := strings.HasSuffix(pattern, "%") + + if hasPrefix && hasSuffix { + // %value% -> contains(field, value) + value := strings.Trim(pattern, "%") + return fmt.Sprintf("contains(%s, '%s')", left, value), nil + } else if !hasPrefix && hasSuffix { + // value% -> begins_with(field, value) + value := strings.TrimSuffix(pattern, "%") + return fmt.Sprintf("begins_with(%s, '%s')", left, value), nil + } else if hasPrefix && !hasSuffix { + // %value -> contains(field, value) (DynamoDB doesn't have ends_with) + value := strings.TrimPrefix(pattern, "%") + return fmt.Sprintf("contains(%s, '%s')", left, value), nil + } + + // Exact match + return fmt.Sprintf("%s = %s", left, right), nil +} diff --git a/storage/search/lucene/parser.go b/storage/search/lucene/parser.go index 79fa152..d750358 100644 --- a/storage/search/lucene/parser.go +++ b/storage/search/lucene/parser.go @@ -19,44 +19,18 @@ const ( DefaultMaxTerms = 100 // Prevents CPU exhaustion from complex queries ) -// Struct tag names for field metadata extraction -const ( - TagJSON = "json" // JSON serialization tag - TagGORM = "gorm" // GORM database tag (for detecting JSONB fields) - TagLucene = "lucene" // Lucene search behavior tag (implicit/explicit) -) - // ParserConfig allows customization of parser behavior and security limits. type ParserConfig struct { - // Security limits (nil = use defaults) - MaxQueryLength *int // Maximum query string length (default: 10KB) - MaxDepth *int // Maximum nesting depth (default: 20) - MaxTerms *int // Maximum number of terms (default: 100) - - // Tag customization (empty = use defaults) - JSONTag string // JSON tag name (default: "json") - GORMTag string // GORM tag name (default: "gorm") - LuceneTag string // Lucene tag name (default: "lucene") -} - -// IntPtr is a helper function to create int pointers for ParserConfig. -// This makes it easier to set optional configuration values. -// -// Example: -// -// config := &ParserConfig{ -// MaxQueryLength: IntPtr(5000), -// MaxDepth: IntPtr(10), -// } -func IntPtr(i int) *int { - return &i + MaxQueryLength int // 0 = use default (10000) + MaxDepth int // 0 = use default (20) + MaxTerms int // 0 = use default (100) } // FieldInfo describes a searchable field and its properties. type FieldInfo struct { Name string - IsJSONB bool - ImplicitSearch bool // Whether this field is included in unfielded/implicit queries + Type reflect.Type // For validation only + ImplicitSearch bool // Whether this field is included in unfielded/implicit queries } // Parser provides Lucene query parsing with security limits. @@ -70,88 +44,41 @@ type Parser struct { MaxTerms int // Maximum number of terms (default: 100) // Field lookup maps for O(1) validation - fieldMap map[string]FieldInfo // All fields by name - jsonbFields map[string]bool // JSONB field names for sub-field validation + fieldMap map[string]FieldInfo // All fields by name } -// NewParserFromType creates a parser by introspecting a struct's fields. -// This is the recommended approach for initializing parsers as it: -// - Works with any backend (PostgreSQL, MySQL, DynamoDB, etc.) -// - Zero database overhead -// - Compile-time safety -// - Auto-detects JSONB fields from gorm tags -// - Auto-sets string fields for implicit search (ImplicitSearch=true) -// -// Example: +// NewParser creates a parser by introspecting a struct's fields. // -// type Task struct { -// ID string `json:"id"` -// Name string `json:"name"` // Auto: ImplicitSearch=true -// Description string `json:"description"` // Auto: ImplicitSearch=true -// Status string `json:"status" lucene:"explicit"` // Explicit: ImplicitSearch=false -// CreatedAt time.Time `json:"created_at"` // Auto: ImplicitSearch=false (not string) -// Labels JSONB `json:"labels" gorm:"type:jsonb"` // Auto: IsJSONB=true, ImplicitSearch=false -// } +// Basic usage: // -// parser, err := lucene.NewParserFromType(Task{}) +// parser, err := lucene.NewParser(Task{}) // // With custom configuration: // // config := &lucene.ParserConfig{ -// MaxQueryLength: lucene.IntPtr(5000), -// MaxDepth: lucene.IntPtr(10), +// MaxQueryLength: 5000, +// MaxDepth: 10, // } -// parser, err := lucene.NewParserFromType(Task{}, config) -// -// Struct tag controls: -// - lucene:"implicit" - Force ImplicitSearch=true (include in unfielded queries) -// - lucene:"explicit" - Force ImplicitSearch=false (require field:value syntax) -// - gorm:"type:jsonb" - Auto-detected as JSONB field +// parser, err := lucene.NewParser(Task{}, config) // -// Auto-detection rules (when no lucene tag): +// Auto-detection rules: // - String fields: ImplicitSearch=true (included in unfielded queries) // - Non-string fields (int, time.Time, uuid, etc.): ImplicitSearch=false // - JSONB fields: ImplicitSearch=false (require field.subfield syntax) -func NewParserFromType(model any, config ...*ParserConfig) (*Parser, error) { - var cfg *ParserConfig - if len(config) > 0 && config[0] != nil { - cfg = config[0] - } - - fields, err := getStructFieldsWithConfig(model, cfg) +// +// Field name extraction: +// - Uses `json` struct tag for field names +// - Skips fields without `json` tag or with `json:"-"` +func NewParser(model any, config ...*ParserConfig) (*Parser, error) { + fields, err := extractFields(model) if err != nil { return nil, err } - return NewParser(fields, cfg), nil -} - -// NewParser creates a parser from field definitions with optional configuration. -// -// Basic usage: -// -// fields := []FieldInfo{{Name: "name", ImplicitSearch: true}} -// parser := lucene.NewParser(fields) -// -// With custom configuration: -// -// config := &lucene.ParserConfig{ -// MaxQueryLength: lucene.IntPtr(5000), -// MaxDepth: lucene.IntPtr(10), -// } -// parser := lucene.NewParser(fields, config) -func NewParser(fields []FieldInfo, config ...*ParserConfig) *Parser { - var cfg *ParserConfig - if len(config) > 0 && config[0] != nil { - cfg = config[0] - } + // Build field map fieldMap := make(map[string]FieldInfo, len(fields)) - jsonbFields := make(map[string]bool) for _, f := range fields { fieldMap[f.Name] = f - if f.IsJSONB { - jsonbFields[f.Name] = true - } } // Apply config or use defaults @@ -159,15 +86,16 @@ func NewParser(fields []FieldInfo, config ...*ParserConfig) *Parser { maxDepth := DefaultMaxDepth maxTerms := DefaultMaxTerms - if cfg != nil { - if cfg.MaxQueryLength != nil { - maxQueryLength = *cfg.MaxQueryLength + if len(config) > 0 && config[0] != nil { + cfg := config[0] + if cfg.MaxQueryLength > 0 { + maxQueryLength = cfg.MaxQueryLength } - if cfg.MaxDepth != nil { - maxDepth = *cfg.MaxDepth + if cfg.MaxDepth > 0 { + maxDepth = cfg.MaxDepth } - if cfg.MaxTerms != nil { - maxTerms = *cfg.MaxTerms + if cfg.MaxTerms > 0 { + maxTerms = cfg.MaxTerms } } @@ -177,34 +105,11 @@ func NewParser(fields []FieldInfo, config ...*ParserConfig) *Parser { MaxDepth: maxDepth, MaxTerms: maxTerms, fieldMap: fieldMap, - jsonbFields: jsonbFields, - } + }, nil } -// Precompiled regex for performance - matches Lucene operators and special syntax -var ( - // Matches field:value pattern (including JSONB like labels.category:value) - fieldValuePattern = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)?:`) - // Extracts field name from field:value pattern - fieldExtractPattern = regexp.MustCompile(`([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)?):`) - // Matches boolean operators (case-insensitive) - booleanOperators = regexp.MustCompile(`(?i)^(AND|OR|NOT|\+|-)$`) - // Matches range syntax - rangePattern = regexp.MustCompile(`^\[.*\s+TO\s+.*\]$|^\{.*\s+TO\s+.*\}$`) -) - -// InvalidFieldError represents an error when a query references a non-existent field -type InvalidFieldError struct { - Field string - ValidFields []string -} - -func (e *InvalidFieldError) Error() string { - return fmt.Sprintf("invalid field '%s' in query; valid fields are: %s", e.Field, strings.Join(e.ValidFields, ", ")) -} - -// getStructFieldsWithConfig extracts field metadata using configurable tag names. -func getStructFieldsWithConfig(model any, config *ParserConfig) ([]FieldInfo, error) { +// extractFields uses reflection to extract field metadata from a struct. +func extractFields(model any) ([]FieldInfo, error) { t := reflect.TypeOf(model) if t.Kind() == reflect.Ptr { t = t.Elem() @@ -214,48 +119,27 @@ func getStructFieldsWithConfig(model any, config *ParserConfig) ([]FieldInfo, er return nil, fmt.Errorf("expected struct, got %s", t.Kind()) } - // Determine tag names from config or use defaults - jsonTag := TagJSON - gormTag := TagGORM - luceneTag := TagLucene - if config != nil { - if config.JSONTag != "" { - jsonTag = config.JSONTag - } - if config.GORMTag != "" { - gormTag = config.GORMTag - } - if config.LuceneTag != "" { - luceneTag = config.LuceneTag - } - } - var fields []FieldInfo for i := 0; i < t.NumField(); i++ { field := t.Field(i) - jsonTagValue := field.Tag.Get(jsonTag) - if jsonTagValue == "" || jsonTagValue == "-" { + + // Get field name from json tag + jsonTag := field.Tag.Get("json") + if jsonTag == "" || jsonTag == "-" { continue } - if commaIdx := strings.Index(jsonTagValue, ","); commaIdx != -1 { - jsonTagValue = jsonTagValue[:commaIdx] + // Strip options from json tag (e.g., "name,omitempty" -> "name") + if commaIdx := strings.Index(jsonTag, ","); commaIdx != -1 { + jsonTag = jsonTag[:commaIdx] } - gormTagValue := field.Tag.Get(gormTag) - isJSONB := strings.Contains(gormTagValue, "type:jsonb") - - luceneTagValue := field.Tag.Get(luceneTag) - implicitSearch := false - if luceneTagValue == "implicit" { - implicitSearch = true - } else if luceneTagValue != "explicit" { - implicitSearch = field.Type.Kind() == reflect.String && !isJSONB - } + // Implicit search: only string fields + implicitSearch := field.Type.Kind() == reflect.String fields = append(fields, FieldInfo{ - Name: jsonTagValue, - IsJSONB: isJSONB, + Name: jsonTag, + Type: field.Type, ImplicitSearch: implicitSearch, }) } @@ -263,10 +147,57 @@ func getStructFieldsWithConfig(model any, config *ParserConfig) ([]FieldInfo, er return fields, nil } +// canUseNestedAccess checks if a field type supports nested access (field.subfield syntax). +func canUseNestedAccess(t reflect.Type) bool { + // Return false for nil types + if t == nil { + return false + } + + // Unwrap pointers + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + // Check type name for JSONB-like types + name := t.Name() + if strings.Contains(name, "JSONB") || strings.Contains(name, "JSON") { + return true + } + + // Maps and structs support nested access + if t.Kind() == reflect.Map || t.Kind() == reflect.Struct { + return true + } + + return false +} + +// Precompiled regex for performance - matches Lucene operators and special syntax +var ( + // Matches field:value pattern (including JSONB like labels.category:value) + fieldValuePattern = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)?:`) + // Extracts field name from field:value pattern + fieldExtractPattern = regexp.MustCompile(`([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)?):`) + // Matches boolean operators (case-insensitive) + booleanOperators = regexp.MustCompile(`(?i)^(AND|OR|NOT|\+|-)$`) + // Matches range syntax + rangePattern = regexp.MustCompile(`^\[.*\s+TO\s+.*\]$|^\{.*\s+TO\s+.*\}$`) +) + +// InvalidFieldError represents an error when a query references a non-existent field +type InvalidFieldError struct { + Field string + ValidFields []string +} + +func (e *InvalidFieldError) Error() string { + return fmt.Sprintf("invalid field '%s' in query; valid fields are: %s", e.Field, strings.Join(e.ValidFields, ", ")) +} + // ParseToMap parses a Lucene query into a map representation. // Note: This is a legacy method kept for backward compatibility. func (p *Parser) ParseToMap(query string) (map[string]any, error) { - if err := p.validateQuery(query); err != nil { return nil, err } @@ -304,7 +235,7 @@ func (p *Parser) ParseToSQL(query string) (string, []any, error) { } // Create PostgreSQL driver on-demand and render - driver := NewPostgresJSONBDriver(p.Fields) + driver := NewPostgresDriver(p.Fields) sql, params, err := driver.RenderParam(e) if err != nil { return "", nil, err @@ -337,7 +268,7 @@ func (p *Parser) ParseToDynamoDBPartiQL(query string) (string, []types.Attribute } // Create DynamoDB driver on-demand and render - driver := NewDynamoDBPartiQLDriver(p.Fields) + driver := NewDynamoDBDriver(p.Fields) partiql, attrs, err := driver.RenderPartiQL(e) if err != nil { return "", nil, err @@ -547,7 +478,7 @@ func isInsideQuotes(query string, pos int) bool { return inQuotes } -// validateFieldName validates both simple fields (name) and JSONB sub-fields (labels.category). +// validateFieldName validates both simple fields (name) and nested fields (labels.category). func (p *Parser) validateFieldName(fieldName string) error { if strings.Contains(fieldName, ".") { parts := strings.SplitN(fieldName, ".", 2) @@ -557,11 +488,15 @@ func (p *Parser) validateFieldName(fieldName string) error { baseField := parts[0] - if !p.jsonbFields[baseField] { - if _, exists := p.fieldMap[baseField]; !exists { - return fmt.Errorf("field '%s' does not exist", baseField) - } - return fmt.Errorf("field '%s' is not a JSONB field; cannot use sub-field notation", baseField) + // Check if base field exists + field, exists := p.fieldMap[baseField] + if !exists { + return fmt.Errorf("field '%s' does not exist", baseField) + } + + // Check if base field supports nested access + if !canUseNestedAccess(field.Type) { + return fmt.Errorf("field '%s' does not support nested access (field.subfield syntax); use explicit field names only", baseField) } return nil @@ -577,7 +512,8 @@ func (p *Parser) validateFieldName(fieldName string) error { func (p *Parser) getValidFieldNames() []string { var names []string for _, f := range p.Fields { - if f.IsJSONB { + // Add a hint for fields that support nested access + if canUseNestedAccess(f.Type) { names = append(names, f.Name+".*") } else { names = append(names, f.Name) diff --git a/storage/search/lucene/parser_test.go b/storage/search/lucene/parser_test.go index 12f0b6f..a5a96bc 100644 --- a/storage/search/lucene/parser_test.go +++ b/storage/search/lucene/parser_test.go @@ -1,18 +1,68 @@ package lucene import ( - "fmt" "strings" "testing" ) +// Test model definitions +type BasicModel struct { + Name string `json:"name"` + Email string `json:"email"` +} + +type BooleanModel struct { + Name string `json:"name"` + Status string `json:"status"` + Role string `json:"role"` +} + +type RangeModel struct { + Age int `json:"age"` + Date string `json:"date"` +} + +type TextModel struct { + Description string `json:"description"` + Title string `json:"title"` + Name string `json:"name"` +} + +type ComplexModel struct { + Name string `json:"name"` + Age int `json:"age"` + Status string `json:"status"` + Email string `json:"email"` +} + +// JSONB types for testing +type JSONBType map[string]interface{} + +type JSONBModel struct { + Metadata JSONBType `json:"metadata"` +} + +type MixedModel struct { + Name string `json:"name"` + Description string `json:"description"` + Status string `json:"status"` + Labels JSONBType `json:"labels"` + Metadata JSONBType `json:"metadata"` +} + +type NullModel struct { + Name string `json:"name"` + ParentID string `json:"parent_id"` + DeletedAt string `json:"deleted_at"` + AttachmentIDs string `json:"attachment_ids"` +} + // TestBasicFieldSearch tests basic field:value queries func TestBasicFieldSearch(t *testing.T) { - fields := []FieldInfo{ - {Name: "name", IsJSONB: false}, - {Name: "email", IsJSONB: false}, + parser, err := NewParser(BasicModel{}) + if err != nil { + t.Fatalf("NewParser() error = %v", err) } - parser := NewParser(fields) tests := []struct { name string @@ -64,12 +114,10 @@ func TestBasicFieldSearch(t *testing.T) { // TestBooleanOperators tests AND, OR, NOT operators func TestBooleanOperators(t *testing.T) { - fields := []FieldInfo{ - {Name: "name", IsJSONB: false}, - {Name: "status", IsJSONB: false}, - {Name: "role", IsJSONB: false}, + parser, err := NewParser(BooleanModel{}) + if err != nil { + t.Fatalf("NewParser() error = %v", err) } - parser := NewParser(fields) tests := []struct { name string @@ -115,11 +163,10 @@ func TestBooleanOperators(t *testing.T) { // TestRequiredProhibited tests + and - operators func TestRequiredProhibited(t *testing.T) { - fields := []FieldInfo{ - {Name: "name", IsJSONB: false}, - {Name: "status", IsJSONB: false}, + parser, err := NewParser(BooleanModel{}) + if err != nil { + t.Fatalf("NewParser() error = %v", err) } - parser := NewParser(fields) tests := []struct { name string @@ -134,12 +181,12 @@ func TestRequiredProhibited(t *testing.T) { { name: "prohibited term", query: "-status:inactive", - wantSQL: []string{`"status"`, "NOT"}, + wantSQL: []string{"NOT", `"status"`}, }, { name: "mixed required and prohibited", query: "+name:john -status:inactive", - wantSQL: []string{`"name"`, `"status"`, "NOT"}, + wantSQL: []string{`"name"`, "NOT", `"status"`}, }, } @@ -160,11 +207,10 @@ func TestRequiredProhibited(t *testing.T) { // TestRangeQueries tests range query syntax func TestRangeQueries(t *testing.T) { - fields := []FieldInfo{ - {Name: "age", IsJSONB: false}, - {Name: "date", IsJSONB: false}, + parser, err := NewParser(RangeModel{}) + if err != nil { + t.Fatalf("NewParser() error = %v", err) } - parser := NewParser(fields) tests := []struct { name string @@ -173,28 +219,28 @@ func TestRangeQueries(t *testing.T) { }{ { name: "inclusive range", - query: "age:[18 TO 65]", - wantSQL: []string{`"age" BETWEEN`}, + query: "age:[25 TO 65]", + wantSQL: []string{`"age"`, "BETWEEN"}, }, { name: "exclusive range", - query: "age:{18 TO 65}", - wantSQL: []string{`"age" >`, `"age" <`}, + query: "age:{25 TO 65}", + wantSQL: []string{`"age"`, ">", "<"}, }, { name: "open-ended range min", - query: "age:[18 TO *]", - wantSQL: []string{`"age" >=`}, + query: "age:[25 TO *]", + wantSQL: []string{`"age"`, ">="}, }, { name: "open-ended range max", query: "age:[* TO 65]", - wantSQL: []string{`"age" <=`}, + wantSQL: []string{`"age"`, "<="}, }, { name: "date range", - query: "date:[2020-01-01 TO 2023-12-31]", - wantSQL: []string{`"date"`}, + query: "date:[2024-01-01 TO 2024-12-31]", + wantSQL: []string{`"date"`, "BETWEEN"}, }, } @@ -215,11 +261,10 @@ func TestRangeQueries(t *testing.T) { // TestQuotedPhrases tests quoted phrase handling func TestQuotedPhrases(t *testing.T) { - fields := []FieldInfo{ - {Name: "description", IsJSONB: false}, - {Name: "title", IsJSONB: false}, + parser, err := NewParser(TextModel{}) + if err != nil { + t.Fatalf("NewParser() error = %v", err) } - parser := NewParser(fields) tests := []struct { name string @@ -233,7 +278,7 @@ func TestQuotedPhrases(t *testing.T) { }, { name: "phrase with special chars", - query: `title:"Go: The Complete Guide"`, + query: `title:"test-app (v1.0)"`, wantSQL: []string{`"title"`}, }, } @@ -255,10 +300,10 @@ func TestQuotedPhrases(t *testing.T) { // TestEscapedCharacters tests escaped character handling func TestEscapedCharacters(t *testing.T) { - fields := []FieldInfo{ - {Name: "name", IsJSONB: false}, + parser, err := NewParser(TextModel{}) + if err != nil { + t.Fatalf("NewParser() error = %v", err) } - parser := NewParser(fields) tests := []struct { name string @@ -294,13 +339,10 @@ func TestEscapedCharacters(t *testing.T) { // TestComplexQueries tests complex query combinations func TestComplexQueries(t *testing.T) { - fields := []FieldInfo{ - {Name: "name", IsJSONB: false}, - {Name: "age", IsJSONB: false}, - {Name: "status", IsJSONB: false}, - {Name: "email", IsJSONB: false}, + parser, err := NewParser(ComplexModel{}) + if err != nil { + t.Fatalf("NewParser() error = %v", err) } - parser := NewParser(fields) tests := []struct { name string @@ -309,60 +351,50 @@ func TestComplexQueries(t *testing.T) { shouldErr bool }{ { - name: "complex with ranges and wildcards", - query: "name:john* AND age:[25 TO 65]", - wantSQL: []string{`"name"`, `"age"`}, - shouldErr: false, + name: "complex with ranges and wildcards", + query: "(name:john* OR email:test*) AND age:[25 TO 65]", + wantSQL: []string{`"name"`, `"email"`, `"age"`, "OR", "AND", "BETWEEN"}, }, { - name: "complex with required and prohibited", - query: "+name:john -status:inactive AND age:[30 TO *]", - wantSQL: []string{`"name"`, `"status"`, `"age"`}, - shouldErr: false, + name: "complex with required and prohibited", + query: "+name:john -status:inactive age:[25 TO 65]", + wantSQL: []string{`"name"`, `"status"`, `"age"`, "NOT"}, }, { - name: "complex with quoted phrases", - query: `name:"John Doe" AND (status:active OR status:pending)`, - wantSQL: []string{`"name"`, `"status"`}, - shouldErr: false, + name: "complex with quoted phrases", + query: `name:"John Doe" AND status:active`, + wantSQL: []string{`"name"`, `"status"`, "AND"}, }, { - name: "complex nested query", - query: "((name:john OR name:jane) AND status:active) OR (age:[18 TO 25] AND status:pending)", - wantSQL: []string{`"name"`, `"status"`, `"age"`}, - shouldErr: false, + name: "complex nested query", + query: "((name:john OR name:jane) AND status:active) OR (status:pending AND age:[18 TO *])", + wantSQL: []string{`"name"`, `"status"`, `"age"`, "OR", "AND"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { sql, _, err := parser.ParseToSQL(tt.query) - if tt.shouldErr { - if err == nil { - t.Errorf("ParseToSQL() expected error but got none") - } - return - } - if err != nil { - t.Fatalf("ParseToSQL() error = %v", err) - } - for _, want := range tt.wantSQL { - if !strings.Contains(sql, want) { - t.Errorf("ParseToSQL() sql = %v, want to contain %v", sql, want) + if (err != nil) != tt.shouldErr { + t.Fatalf("ParseToSQL() error = %v, shouldErr = %v", err, tt.shouldErr) + } + if !tt.shouldErr { + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("ParseToSQL() sql = %v, want to contain %v", sql, want) + } } } }) } } -// TestImplicitSearch tests implicit search across fields with ImplicitSearch=true +// TestImplicitSearch tests implicit search across string fields func TestImplicitSearch(t *testing.T) { - fields := []FieldInfo{ - {Name: "name", IsJSONB: false, ImplicitSearch: true}, - {Name: "email", IsJSONB: false, ImplicitSearch: true}, - {Name: "description", IsJSONB: false, ImplicitSearch: true}, + parser, err := NewParser(TextModel{}) + if err != nil { + t.Fatalf("NewParser() error = %v", err) } - parser := NewParser(fields) tests := []struct { name string @@ -374,7 +406,7 @@ func TestImplicitSearch(t *testing.T) { name: "implicit search", query: "john", wantOR: true, - wantParams: 3, // Should expand to 3 fields + wantParams: 3, // name, description, title }, { name: "implicit search with wildcard", @@ -391,7 +423,7 @@ func TestImplicitSearch(t *testing.T) { t.Fatalf("ParseToSQL() error = %v", err) } if tt.wantOR && !strings.Contains(sql, "OR") { - t.Errorf("ParseToSQL() sql = %v, want to contain OR", sql) + t.Errorf("ParseToSQL() expected OR in implicit search, got: %v", sql) } if len(params) != tt.wantParams { t.Errorf("ParseToSQL() params count = %v, want %v", len(params), tt.wantParams) @@ -402,10 +434,10 @@ func TestImplicitSearch(t *testing.T) { // TestJSONBFields tests JSONB field notation func TestJSONBFields(t *testing.T) { - fields := []FieldInfo{ - {Name: "metadata", IsJSONB: true}, + parser, err := NewParser(JSONBModel{}) + if err != nil { + t.Fatalf("NewParser() error = %v", err) } - parser := NewParser(fields) tests := []struct { name string @@ -441,11 +473,10 @@ func TestJSONBFields(t *testing.T) { // TestMapOutput tests the legacy map output format func TestMapOutput(t *testing.T) { - fields := []FieldInfo{ - {Name: "name", IsJSONB: false}, - {Name: "status", IsJSONB: false}, + parser, err := NewParser(BooleanModel{}) + if err != nil { + t.Fatalf("NewParser() error = %v", err) } - parser := NewParser(fields) result, err := parser.ParseToMap("name:john AND status:active") if err != nil { @@ -459,14 +490,10 @@ func TestMapOutput(t *testing.T) { // TestFieldValidation tests field validation for invalid field references func TestFieldValidation(t *testing.T) { - fields := []FieldInfo{ - {Name: "name", IsJSONB: false, ImplicitSearch: true}, - {Name: "description", IsJSONB: false, ImplicitSearch: true}, - {Name: "status", IsJSONB: false}, - {Name: "labels", IsJSONB: true}, - {Name: "metadata", IsJSONB: true}, + parser, err := NewParser(MixedModel{}) + if err != nil { + t.Fatalf("NewParser() error = %v", err) } - parser := NewParser(fields) tests := []struct { name string @@ -481,46 +508,46 @@ func TestFieldValidation(t *testing.T) { }, { name: "valid JSONB sub-field", - query: "labels.category:urgent", + query: "labels.category:prod", wantErr: false, }, { name: "invalid field", - query: "nonexistent:value", + query: "invalidfield:value", wantErr: true, - errField: "nonexistent", + errField: "invalidfield", }, { - name: "invalid JSONB base field", - query: "fakejsonb.key:value", + name: "invalid JSONB base", + query: "notjsonb.subfield:value", wantErr: true, - errField: "fakejsonb.key", + errField: "notjsonb", }, { name: "sub-field on non-JSONB field", query: "name.subfield:value", wantErr: true, - errField: "name.subfield", + errField: "name", }, { name: "implicit search (no explicit fields) - valid", - query: "paint", + query: "searchterm", wantErr: false, }, { name: "mixed valid and implicit", - query: "status:active AND paint", + query: "name:john OR searchterm", wantErr: false, }, { name: "mixed valid and invalid", - query: "name:john AND invalid_field:test", + query: "name:john OR invalidfield:value", wantErr: true, - errField: "invalid_field", + errField: "invalidfield", }, { name: "complex valid query", - query: "(name:john OR description:test) AND status:active AND labels.priority:high", + query: "(name:john OR description:test) AND labels.env:prod", wantErr: false, }, { @@ -534,81 +561,22 @@ func TestFieldValidation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _, _, err := parser.ParseToSQL(tt.query) - if tt.wantErr { - if err == nil { - t.Errorf("ParseToSQL() expected error for query %q but got none", tt.query) - return - } - if _, ok := err.(*InvalidFieldError); !ok { - t.Errorf("ParseToSQL() error = %v, want InvalidFieldError", err) - return - } - if !strings.Contains(err.Error(), tt.errField) { - t.Errorf("ParseToSQL() error = %v, want to mention field %q", err, tt.errField) - } - } else { - if err != nil { - t.Errorf("ParseToSQL() unexpected error = %v for query %q", err, tt.query) - } - } - }) - } -} - -// TestValidateFields tests the ValidateFields method directly -func TestValidateFields(t *testing.T) { - fields := []FieldInfo{ - {Name: "id", IsJSONB: false}, - {Name: "tenant_id", IsJSONB: false}, - {Name: "name", IsJSONB: false, ImplicitSearch: true}, - {Name: "description", IsJSONB: false, ImplicitSearch: true}, - {Name: "status", IsJSONB: false}, - {Name: "labels", IsJSONB: true}, - {Name: "properties", IsJSONB: true}, - {Name: "metadata", IsJSONB: true}, - } - parser := NewParser(fields) - - tests := []struct { - name string - query string - wantErr bool - }{ - {"valid simple field", "name:test", false}, - {"valid multiple fields", "name:test AND status:active", false}, - {"valid JSONB sub-field", "labels.category:urgent", false}, - {"valid deep JSONB", "metadata.nested_key:value", false}, - {"invalid field", "unknown_field:test", true}, - {"invalid JSONB base", "unknown.subkey:test", true}, - {"sub-field on non-JSONB", "status.sub:test", true}, - {"empty query", "", false}, - {"implicit only - no field prefix", "searchterm", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := parser.ValidateFields(tt.query) - if tt.wantErr && err == nil { - t.Errorf("ValidateFields(%q) expected error but got none", tt.query) + if (err != nil) != tt.wantErr { + t.Errorf("ParseToSQL() error = %v, wantErr = %v", err, tt.wantErr) } - if !tt.wantErr && err != nil { - t.Errorf("ValidateFields(%q) unexpected error: %v", tt.query, err) + if tt.wantErr && tt.errField != "" && !strings.Contains(err.Error(), tt.errField) { + t.Errorf("ParseToSQL() error = %v, want to contain field %v", err, tt.errField) } }) } } -// TestNullValueQueries tests null value handling for IS NULL queries. -// Note: This is a SQL-specific extension (vanilla Lucene doesn't support NULL values). -// Only "null" (case-insensitive) is supported for IS NULL queries; "nil" is treated as a literal string. +// TestNullValueQueries tests null value handling for IS NULL queries func TestNullValueQueries(t *testing.T) { - fields := []FieldInfo{ - {Name: "name", IsJSONB: false}, - {Name: "parent_id", IsJSONB: false}, - {Name: "deleted_at", IsJSONB: false}, - {Name: "attachment_ids", IsJSONB: false}, + parser, err := NewParser(NullModel{}) + if err != nil { + t.Fatalf("NewParser() error = %v", err) } - parser := NewParser(fields) tests := []struct { name string @@ -618,113 +586,80 @@ func TestNullValueQueries(t *testing.T) { }{ { name: "field is null (lowercase)", - query: "deleted_at:null", - wantSQL: "IS NULL", + query: "parent_id:null", + wantSQL: `"parent_id" IS NULL`, }, { name: "field is NULL (uppercase)", - query: "deleted_at:NULL", - wantSQL: "IS NULL", + query: "parent_id:NULL", + wantSQL: `"parent_id" IS NULL`, }, { name: "field is Null (mixed case)", - query: "deleted_at:Null", - wantSQL: "IS NULL", + query: "parent_id:Null", + wantSQL: `"parent_id" IS NULL`, }, { name: "parent_id is null", query: "parent_id:null", - wantSQL: "IS NULL", + wantSQL: `"parent_id" IS NULL`, }, { name: "combined null with other conditions", - query: "deleted_at:null AND name:john", - wantSQL: "IS NULL", + query: "name:john AND deleted_at:null", + wantSQL: `"deleted_at" IS NULL`, }, { name: "NOT null (is not null)", query: "NOT deleted_at:null", - wantSQL: "NOT", + wantSQL: `NOT(`, }, { name: "nil should be treated as literal value (not NULL)", query: "name:nil", - wantSQL: "=", - wantErr: false, // Should not error, but should treat "nil" as literal string, not IS NULL + wantSQL: `"name" =`, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { sql, _, err := parser.ParseToSQL(tt.query) - if tt.wantErr { - if err == nil { - t.Errorf("ParseToSQL(%q) expected error but got none", tt.query) - } - return + if (err != nil) != tt.wantErr { + t.Errorf("ParseToSQL() error = %v, wantErr = %v", err, tt.wantErr) } - if err != nil { - t.Fatalf("ParseToSQL(%q) error = %v", tt.query, err) - } - if !strings.Contains(sql, tt.wantSQL) { - t.Errorf("ParseToSQL(%q) sql = %v, want to contain %v", tt.query, sql, tt.wantSQL) + if !tt.wantErr && !strings.Contains(sql, tt.wantSQL) { + t.Errorf("ParseToSQL() sql = %v, want to contain %v", sql, tt.wantSQL) } }) } } -// TestEmptyAsLiteralValue tests that 'empty' is treated as a literal value (not special keyword) +// TestEmptyAsLiteralValue tests that 'empty' is treated as a literal value func TestEmptyAsLiteralValue(t *testing.T) { - fields := []FieldInfo{ - {Name: "status", IsJSONB: false}, - {Name: "name", IsJSONB: false}, + parser, err := NewParser(BooleanModel{}) + if err != nil { + t.Fatalf("NewParser() error = %v", err) } - parser := NewParser(fields) - // 'empty' should be treated as a regular search value, not a special keyword sql, params, err := parser.ParseToSQL("status:empty") if err != nil { t.Fatalf("ParseToSQL() error = %v", err) } - // Should generate a regular equals query, not IS NULL - if strings.Contains(sql, "IS NULL") { - t.Errorf("'empty' should be treated as literal value, not IS NULL. Got: %s", sql) + if !strings.Contains(sql, `"status" =`) { + t.Errorf("Expected regular equals query, got: %v", sql) } - - // The value should be in params if len(params) != 1 || params[0] != "empty" { t.Errorf("Expected params to contain 'empty', got: %v", params) } } -// BenchmarkParser benchmarks the parser performance -func BenchmarkParser(b *testing.B) { - fields := []FieldInfo{ - {Name: "name", IsJSONB: false}, - {Name: "age", IsJSONB: false}, - {Name: "status", IsJSONB: false}, - {Name: "email", IsJSONB: false}, - } - parser := NewParser(fields) - - query := `(name:john* OR email:*@example.com) AND (status:active OR status:pending) AND age:[25 TO 65]` - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _, _ = parser.ParseToSQL(query) - } -} - -// TestFuzzySearch tests fuzzy search operator (~) using pg_trgm similarity +// TestFuzzySearch tests fuzzy search operator (~) func TestFuzzySearch(t *testing.T) { - fields := []FieldInfo{ - {Name: "name", IsJSONB: false}, - {Name: "description", IsJSONB: false}, - {Name: "status", IsJSONB: false}, - {Name: "labels", IsJSONB: true}, + parser, err := NewParser(MixedModel{}) + if err != nil { + t.Fatalf("NewParser() error = %v", err) } - parser := NewParser(fields) tests := []struct { name string @@ -744,46 +679,41 @@ func TestFuzzySearch(t *testing.T) { }, { name: "fuzzy on JSONB field", - query: "labels.category:construction~", + query: "labels.tag:prod~", wantSQL: "similarity", }, { name: "fuzzy combined with other conditions", - query: "name:roam~ AND status:active", + query: "name:test~ AND status:active", wantSQL: "similarity", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - sql, params, err := parser.ParseToSQL(tt.query) - if tt.wantErr { - if err == nil { - t.Errorf("ParseToSQL(%q) expected error but got none", tt.query) - } - return - } - if err != nil { - t.Fatalf("ParseToSQL(%q) error = %v", tt.query, err) - } - if !strings.Contains(sql, tt.wantSQL) { - t.Errorf("ParseToSQL(%q) sql = %v, want to contain %v", tt.query, sql, tt.wantSQL) + sql, _, err := parser.ParseToSQL(tt.query) + if (err != nil) != tt.wantErr { + t.Errorf("ParseToSQL() error = %v, wantErr = %v", err, tt.wantErr) } - if len(params) == 0 { - t.Errorf("ParseToSQL(%q) expected at least one parameter", tt.query) + if !tt.wantErr && !strings.Contains(sql, tt.wantSQL) { + t.Errorf("ParseToSQL() sql = %v, want to contain %v", sql, tt.wantSQL) } }) } } -// TestEscaping tests that special characters can be escaped in queries +// TestEscaping tests that special characters can be escaped func TestEscaping(t *testing.T) { - fields := []FieldInfo{ - {Name: "name", IsJSONB: false}, - {Name: "version", IsJSONB: false}, - {Name: "path", IsJSONB: false}, + type EscapeModel struct { + Name string `json:"name"` + Version string `json:"version"` + Path string `json:"path"` + } + + parser, err := NewParser(EscapeModel{}) + if err != nil { + t.Fatalf("NewParser() error = %v", err) } - parser := NewParser(fields) tests := []struct { name string @@ -798,43 +728,24 @@ func TestEscaping(t *testing.T) { }, { name: "escaped colon", - query: `version:1\.2\.3`, - wantSQL: `"version"`, - }, - { - name: "escaped parentheses", - query: `name:\(test\)`, + query: `name:test\:value`, wantSQL: `"name"`, }, { name: "escaped path separator", - query: `path:src\/components`, + query: `path:\/usr\/bin`, wantSQL: `"path"`, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - sql, params, err := parser.ParseToSQL(tt.query) - if tt.wantErr { - if err == nil { - t.Errorf("ParseToSQL(%q) expected error but got none", tt.query) - } - return - } - if err != nil { - t.Fatalf("ParseToSQL(%q) error = %v", tt.query, err) - } - if !strings.Contains(sql, tt.wantSQL) { - t.Errorf("ParseToSQL(%q) sql = %v, want to contain %v", tt.query, sql, tt.wantSQL) + sql, _, err := parser.ParseToSQL(tt.query) + if (err != nil) != tt.wantErr { + t.Errorf("ParseToSQL() error = %v, wantErr = %v", err, tt.wantErr) } - // Verify the escaped character is in the parameter - if len(params) > 0 { - paramStr := fmt.Sprintf("%v", params[0]) - // The escaped character should appear as the literal character in params - if strings.Contains(tt.query, `\+`) && !strings.Contains(paramStr, "+") { - t.Errorf("ParseToSQL(%q) expected '+' in params, got %v", tt.query, params) - } + if !tt.wantErr && !strings.Contains(sql, tt.wantSQL) { + t.Errorf("ParseToSQL() sql = %v, want to contain %v", sql, tt.wantSQL) } }) } @@ -842,11 +753,10 @@ func TestEscaping(t *testing.T) { // TestBoostOperatorError tests that boost operator returns a clear error func TestBoostOperatorError(t *testing.T) { - fields := []FieldInfo{ - {Name: "name", IsJSONB: false}, - {Name: "status", IsJSONB: false}, + parser, err := NewParser(BooleanModel{}) + if err != nil { + t.Fatalf("NewParser() error = %v", err) } - parser := NewParser(fields) tests := []struct { name string @@ -855,13 +765,13 @@ func TestBoostOperatorError(t *testing.T) { }{ { name: "boost operator", - query: "name:test^4", - wantErr: "boost operator", + query: "name:john^2", + wantErr: "boost", }, { name: "boost in compound query", - query: "name:test^2 AND status:active", - wantErr: "boost operator", + query: "name:john^2 AND status:active", + wantErr: "boost", }, } @@ -878,3 +788,14 @@ func TestBoostOperatorError(t *testing.T) { }) } } + +// BenchmarkParser benchmarks the parser performance +func BenchmarkParser(b *testing.B) { + parser, _ := NewParser(ComplexModel{}) + query := `(name:john* OR email:*@example.com) AND (status:active OR status:pending) AND age:[25 TO 65]` + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, _ = parser.ParseToSQL(query) + } +} diff --git a/storage/search/lucene/driver.go b/storage/search/lucene/postgres_driver.go similarity index 81% rename from storage/search/lucene/driver.go rename to storage/search/lucene/postgres_driver.go index 8320901..6b0153b 100644 --- a/storage/search/lucene/driver.go +++ b/storage/search/lucene/postgres_driver.go @@ -4,7 +4,6 @@ import ( "fmt" "strings" - "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/grindlemire/go-lucene/pkg/driver" "github.com/grindlemire/go-lucene/pkg/lucene/expr" ) @@ -16,7 +15,7 @@ type PostgresJSONBDriver struct { fields map[string]FieldInfo // Map of field names to their metadata } -func NewPostgresJSONBDriver(fields []FieldInfo) *PostgresJSONBDriver { +func NewPostgresDriver(fields []FieldInfo) *PostgresJSONBDriver { fieldMap := make(map[string]FieldInfo) for _, f := range fields { fieldMap[f.Name] = f @@ -338,14 +337,14 @@ func (p *PostgresJSONBDriver) processJSONBFields(e *expr.Expression) { } } -// formatFieldName converts field.subfield to JSONB syntax if the base field is JSONB. +// formatFieldName converts field.subfield to JSONB syntax if the base field supports nested access. func (p *PostgresJSONBDriver) formatFieldName(fieldName string) expr.Column { parts := strings.SplitN(fieldName, ".", 2) if len(parts) == 2 { baseField := parts[0] subField := parts[1] - if field, exists := p.fields[baseField]; exists && field.IsJSONB { + if field, exists := p.fields[baseField]; exists && canUseNestedAccess(field.Type) { // Return as JSONB operator syntax return expr.Column(fmt.Sprintf("%s->>'%s'", baseField, subField)) } @@ -353,7 +352,7 @@ func (p *PostgresJSONBDriver) formatFieldName(fieldName string) expr.Column { return expr.Column(fieldName) } -// Helper functions for DRY and cleaner code +// Helper functions for PostgreSQL driver // convertWildcards converts Lucene wildcards to SQL wildcards. // * (any characters) → % (SQL wildcard) @@ -481,90 +480,6 @@ func (p *PostgresJSONBDriver) renderRange(e *expr.Expression) (string, []any, er return fmt.Sprintf("(%s > ? AND %s < ?)", colStr, colStr), params, nil } -// DynamoDBPartiQLDriver converts Lucene queries to DynamoDB PartiQL. -type DynamoDBPartiQLDriver struct { - driver.Base - fields map[string]FieldInfo -} - -func NewDynamoDBPartiQLDriver(fields []FieldInfo) *DynamoDBPartiQLDriver { - fieldMap := make(map[string]FieldInfo) - for _, f := range fields { - fieldMap[f.Name] = f - } - - fns := map[expr.Operator]driver.RenderFN{ - expr.Literal: driver.Shared[expr.Literal], - expr.And: driver.Shared[expr.And], - expr.Or: driver.Shared[expr.Or], - expr.Not: driver.Shared[expr.Not], - expr.Equals: driver.Shared[expr.Equals], - expr.Range: driver.Shared[expr.Range], - expr.Must: driver.Shared[expr.Must], - expr.MustNot: driver.Shared[expr.MustNot], - expr.Wild: driver.Shared[expr.Wild], - expr.Regexp: driver.Shared[expr.Regexp], - expr.Like: dynamoDBLike, // Custom LIKE for DynamoDB functions - expr.Greater: driver.Shared[expr.Greater], - expr.GreaterEq: driver.Shared[expr.GreaterEq], - expr.Less: driver.Shared[expr.Less], - expr.LessEq: driver.Shared[expr.LessEq], - expr.In: driver.Shared[expr.In], - expr.List: driver.Shared[expr.List], - } - - return &DynamoDBPartiQLDriver{ - Base: driver.Base{ - RenderFNs: fns, - }, - fields: fieldMap, - } -} - -// RenderPartiQL renders the expression to DynamoDB PartiQL with AttributeValue parameters. -func (d *DynamoDBPartiQLDriver) RenderPartiQL(e *expr.Expression) (string, []types.AttributeValue, error) { - // Use base rendering with ? placeholders - str, params, err := d.RenderParam(e) - if err != nil { - return "", nil, err - } - - // Convert params to DynamoDB AttributeValues - attrValues := make([]types.AttributeValue, len(params)) - for i, param := range params { - attrValues[i] = &types.AttributeValueMemberS{Value: fmt.Sprintf("%v", param)} - } - - return str, attrValues, nil -} - -// dynamoDBLike implements LIKE using DynamoDB's begins_with and contains functions. -func dynamoDBLike(left, right string) (string, error) { - // Remove quotes from right side to analyze pattern - pattern := strings.Trim(right, "'") - - // Replace wildcards for analysis - hasPrefix := strings.HasPrefix(pattern, "%") - hasSuffix := strings.HasSuffix(pattern, "%") - - if hasPrefix && hasSuffix { - // %value% -> contains(field, value) - value := strings.Trim(pattern, "%") - return fmt.Sprintf("contains(%s, '%s')", left, value), nil - } else if !hasPrefix && hasSuffix { - // value% -> begins_with(field, value) - value := strings.TrimSuffix(pattern, "%") - return fmt.Sprintf("begins_with(%s, '%s')", left, value), nil - } else if hasPrefix && !hasSuffix { - // %value -> contains(field, value) (DynamoDB doesn't have ends_with) - value := strings.TrimPrefix(pattern, "%") - return fmt.Sprintf("contains(%s, '%s')", left, value), nil - } - - // Exact match - return fmt.Sprintf("%s = %s", left, right), nil -} - // convertToPostgresPlaceholders converts ? placeholders to PostgreSQL's $N format. func convertToPostgresPlaceholders(query string) string { paramIndex := 1 diff --git a/storage/sql.go b/storage/sql.go index bfc1f49..0165bd4 100644 --- a/storage/sql.go +++ b/storage/sql.go @@ -306,7 +306,7 @@ func (s *SQLAdapter) Search(dest any, sortKey string, query string, limit int, c destType := reflect.TypeOf(dest).Elem().Elem() model := reflect.New(destType).Elem().Interface() - parser, err := lucene.NewParserFromType(model) + parser, err := lucene.NewParser(model) if err != nil { slog.Error("Parser creation failed", "error", err) return "", err From f018b5b555ac036367467660b93031bc01333426 Mon Sep 17 00:00:00 2001 From: Lutherwaves Date: Wed, 31 Dec 2025 14:34:39 +0000 Subject: [PATCH 04/13] refactor: make Lucene parser SQL generation generic for PostgreSQL, MySQL, and SQLite - Renamed postgres_driver.go to sql_driver.go for generic SQL support - Refactored PostgresJSONBDriver to SQLDriver with provider field - Added provider-specific switch statements for: * Case-insensitive LIKE (ILIKE vs LOWER() vs LIKE) * Fuzzy search (similarity() vs SOUNDEX() vs error) * JSON field extraction (JSONB ->> vs JSON_EXTRACT) * Parameter placeholders ($N vs ?) - Updated ParseToSQL() to accept provider string parameter - Updated SQLAdapter.Search() to pass provider to parser - Updated all tests to include "postgresql" provider parameter - PostgreSQL: ILIKE, ::text casting, similarity(), JSONB ->> operators, $N placeholders - MySQL: LOWER() + LIKE, JSON_UNQUOTE(JSON_EXTRACT()), SOUNDEX(), ? placeholders - SQLite: LIKE (case-insensitive), JSON_EXTRACT(), no fuzzy search, ? placeholders --- storage/search/lucene/parser.go | 11 +- storage/search/lucene/parser_test.go | 32 +-- .../{postgres_driver.go => sql_driver.go} | 207 +++++++++++------- storage/sql.go | 3 +- 4 files changed, 156 insertions(+), 97 deletions(-) rename storage/search/lucene/{postgres_driver.go => sql_driver.go} (62%) diff --git a/storage/search/lucene/parser.go b/storage/search/lucene/parser.go index d750358..b35fe62 100644 --- a/storage/search/lucene/parser.go +++ b/storage/search/lucene/parser.go @@ -211,9 +211,10 @@ func (p *Parser) ParseToMap(query string) (map[string]any, error) { return p.exprToMap(e), nil } -// ParseToSQL parses a Lucene query and converts it to PostgreSQL SQL with parameters. -// Creates a PostgreSQL driver on-demand for rendering. -func (p *Parser) ParseToSQL(query string) (string, []any, error) { +// ParseToSQL parses a Lucene query and converts it to SQL with parameters for the specified provider. +// Creates a SQL driver on-demand for rendering with provider-specific syntax. +// Provider should be one of: "postgresql", "mysql", "sqlite" +func (p *Parser) ParseToSQL(query string, provider string) (string, []any, error) { slog.Debug(fmt.Sprintf(`Parsing query to SQL: %s`, query)) if err := p.validateQuery(query); err != nil { @@ -234,8 +235,8 @@ func (p *Parser) ParseToSQL(query string) (string, []any, error) { return "", nil, err } - // Create PostgreSQL driver on-demand and render - driver := NewPostgresDriver(p.Fields) + // Create SQL driver on-demand for the specified provider and render + driver := NewSQLDriver(p.Fields, provider) sql, params, err := driver.RenderParam(e) if err != nil { return "", nil, err diff --git a/storage/search/lucene/parser_test.go b/storage/search/lucene/parser_test.go index a5a96bc..7a3df95 100644 --- a/storage/search/lucene/parser_test.go +++ b/storage/search/lucene/parser_test.go @@ -98,7 +98,7 @@ func TestBasicFieldSearch(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - sql, vals, err := parser.ParseToSQL(tt.query) + sql, vals, err := parser.ParseToSQL(tt.query, "postgresql") if err != nil { t.Fatalf("ParseToSQL() error = %v", err) } @@ -148,7 +148,7 @@ func TestBooleanOperators(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - sql, _, err := parser.ParseToSQL(tt.query) + sql, _, err := parser.ParseToSQL(tt.query, "postgresql") if err != nil { t.Fatalf("ParseToSQL() error = %v", err) } @@ -192,7 +192,7 @@ func TestRequiredProhibited(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - sql, _, err := parser.ParseToSQL(tt.query) + sql, _, err := parser.ParseToSQL(tt.query, "postgresql") if err != nil { t.Fatalf("ParseToSQL() error = %v", err) } @@ -246,7 +246,7 @@ func TestRangeQueries(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - sql, _, err := parser.ParseToSQL(tt.query) + sql, _, err := parser.ParseToSQL(tt.query, "postgresql") if err != nil { t.Fatalf("ParseToSQL() error = %v", err) } @@ -285,7 +285,7 @@ func TestQuotedPhrases(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - sql, _, err := parser.ParseToSQL(tt.query) + sql, _, err := parser.ParseToSQL(tt.query, "postgresql") if err != nil { t.Fatalf("ParseToSQL() error = %v", err) } @@ -324,7 +324,7 @@ func TestEscapedCharacters(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - sql, _, err := parser.ParseToSQL(tt.query) + sql, _, err := parser.ParseToSQL(tt.query, "postgresql") if err != nil { t.Fatalf("ParseToSQL() error = %v", err) } @@ -374,7 +374,7 @@ func TestComplexQueries(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - sql, _, err := parser.ParseToSQL(tt.query) + sql, _, err := parser.ParseToSQL(tt.query, "postgresql") if (err != nil) != tt.shouldErr { t.Fatalf("ParseToSQL() error = %v, shouldErr = %v", err, tt.shouldErr) } @@ -418,7 +418,7 @@ func TestImplicitSearch(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - sql, params, err := parser.ParseToSQL(tt.query) + sql, params, err := parser.ParseToSQL(tt.query, "postgresql") if err != nil { t.Fatalf("ParseToSQL() error = %v", err) } @@ -458,7 +458,7 @@ func TestJSONBFields(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - sql, _, err := parser.ParseToSQL(tt.query) + sql, _, err := parser.ParseToSQL(tt.query, "postgresql") if err != nil { t.Fatalf("ParseToSQL() error = %v", err) } @@ -560,7 +560,7 @@ func TestFieldValidation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, _, err := parser.ParseToSQL(tt.query) + _, _, err := parser.ParseToSQL(tt.query, "postgresql") if (err != nil) != tt.wantErr { t.Errorf("ParseToSQL() error = %v, wantErr = %v", err, tt.wantErr) } @@ -623,7 +623,7 @@ func TestNullValueQueries(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - sql, _, err := parser.ParseToSQL(tt.query) + sql, _, err := parser.ParseToSQL(tt.query, "postgresql") if (err != nil) != tt.wantErr { t.Errorf("ParseToSQL() error = %v, wantErr = %v", err, tt.wantErr) } @@ -641,7 +641,7 @@ func TestEmptyAsLiteralValue(t *testing.T) { t.Fatalf("NewParser() error = %v", err) } - sql, params, err := parser.ParseToSQL("status:empty") + sql, params, err := parser.ParseToSQL("status:empty", "postgresql") if err != nil { t.Fatalf("ParseToSQL() error = %v", err) } @@ -691,7 +691,7 @@ func TestFuzzySearch(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - sql, _, err := parser.ParseToSQL(tt.query) + sql, _, err := parser.ParseToSQL(tt.query, "postgresql") if (err != nil) != tt.wantErr { t.Errorf("ParseToSQL() error = %v, wantErr = %v", err, tt.wantErr) } @@ -740,7 +740,7 @@ func TestEscaping(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - sql, _, err := parser.ParseToSQL(tt.query) + sql, _, err := parser.ParseToSQL(tt.query, "postgresql") if (err != nil) != tt.wantErr { t.Errorf("ParseToSQL() error = %v, wantErr = %v", err, tt.wantErr) } @@ -777,7 +777,7 @@ func TestBoostOperatorError(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, _, err := parser.ParseToSQL(tt.query) + _, _, err := parser.ParseToSQL(tt.query, "postgresql") if err == nil { t.Errorf("ParseToSQL(%q) expected error but got none", tt.query) return @@ -796,6 +796,6 @@ func BenchmarkParser(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, _, _ = parser.ParseToSQL(query) + _, _, _ = parser.ParseToSQL(query, "postgresql") } } diff --git a/storage/search/lucene/postgres_driver.go b/storage/search/lucene/sql_driver.go similarity index 62% rename from storage/search/lucene/postgres_driver.go rename to storage/search/lucene/sql_driver.go index 6b0153b..b48eec0 100644 --- a/storage/search/lucene/postgres_driver.go +++ b/storage/search/lucene/sql_driver.go @@ -8,14 +8,17 @@ import ( "github.com/grindlemire/go-lucene/pkg/lucene/expr" ) -// PostgresJSONBDriver is a custom PostgreSQL driver that supports JSONB field notation. -// It extends the base PostgreSQL driver to handle field->>'subfield' syntax. -type PostgresJSONBDriver struct { +// SQLDriver is a SQL driver that supports multiple SQL dialects (PostgreSQL, MySQL, SQLite). +// It handles database-specific syntax for LIKE operators, JSON field access, and parameter placeholders. +type SQLDriver struct { driver.Base - fields map[string]FieldInfo // Map of field names to their metadata + fields map[string]FieldInfo // Map of field names to their metadata + provider string // SQL provider: "postgresql", "mysql", or "sqlite" } -func NewPostgresDriver(fields []FieldInfo) *PostgresJSONBDriver { +// NewSQLDriver creates a new SQL driver for the specified provider. +// Provider should be one of: "postgresql", "mysql", "sqlite" +func NewSQLDriver(fields []FieldInfo, provider string) *SQLDriver { fieldMap := make(map[string]FieldInfo) for _, f := range fields { fieldMap[f.Name] = f @@ -43,113 +46,147 @@ func NewPostgresDriver(fields []FieldInfo) *PostgresJSONBDriver { expr.List: driver.Shared[expr.List], } - return &PostgresJSONBDriver{ + return &SQLDriver{ Base: driver.Base{ RenderFNs: fns, }, - fields: fieldMap, + fields: fieldMap, + provider: provider, } } -// RenderParam renders the expression with PostgreSQL-style $N placeholders. -func (p *PostgresJSONBDriver) RenderParam(e *expr.Expression) (string, []any, error) { - // Process JSONB field notation before rendering - p.processJSONBFields(e) +// RenderParam renders the expression with provider-specific parameter placeholders. +func (s *SQLDriver) RenderParam(e *expr.Expression) (string, []any, error) { + // Process JSON field notation before rendering + s.processJSONFields(e) // Use our custom rendering logic - str, params, err := p.renderParamInternal(e) + str, params, err := s.renderParamInternal(e) if err != nil { return "", nil, err } - // Convert ? to $N format - str = convertToPostgresPlaceholders(str) + // Convert ? placeholders to provider-specific format + // PostgreSQL uses $1, $2, $3; MySQL and SQLite use ? + switch s.provider { + case "postgresql": + str = convertToPostgresPlaceholders(str) + case "mysql", "sqlite": + // Already uses ? placeholders, no conversion needed + } return str, params, nil } // renderParamInternal dispatches to specialized renderers based on operator type. -func (p *PostgresJSONBDriver) renderParamInternal(e *expr.Expression) (string, []any, error) { +func (s *SQLDriver) renderParamInternal(e *expr.Expression) (string, []any, error) { if e == nil { return "", nil, nil } switch e.Op { case expr.Like, expr.Wild: - return p.renderLikeOrWild(e) + return s.renderLikeOrWild(e) case expr.Fuzzy: - return p.renderFuzzy(e) + return s.renderFuzzy(e) case expr.Boost: return "", nil, fmt.Errorf("boost operator (^) is not supported in SQL filtering; it only affects ranking/scoring") case expr.Range: - return p.renderRange(e) + return s.renderRange(e) case expr.Equals, expr.Greater, expr.Less, expr.GreaterEq, expr.LessEq: - return p.renderComparison(e) + return s.renderComparison(e) case expr.And, expr.Or, expr.Must, expr.MustNot: - return p.renderBinary(e) + return s.renderBinary(e) default: // Use base implementation for all other operators - return p.Base.RenderParam(e) + return s.Base.RenderParam(e) } } -// renderLikeOrWild converts LIKE and Wild operators to PostgreSQL ILIKE for case-insensitive matching. -func (p *PostgresJSONBDriver) renderLikeOrWild(e *expr.Expression) (string, []any, error) { - leftStr, leftParams, err := p.serializeColumn(e.Left) +// renderLikeOrWild converts LIKE and Wild operators to provider-specific case-insensitive matching. +func (s *SQLDriver) renderLikeOrWild(e *expr.Expression) (string, []any, error) { + leftStr, leftParams, err := s.serializeColumn(e.Left) if err != nil { return "", nil, err } - rightStr, rightParams, err := p.serializeValue(e.Right) + rightStr, rightParams, err := s.serializeValue(e.Right) if err != nil { return "", nil, err } params := append(leftParams, rightParams...) - if isJSONBSyntax(leftStr) { - return fmt.Sprintf("%s ILIKE %s", leftStr, rightStr), params, nil + switch s.provider { + case "postgresql": + // PostgreSQL: ILIKE for case-insensitive matching + if isJSONSyntax(leftStr) { + return fmt.Sprintf("%s ILIKE %s", leftStr, rightStr), params, nil + } + return fmt.Sprintf("%s::text ILIKE %s", leftStr, rightStr), params, nil + + case "mysql": + // MySQL: Use LOWER() for case-insensitive matching + return fmt.Sprintf("LOWER(%s) LIKE LOWER(%s)", leftStr, rightStr), params, nil + + case "sqlite": + // SQLite: LIKE is already case-insensitive for ASCII by default + return fmt.Sprintf("%s LIKE %s", leftStr, rightStr), params, nil + + default: + return "", nil, fmt.Errorf("unsupported SQL provider: %s", s.provider) } - return fmt.Sprintf("%s::text ILIKE %s", leftStr, rightStr), params, nil } -// renderFuzzy handles fuzzy search using PostgreSQL similarity() function. -// Requires pg_trgm extension. +// renderFuzzy handles fuzzy search with provider-specific implementations. // For queries like "name:roam~2", the structure is: // - Op: Fuzzy // - Left: Equals expression (name:roam) with Left=Column("name"), Right=Literal("roam") // - Right: nil (distance stored in unexported fuzzyDistance field) -func (p *PostgresJSONBDriver) renderFuzzy(e *expr.Expression) (string, []any, error) { +func (s *SQLDriver) renderFuzzy(e *expr.Expression) (string, []any, error) { leftExpr, ok := e.Left.(*expr.Expression) if !ok || leftExpr.Op != expr.Equals { return "", nil, fmt.Errorf("fuzzy operator requires field:value syntax (e.g., name:roam~2)") } - colStr, colParams, err := p.serializeColumn(leftExpr.Left) + colStr, colParams, err := s.serializeColumn(leftExpr.Left) if err != nil { return "", nil, err } - termStr, termParams, err := p.serializeValue(leftExpr.Right) + termStr, termParams, err := s.serializeValue(leftExpr.Right) if err != nil { return "", nil, err } params := append(colParams, termParams...) - // Use threshold 0.3 (lower = more matches, higher = stricter). - // The fuzzy distance from go-lucene is unexported, so we use a reasonable default. - threshold := 0.3 + switch s.provider { + case "postgresql": + // PostgreSQL: Use similarity() function from pg_trgm extension + // Threshold 0.3 (lower = more matches, higher = stricter) + threshold := 0.3 + if isJSONSyntax(colStr) { + return fmt.Sprintf("similarity(%s, %s) > %f", colStr, termStr, threshold), params, nil + } + return fmt.Sprintf("similarity(%s::text, %s) > %f", colStr, termStr, threshold), params, nil + + case "mysql": + // MySQL: Use SOUNDEX for phonetic matching (limited fuzzy support) + return fmt.Sprintf("SOUNDEX(%s) = SOUNDEX(%s)", colStr, termStr), params, nil - if isJSONBSyntax(colStr) { - return fmt.Sprintf("similarity(%s, %s) > %f", colStr, termStr, threshold), params, nil + case "sqlite": + // SQLite: No built-in fuzzy search support + return "", nil, fmt.Errorf("fuzzy search (field:term~N) is not supported with SQLite; use wildcards instead (e.g., field:term*)") + + default: + return "", nil, fmt.Errorf("unsupported SQL provider: %s", s.provider) } - return fmt.Sprintf("similarity(%s::text, %s) > %f", colStr, termStr, threshold), params, nil } // renderComparison handles comparison operators with IS NULL support for null values. -func (p *PostgresJSONBDriver) renderComparison(e *expr.Expression) (string, []any, error) { - leftStr, leftParams, err := p.serializeColumn(e.Left) +func (s *SQLDriver) renderComparison(e *expr.Expression) (string, []any, error) { + leftStr, leftParams, err := s.serializeColumn(e.Left) if err != nil { return "", nil, err } @@ -161,7 +198,7 @@ func (p *PostgresJSONBDriver) renderComparison(e *expr.Expression) (string, []an return "", nil, fmt.Errorf("cannot use comparison operators (>, <, >=, <=) with null value") } - rightStr, rightParams, err := p.serializeValue(e.Right) + rightStr, rightParams, err := s.serializeValue(e.Right) if err != nil { return "", nil, err } @@ -187,7 +224,7 @@ func (p *PostgresJSONBDriver) renderComparison(e *expr.Expression) (string, []an // renderBinary handles binary and unary logical operators recursively. // Note: Must and MustNot are unary (only Left operand), while And and Or are binary. -func (p *PostgresJSONBDriver) renderBinary(e *expr.Expression) (string, []any, error) { +func (s *SQLDriver) renderBinary(e *expr.Expression) (string, []any, error) { switch e.Op { case expr.Must, expr.MustNot: if e.Left == nil { @@ -195,7 +232,7 @@ func (p *PostgresJSONBDriver) renderBinary(e *expr.Expression) (string, []any, e } if leftExpr, ok := e.Left.(*expr.Expression); ok { - leftStr, leftParams, err := p.renderParamInternal(leftExpr) + leftStr, leftParams, err := s.renderParamInternal(leftExpr) if err != nil { return "", nil, err } @@ -206,11 +243,11 @@ func (p *PostgresJSONBDriver) renderBinary(e *expr.Expression) (string, []any, e return fmt.Sprintf("NOT (%s)", leftStr), leftParams, nil } - leftStr, leftParams, err := p.serializeColumn(e.Left) + leftStr, leftParams, err := s.serializeColumn(e.Left) if err != nil { - leftStr, leftParams, err = p.serializeValue(e.Left) + leftStr, leftParams, err = s.serializeValue(e.Left) if err != nil { - return p.Base.RenderParam(e) + return s.Base.RenderParam(e) } } @@ -228,15 +265,15 @@ func (p *PostgresJSONBDriver) renderBinary(e *expr.Expression) (string, []any, e rightExpr, rightIsExpr := e.Right.(*expr.Expression) if !leftIsExpr || !rightIsExpr { - return p.Base.RenderParam(e) + return s.Base.RenderParam(e) } - leftStr, leftParams, err := p.renderParamInternal(leftExpr) + leftStr, leftParams, err := s.renderParamInternal(leftExpr) if err != nil { return "", nil, err } - rightStr, rightParams, err := p.renderParamInternal(rightExpr) + rightStr, rightParams, err := s.renderParamInternal(rightExpr) if err != nil { return "", nil, err } @@ -253,16 +290,16 @@ func (p *PostgresJSONBDriver) renderBinary(e *expr.Expression) (string, []any, e } } -func (p *PostgresJSONBDriver) serializeColumn(in any) (string, []any, error) { +func (s *SQLDriver) serializeColumn(in any) (string, []any, error) { switch v := in.(type) { case expr.Column: colStr := string(v) - if isJSONBSyntax(colStr) { + if isJSONSyntax(colStr) { return colStr, nil, nil } return fmt.Sprintf(`"%s"`, colStr), nil, nil case string: - if isJSONBSyntax(v) { + if isJSONSyntax(v) { return v, nil, nil } return fmt.Sprintf(`"%s"`, v), nil, nil @@ -270,20 +307,20 @@ func (p *PostgresJSONBDriver) serializeColumn(in any) (string, []any, error) { if v.Op == expr.Literal && v.Left != nil { if col, ok := v.Left.(expr.Column); ok { colStr := string(col) - if isJSONBSyntax(colStr) { + if isJSONSyntax(colStr) { return colStr, nil, nil } return fmt.Sprintf(`"%s"`, colStr), nil, nil } } - return p.renderParamInternal(v) + return s.renderParamInternal(v) default: return "", nil, fmt.Errorf("unexpected column type: %T", v) } } // serializeValue converts Lucene wildcards (* and ?) to SQL wildcards (% and _). -func (p *PostgresJSONBDriver) serializeValue(in any) (string, []any, error) { +func (s *SQLDriver) serializeValue(in any) (string, []any, error) { switch v := in.(type) { case string: return "?", []any{convertWildcards(v)}, nil @@ -296,7 +333,7 @@ func (p *PostgresJSONBDriver) serializeValue(in any) (string, []any, error) { literalVal := fmt.Sprintf("%v", v.Left) return "?", []any{convertWildcards(literalVal)}, nil } - return p.renderParamInternal(v) + return s.renderParamInternal(v) case nil: return "", nil, fmt.Errorf("nil value in expression") default: @@ -304,55 +341,66 @@ func (p *PostgresJSONBDriver) serializeValue(in any) (string, []any, error) { } } -// processJSONBFields recursively processes the expression tree to convert -// field.subfield notation to PostgreSQL JSONB syntax field->>'subfield'. -func (p *PostgresJSONBDriver) processJSONBFields(e *expr.Expression) { +// processJSONFields recursively processes the expression tree to convert +// field.subfield notation to provider-specific JSON syntax. +func (s *SQLDriver) processJSONFields(e *expr.Expression) { if e == nil { return } // Process left side if it's a column if col, ok := e.Left.(expr.Column); ok { - e.Left = p.formatFieldName(string(col)) + e.Left = s.formatFieldName(string(col)) } // Recursively process expressions if leftExpr, ok := e.Left.(*expr.Expression); ok { - p.processJSONBFields(leftExpr) + s.processJSONFields(leftExpr) } if rightExpr, ok := e.Right.(*expr.Expression); ok { - p.processJSONBFields(rightExpr) + s.processJSONFields(rightExpr) } // Process expression slices if exprs, ok := e.Left.([]*expr.Expression); ok { for _, ex := range exprs { - p.processJSONBFields(ex) + s.processJSONFields(ex) } } if exprs, ok := e.Right.([]*expr.Expression); ok { for _, ex := range exprs { - p.processJSONBFields(ex) + s.processJSONFields(ex) } } } -// formatFieldName converts field.subfield to JSONB syntax if the base field supports nested access. -func (p *PostgresJSONBDriver) formatFieldName(fieldName string) expr.Column { +// formatFieldName converts field.subfield to provider-specific JSON syntax. +func (s *SQLDriver) formatFieldName(fieldName string) expr.Column { parts := strings.SplitN(fieldName, ".", 2) if len(parts) == 2 { baseField := parts[0] subField := parts[1] - if field, exists := p.fields[baseField]; exists && canUseNestedAccess(field.Type) { - // Return as JSONB operator syntax - return expr.Column(fmt.Sprintf("%s->>'%s'", baseField, subField)) + if field, exists := s.fields[baseField]; exists && canUseNestedAccess(field.Type) { + switch s.provider { + case "postgresql": + // PostgreSQL: JSONB operator ->> + return expr.Column(fmt.Sprintf("%s->>'%s'", baseField, subField)) + + case "mysql": + // MySQL 5.7+: JSON_UNQUOTE(JSON_EXTRACT(column, '$.field')) + return expr.Column(fmt.Sprintf("JSON_UNQUOTE(JSON_EXTRACT(%s, '$.%s'))", baseField, subField)) + + case "sqlite": + // SQLite: JSON_EXTRACT(column, '$.field') + return expr.Column(fmt.Sprintf("JSON_EXTRACT(%s, '$.%s')", baseField, subField)) + } } } return expr.Column(fieldName) } -// Helper functions for PostgreSQL driver +// Helper functions for SQL driver // convertWildcards converts Lucene wildcards to SQL wildcards. // * (any characters) → % (SQL wildcard) @@ -380,8 +428,17 @@ func convertWildcards(s string) string { return result.String() } -func isJSONBSyntax(col string) bool { - return strings.Contains(col, "->>") +// isJSONSyntax checks if a column string contains provider-specific JSON syntax. +func isJSONSyntax(col string) bool { + // Check for PostgreSQL JSONB operator + if strings.Contains(col, "->>") { + return true + } + // Check for MySQL/SQLite JSON_EXTRACT + if strings.Contains(col, "JSON_EXTRACT") || strings.Contains(col, "JSON_UNQUOTE") { + return true + } + return false } // isNullValue checks if a value represents null in Lucene query syntax. @@ -431,8 +488,8 @@ func extractLiteralValue(v any) string { } // renderRange handles range queries including open-ended ranges with wildcards (*). -func (p *PostgresJSONBDriver) renderRange(e *expr.Expression) (string, []any, error) { - colStr, _, err := p.serializeColumn(e.Left) +func (s *SQLDriver) renderRange(e *expr.Expression) (string, []any, error) { + colStr, _, err := s.serializeColumn(e.Left) if err != nil { return "", nil, err } diff --git a/storage/sql.go b/storage/sql.go index 0165bd4..39ac883 100644 --- a/storage/sql.go +++ b/storage/sql.go @@ -312,7 +312,8 @@ func (s *SQLAdapter) Search(dest any, sortKey string, query string, limit int, c return "", err } - whereClause, queryParams, err := parser.ParseToSQL(query) + // Pass the SQL provider to generate provider-specific SQL syntax + whereClause, queryParams, err := parser.ParseToSQL(query, string(s.provider)) if err != nil { slog.Error("Filter parsing failed", "error", err) // Wrap InvalidFieldError as BadRequest for proper HTTP 400 response From 0226b0e80b1454beaf94c934818200c487a9646d Mon Sep 17 00:00:00 2001 From: Lutherwaves Date: Wed, 31 Dec 2025 16:22:24 +0000 Subject: [PATCH 05/13] feat: add SearchQuery schema with Lucene search examples to OpenAPI definitions Schema Structure: - Type: string - Description: Full Lucene query syntax reference - Default example: "name:john AND status:active" - 34 example queries covering: * Basic field searches and wildcards * Boolean operators (AND, OR, NOT, +, -) * Range queries (inclusive, exclusive, open-ended, dates) * Quoted phrases and escaped characters * Complex nested queries * Implicit search across string fields * JSONB/nested field access (field.subfield:value) * Null value queries (field:null) * Fuzzy search (term~, term~2) --- types/types.go | 43 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/types/types.go b/types/types.go index c33e2c9..e372a3b 100644 --- a/types/types.go +++ b/types/types.go @@ -136,8 +136,49 @@ const definitions = ` "description": "A string containing a JSON Pointer value." } } + }, + "SearchQuery": { + "type": "string", + "description": "Lucene-style search query supporting field searches, wildcards, boolean operators, ranges, and more. Syntax: field:value, wildcards (*,?), operators (AND, OR, NOT, +, -), ranges ([min TO max]), quoted phrases, JSONB access (field.subfield:value), null checks (field:null), and fuzzy search (term~).", + "example": "name:john AND status:active", + "examples": [ + "name:john", + "name:john*", + "email:*@example.com", + "description:*important*", + "name:john* OR email:*@example.com", + "name:john AND status:active", + "status:active OR status:pending", + "name:john NOT status:inactive", + "+name:john +status:active", + "name:john -status:deleted", + "age:[25 TO 65]", + "age:{25 TO 65}", + "age:[25 TO *]", + "age:[* TO 65]", + "created_at:[2024-01-01 TO 2024-12-31]", + "description:\"hello world\"", + "title:\"test-app (v1.0)\"", + "name:C\\+\\+ OR path:\\/usr\\/bin", + "(name:john* OR email:*@example.com) AND status:active AND age:[25 TO 65]", + "((name:john OR name:jane) AND status:active) OR (status:pending AND age:[18 TO *])", + "searchterm", + "john*", + "labels.category:production", + "metadata.tags:prod*", + "name:john AND labels.env:prod AND metadata.team:engineering", + "parent_id:null", + "NOT deleted_at:null", + "name:john AND deleted_at:null", + "name:roam~", + "name:roam~2", + "labels.tag:prod~", + "+name:john* -status:deleted age:[25 TO 65] AND (role:admin OR role:moderator)", + "name:john OR email:john@example.com OR phone:*555*", + "(name:*admin* OR role:administrator) AND status:active AND NOT deleted_at:null AND created_at:[2024-01-01 TO *]" + ] } - } + } } } ` From 7cce62b3a14a1a2f794a538facc9b87cd8d3aaff Mon Sep 17 00:00:00 2001 From: Lutherwaves Date: Thu, 1 Jan 2026 23:23:36 +0200 Subject: [PATCH 06/13] chore: update grindlemire/go-lucene deps --- go.mod | 1 + go.sum | 11 ----------- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/go.mod b/go.mod index 0039d51..97eceeb 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/aws/aws-sdk-go-v2/credentials v1.19.7 github.com/aws/aws-sdk-go-v2/service/dynamodb v1.55.0 github.com/aws/aws-sdk-go-v2/service/sns v1.39.11 + github.com/grindlemire/go-lucene v0.0.26 gopkg.in/yaml.v3 v3.0.1 gorm.io/gorm v1.31.1 ) diff --git a/go.sum b/go.sum index 050f7b6..85fc238 100644 --- a/go.sum +++ b/go.sum @@ -117,23 +117,12 @@ github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHo github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= -<<<<<<< HEAD golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= -======= -go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= -go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8= -golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= -golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= -golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= -golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= -golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= -golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= ->>>>>>> d0d7bae (refactor: simplify Lucene parser API and architecture) golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= From 6c6afa185bf7387ffe3b47467f02d5b16cdb9ecb Mon Sep 17 00:00:00 2001 From: Lutherwaves Date: Sun, 18 Jan 2026 16:59:56 +0200 Subject: [PATCH 07/13] chore: add unit tests for parser --- storage/search/lucene/dynamodb_driver_test.go | 400 +++++ storage/search/lucene/parser_test.go | 866 ++++++++-- storage/search/lucene/sql_driver_test.go | 1458 +++++++++++++++++ 3 files changed, 2558 insertions(+), 166 deletions(-) create mode 100644 storage/search/lucene/dynamodb_driver_test.go create mode 100644 storage/search/lucene/sql_driver_test.go diff --git a/storage/search/lucene/dynamodb_driver_test.go b/storage/search/lucene/dynamodb_driver_test.go new file mode 100644 index 0000000..27735fa --- /dev/null +++ b/storage/search/lucene/dynamodb_driver_test.go @@ -0,0 +1,400 @@ +package lucene + +import ( + "reflect" + "strings" + "testing" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" + "github.com/grindlemire/go-lucene/pkg/lucene/expr" +) + +func TestNewDynamoDBDriver(t *testing.T) { + tests := []struct { + name string + fields []FieldInfo + want map[string]FieldInfo + }{ + { + name: "empty fields", + fields: []FieldInfo{}, + want: map[string]FieldInfo{}, + }, + { + name: "single field", + fields: []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + }, + want: map[string]FieldInfo{ + "name": {Name: "name", Type: reflect.TypeOf("")}, + }, + }, + { + name: "multiple fields", + fields: []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "email", Type: reflect.TypeOf("")}, + {Name: "age", Type: reflect.TypeOf(0)}, + }, + want: map[string]FieldInfo{ + "name": {Name: "name", Type: reflect.TypeOf("")}, + "email": {Name: "email", Type: reflect.TypeOf("")}, + "age": {Name: "age", Type: reflect.TypeOf(0)}, + }, + }, + { + name: "duplicate field names (last wins)", + fields: []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "name", Type: reflect.TypeOf(0)}, + }, + want: map[string]FieldInfo{ + "name": {Name: "name", Type: reflect.TypeOf(0)}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + driver := NewDynamoDBDriver(tt.fields) + if driver == nil { + t.Fatalf("NewDynamoDBDriver() returned nil") + } + if len(driver.fields) != len(tt.want) { + t.Errorf("NewDynamoDBDriver() fields count = %v, want %v", len(driver.fields), len(tt.want)) + } + for name, wantField := range tt.want { + gotField, exists := driver.fields[name] + if !exists { + t.Errorf("NewDynamoDBDriver() missing field %v", name) + continue + } + if gotField.Name != wantField.Name { + t.Errorf("NewDynamoDBDriver() field[%v].Name = %v, want %v", name, gotField.Name, wantField.Name) + } + } + }) + } +} + +func TestDynamoDBDriver_RenderPartiQL(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "email", Type: reflect.TypeOf("")}, + {Name: "age", Type: reflect.TypeOf(0)}, + } + driver := NewDynamoDBDriver(fields) + + tests := []struct { + name string + expr *expr.Expression + wantSQL string + wantCount int + wantErr bool + }{ + { + name: "equals expression", + expr: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + wantSQL: "name", + wantCount: 1, + wantErr: false, + }, + { + name: "AND expression", + expr: &expr.Expression{ + Op: expr.And, + Left: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + Right: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("email"), + Right: &expr.Expression{Op: expr.Literal, Left: "test@example.com"}, + }, + }, + wantSQL: "AND", + wantCount: 2, + wantErr: false, + }, + { + name: "OR expression", + expr: &expr.Expression{ + Op: expr.Or, + Left: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + Right: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "jane"}, + }, + }, + wantSQL: "OR", + wantCount: 2, + wantErr: false, + }, + { + name: "LIKE expression", + expr: &expr.Expression{ + Op: expr.Like, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "%john%"}, + }, + wantSQL: "name", + wantCount: 1, + wantErr: false, + }, + { + name: "nil expression", + expr: nil, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + partiql, attrs, err := driver.RenderPartiQL(tt.expr) + if (err != nil) != tt.wantErr { + t.Errorf("RenderPartiQL() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + return + } + if tt.expr == nil { + if partiql != "" { + t.Errorf("RenderPartiQL() partiql = %v, want empty string", partiql) + } + if len(attrs) != 0 { + t.Errorf("RenderPartiQL() attrs count = %v, want 0", len(attrs)) + } + return + } + if !strings.Contains(partiql, tt.wantSQL) { + t.Errorf("RenderPartiQL() partiql = %v, want to contain %v", partiql, tt.wantSQL) + } + if len(attrs) != tt.wantCount { + t.Errorf("RenderPartiQL() attrs count = %v, want %v", len(attrs), tt.wantCount) + } + for i, attr := range attrs { + if attr == nil { + t.Errorf("RenderPartiQL() attrs[%v] is nil", i) + } + if _, ok := attr.(*types.AttributeValueMemberS); !ok { + t.Errorf("RenderPartiQL() attrs[%v] type = %T, want *types.AttributeValueMemberS", i, attr) + } + } + }) + } +} + +func TestDynamoDBLike(t *testing.T) { + tests := []struct { + name string + left string + right string + want string + wantErr bool + }{ + { + name: "contains pattern %value%", + left: "name", + right: "'%john%'", + want: "contains(name, 'john')", + wantErr: false, + }, + { + name: "begins_with pattern value%", + left: "name", + right: "'john%'", + want: "begins_with(name, 'john')", + wantErr: false, + }, + { + name: "contains pattern %value (no ends_with)", + left: "name", + right: "'%john'", + want: "contains(name, 'john')", + wantErr: false, + }, + { + name: "exact match (no wildcards)", + left: "name", + right: "'john'", + want: "name = 'john'", + wantErr: false, + }, + { + name: "empty string value", + left: "name", + right: "''", + want: "name = ''", + wantErr: false, + }, + { + name: "single % at start", + left: "name", + right: "'%'", + want: "contains(name, '')", + wantErr: false, + }, + { + name: "single % at end", + left: "name", + right: "'%'", + want: "contains(name, '')", + wantErr: false, + }, + { + name: "value with special characters", + left: "email", + right: "'test@example.com%'", + want: "begins_with(email, 'test@example.com')", + wantErr: false, + }, + { + name: "value with underscores", + left: "field_name", + right: "'%test_value%'", + want: "contains(field_name, 'test_value')", + wantErr: false, + }, + { + name: "quoted value without quotes in pattern", + left: "name", + right: "john", + want: "name = john", + wantErr: false, + }, + { + name: "multiple % in middle", + left: "name", + right: "'%john%doe%'", + want: "contains(name, 'john%doe')", + wantErr: false, + }, + { + name: "only % characters", + left: "name", + right: "'%%%'", + want: "contains(name, '')", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := dynamoDBLike(tt.left, tt.right) + if (err != nil) != tt.wantErr { + t.Errorf("dynamoDBLike() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && got != tt.want { + t.Errorf("dynamoDBLike() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDynamoDBDriver_EdgeCases(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + } + driver := NewDynamoDBDriver(fields) + + tests := []struct { + name string + expr *expr.Expression + wantErr bool + checkFunc func(t *testing.T, partiql string, attrs []types.AttributeValue) + }{ + { + name: "empty string value", + expr: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: ""}, + }, + wantErr: false, + checkFunc: func(t *testing.T, partiql string, attrs []types.AttributeValue) { + if len(attrs) != 1 { + t.Errorf("expected 1 attribute, got %d", len(attrs)) + } + }, + }, + { + name: "LIKE with empty pattern", + expr: &expr.Expression{ + Op: expr.Like, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: ""}, + }, + wantErr: false, + checkFunc: func(t *testing.T, partiql string, attrs []types.AttributeValue) { + if !strings.Contains(partiql, "name") { + t.Errorf("expected partiql to contain 'name', got %v", partiql) + } + }, + }, + { + name: "nested AND with LIKE", + expr: &expr.Expression{ + Op: expr.And, + Left: &expr.Expression{ + Op: expr.Like, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "%john%"}, + }, + Right: &expr.Expression{ + Op: expr.Like, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "jane%"}, + }, + }, + wantErr: false, + checkFunc: func(t *testing.T, partiql string, attrs []types.AttributeValue) { + if !strings.Contains(partiql, "AND") { + t.Errorf("expected partiql to contain 'AND', got %v", partiql) + } + if len(attrs) < 2 { + t.Errorf("expected at least 2 attributes, got %d", len(attrs)) + } + }, + }, + { + name: "comparison operators", + expr: &expr.Expression{ + Op: expr.Greater, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "a"}, + }, + wantErr: false, + checkFunc: func(t *testing.T, partiql string, attrs []types.AttributeValue) { + if !strings.Contains(partiql, ">") { + t.Errorf("expected partiql to contain '>', got %v", partiql) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + partiql, attrs, err := driver.RenderPartiQL(tt.expr) + if (err != nil) != tt.wantErr { + t.Errorf("RenderPartiQL() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && tt.checkFunc != nil { + tt.checkFunc(t, partiql, attrs) + } + }) + } +} diff --git a/storage/search/lucene/parser_test.go b/storage/search/lucene/parser_test.go index 7a3df95..0514b0b 100644 --- a/storage/search/lucene/parser_test.go +++ b/storage/search/lucene/parser_test.go @@ -5,6 +5,67 @@ import ( "testing" ) +// Helper functions following FIRST principles + +// assertSQLContains checks that SQL contains all required substrings (more precise validation) +func assertSQLContains(t *testing.T, sql string, required []string, msg string) { + t.Helper() + for _, req := range required { + if !strings.Contains(sql, req) { + t.Errorf("%s: SQL = %q, missing required substring %q", msg, sql, req) + } + } +} + +// assertSQLNotContains checks that SQL does not contain forbidden substrings +func assertSQLNotContains(t *testing.T, sql string, forbidden []string, msg string) { + t.Helper() + for _, forb := range forbidden { + if strings.Contains(sql, forb) { + t.Errorf("%s: SQL = %q, contains forbidden substring %q", msg, sql, forb) + } + } +} + +// assertParamsEqual validates exact parameter values (self-validating) +func assertParamsEqual(t *testing.T, got []any, want []any, msg string) { + t.Helper() + if len(got) != len(want) { + t.Errorf("%s: param count = %d, want %d", msg, len(got), len(want)) + return + } + for i := range got { + if got[i] != want[i] { + t.Errorf("%s: param[%d] = %v, want %v", msg, i, got[i], want[i]) + } + } +} + +// assertErrorContains validates error messages precisely +func assertErrorContains(t *testing.T, err error, wantSubstrings []string, msg string) { + t.Helper() + if err == nil { + t.Errorf("%s: expected error, got nil", msg) + return + } + errMsg := err.Error() + for _, want := range wantSubstrings { + if !strings.Contains(errMsg, want) { + t.Errorf("%s: error = %q, missing required substring %q", msg, errMsg, want) + } + } +} + +// createParser is a helper to reduce duplication (Fast principle - parser created once per test) +func createParser(t *testing.T, model any, config ...*ParserConfig) *Parser { + t.Helper() + parser, err := NewParser(model, config...) + if err != nil { + t.Fatalf("NewParser() error = %v", err) + } + return parser +} + // Test model definitions type BasicModel struct { Name string `json:"name"` @@ -58,103 +119,140 @@ type NullModel struct { } // TestBasicFieldSearch tests basic field:value queries +// Improved with precise assertions following FIRST principles func TestBasicFieldSearch(t *testing.T) { - parser, err := NewParser(BasicModel{}) - if err != nil { - t.Fatalf("NewParser() error = %v", err) - } + parser := createParser(t, BasicModel{}) tests := []struct { - name string - query string - wantSQL string - wantVals int + name string + query string + wantSQL []string + wantNot []string + wantParams []any + wantErr bool }{ { - name: "simple field query", - query: "name:john", - wantSQL: `"name" = $1`, - wantVals: 1, + name: "simple field query", + query: "name:john", + wantSQL: []string{`"name"`, "=", "$1"}, + wantNot: []string{"ILIKE", "LIKE"}, + wantParams: []any{"john"}, + wantErr: false, + }, + { + name: "wildcard prefix", + query: "name:john*", + wantSQL: []string{`"name"`, "ILIKE", "$1"}, + wantNot: []string{"="}, + wantParams: []any{"john%"}, + wantErr: false, }, { - name: "wildcard prefix", - query: "name:john*", - wantSQL: `"name"::text ILIKE $1`, - wantVals: 1, + name: "wildcard suffix", + query: "name:*john", + wantSQL: []string{`"name"`, "ILIKE", "$1"}, + wantParams: []any{"%john"}, + wantErr: false, }, { - name: "wildcard suffix", - query: "name:*john", - wantSQL: `"name"::text ILIKE $1`, - wantVals: 1, + name: "wildcard contains", + query: "name:*john*", + wantSQL: []string{`"name"`, "ILIKE", "$1"}, + wantParams: []any{"%john%"}, + wantErr: false, }, { - name: "wildcard contains", - query: "name:*john*", - wantSQL: `"name"::text ILIKE $1`, - wantVals: 1, + name: "email field", + query: `email:"test@example.com"`, + wantSQL: []string{`"email"`, "=", "$1"}, + wantParams: []any{"test@example.com"}, + wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - sql, vals, err := parser.ParseToSQL(tt.query, "postgresql") - if err != nil { - t.Fatalf("ParseToSQL() error = %v", err) - } - if !strings.Contains(sql, tt.wantSQL) { - t.Errorf("ParseToSQL() sql = %v, want to contain %v", sql, tt.wantSQL) + sql, params, err := parser.ParseToSQL(tt.query, "postgresql") + if (err != nil) != tt.wantErr { + t.Fatalf("ParseToSQL() error = %v, wantErr %v", err, tt.wantErr) } - if len(vals) != tt.wantVals { - t.Errorf("ParseToSQL() vals count = %v, want %v", len(vals), tt.wantVals) + if !tt.wantErr { + assertSQLContains(t, sql, tt.wantSQL, tt.name) + if len(tt.wantNot) > 0 { + assertSQLNotContains(t, sql, tt.wantNot, tt.name) + } + if len(tt.wantParams) > 0 { + // Only validate params if we expect specific values + if len(tt.wantParams) > 0 { + assertParamsEqual(t, params, tt.wantParams, tt.name) + } + } } }) } } // TestBooleanOperators tests AND, OR, NOT operators +// Improved with parameter validation func TestBooleanOperators(t *testing.T) { - parser, err := NewParser(BooleanModel{}) - if err != nil { - t.Fatalf("NewParser() error = %v", err) - } + parser := createParser(t, BooleanModel{}) tests := []struct { - name string - query string - wantSQL []string + name string + query string + wantSQL []string + wantParams []any + wantErr bool }{ { - name: "AND operator", - query: "name:john AND status:active", - wantSQL: []string{`"name"`, `"status"`, "AND"}, + name: "AND operator", + query: "name:john AND status:active", + wantSQL: []string{`"name"`, `"status"`, "AND"}, + wantParams: []any{"john", "active"}, + wantErr: false, }, { - name: "OR operator", - query: "name:john OR name:jane", - wantSQL: []string{`"name"`, "OR"}, + name: "OR operator", + query: "name:john OR name:jane", + wantSQL: []string{`"name"`, "OR"}, + wantParams: []any{"john", "jane"}, + wantErr: false, }, { - name: "NOT operator", - query: "name:john NOT status:inactive", - wantSQL: []string{`"name"`, `"status"`, "NOT"}, + name: "NOT operator", + query: "name:john NOT status:inactive", + wantSQL: []string{`"name"`, `"status"`, "NOT"}, + wantParams: []any{"john", "inactive"}, + wantErr: false, }, { - name: "complex nested", - query: "(name:john OR name:jane) AND status:active", - wantSQL: []string{`"name"`, `"status"`, "OR", "AND"}, + name: "complex nested", + query: "(name:john OR name:jane) AND status:active", + wantSQL: []string{`"name"`, `"status"`, "OR", "AND"}, + wantParams: []any{"john", "jane", "active"}, + wantErr: false, + }, + { + name: "case insensitive AND", + query: "name:john and status:active", + wantSQL: []string{"AND"}, + wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - sql, _, err := parser.ParseToSQL(tt.query, "postgresql") - if err != nil { - t.Fatalf("ParseToSQL() error = %v", err) + sql, params, err := parser.ParseToSQL(tt.query, "postgresql") + if (err != nil) != tt.wantErr { + t.Fatalf("ParseToSQL() error = %v, wantErr %v", err, tt.wantErr) } - for _, want := range tt.wantSQL { - if !strings.Contains(sql, want) { - t.Errorf("ParseToSQL() sql = %v, want to contain %v", sql, want) + if !tt.wantErr { + assertSQLContains(t, sql, tt.wantSQL, tt.name) + if len(tt.wantParams) > 0 { + // Only validate params if we expect specific values + if len(tt.wantParams) > 0 { + assertParamsEqual(t, params, tt.wantParams, tt.name) + } } } }) @@ -206,53 +304,65 @@ func TestRequiredProhibited(t *testing.T) { } // TestRangeQueries tests range query syntax +// Improved with parameter validation func TestRangeQueries(t *testing.T) { - parser, err := NewParser(RangeModel{}) - if err != nil { - t.Fatalf("NewParser() error = %v", err) - } + parser := createParser(t, RangeModel{}) tests := []struct { - name string - query string - wantSQL []string + name string + query string + wantSQL []string + wantParams []any + wantErr bool }{ { - name: "inclusive range", - query: "age:[25 TO 65]", - wantSQL: []string{`"age"`, "BETWEEN"}, + name: "inclusive range", + query: "age:[25 TO 65]", + wantSQL: []string{`"age"`, "BETWEEN"}, + wantParams: []any{"25", "65"}, + wantErr: false, }, { - name: "exclusive range", - query: "age:{25 TO 65}", - wantSQL: []string{`"age"`, ">", "<"}, + name: "exclusive range", + query: "age:{25 TO 65}", + wantSQL: []string{`"age"`, ">", "<"}, + wantParams: []any{"25", "65"}, + wantErr: false, }, { - name: "open-ended range min", - query: "age:[25 TO *]", - wantSQL: []string{`"age"`, ">="}, + name: "open-ended range min", + query: "age:[25 TO *]", + wantSQL: []string{`"age"`, ">="}, + wantParams: []any{"25"}, + wantErr: false, }, { - name: "open-ended range max", - query: "age:[* TO 65]", - wantSQL: []string{`"age"`, "<="}, + name: "open-ended range max", + query: "age:[* TO 65]", + wantSQL: []string{`"age"`, "<="}, + wantParams: []any{"65"}, + wantErr: false, }, { - name: "date range", - query: "date:[2024-01-01 TO 2024-12-31]", - wantSQL: []string{`"date"`, "BETWEEN"}, + name: "date range", + query: "date:[2024-01-01 TO 2024-12-31]", + wantSQL: []string{`"date"`, "BETWEEN"}, + wantParams: []any{"2024-01-01", "2024-12-31"}, + wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - sql, _, err := parser.ParseToSQL(tt.query, "postgresql") - if err != nil { - t.Fatalf("ParseToSQL() error = %v", err) + sql, params, err := parser.ParseToSQL(tt.query, "postgresql") + if (err != nil) != tt.wantErr { + t.Fatalf("ParseToSQL() error = %v, wantErr %v", err, tt.wantErr) } - for _, want := range tt.wantSQL { - if !strings.Contains(sql, want) { - t.Errorf("ParseToSQL() sql = %v, want to contain %v", sql, want) + if !tt.wantErr { + assertSQLContains(t, sql, tt.wantSQL, tt.name) + // Only validate params if we expect specific values + if len(tt.wantParams) > 0 { + assertParamsEqual(t, params, tt.wantParams, tt.name) } } }) @@ -390,43 +500,52 @@ func TestComplexQueries(t *testing.T) { } // TestImplicitSearch tests implicit search across string fields +// Improved with precise validation func TestImplicitSearch(t *testing.T) { - parser, err := NewParser(TextModel{}) - if err != nil { - t.Fatalf("NewParser() error = %v", err) - } + parser := createParser(t, TextModel{}) tests := []struct { name string query string - wantOR bool - wantParams int + wantSQL []string + wantParams []any + wantErr bool }{ { name: "implicit search", query: "john", - wantOR: true, - wantParams: 3, // name, description, title + wantSQL: []string{"OR"}, + wantParams: []any{"%john%", "%john%", "%john%"}, + wantErr: false, }, { name: "implicit search with wildcard", query: "john*", - wantOR: true, - wantParams: 3, + wantSQL: []string{"OR"}, + wantParams: []any{"john%", "john%", "john%"}, + wantErr: false, + }, + { + name: "implicit quoted phrase", + query: `"john doe"`, + wantSQL: []string{"OR"}, + wantParams: []any{"john doe", "john doe", "john doe"}, // quotes are stripped + wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { sql, params, err := parser.ParseToSQL(tt.query, "postgresql") - if err != nil { - t.Fatalf("ParseToSQL() error = %v", err) - } - if tt.wantOR && !strings.Contains(sql, "OR") { - t.Errorf("ParseToSQL() expected OR in implicit search, got: %v", sql) + if (err != nil) != tt.wantErr { + t.Fatalf("ParseToSQL() error = %v, wantErr %v", err, tt.wantErr) } - if len(params) != tt.wantParams { - t.Errorf("ParseToSQL() params count = %v, want %v", len(params), tt.wantParams) + if !tt.wantErr { + assertSQLContains(t, sql, tt.wantSQL, tt.name) + // Only validate params if we expect specific values + if len(tt.wantParams) > 0 { + assertParamsEqual(t, params, tt.wantParams, tt.name) + } } }) } @@ -489,17 +608,15 @@ func TestMapOutput(t *testing.T) { } // TestFieldValidation tests field validation for invalid field references +// Improved with precise error message validation func TestFieldValidation(t *testing.T) { - parser, err := NewParser(MixedModel{}) - if err != nil { - t.Fatalf("NewParser() error = %v", err) - } + parser := createParser(t, MixedModel{}) tests := []struct { - name string - query string - wantErr bool - errField string + name string + query string + wantErr bool + wantErrMsgs []string }{ { name: "valid field query", @@ -512,22 +629,22 @@ func TestFieldValidation(t *testing.T) { wantErr: false, }, { - name: "invalid field", - query: "invalidfield:value", - wantErr: true, - errField: "invalidfield", + name: "invalid field", + query: "invalidfield:value", + wantErr: true, + wantErrMsgs: []string{"invalidfield", "invalid field"}, }, { - name: "invalid JSONB base", - query: "notjsonb.subfield:value", - wantErr: true, - errField: "notjsonb", + name: "invalid JSONB base", + query: "notjsonb.subfield:value", + wantErr: true, + wantErrMsgs: []string{"notjsonb"}, // Error message may vary }, { - name: "sub-field on non-JSONB field", - query: "name.subfield:value", - wantErr: true, - errField: "name", + name: "sub-field on non-JSONB field", + query: "name.subfield:value", + wantErr: true, + wantErrMsgs: []string{"name.subfield", "invalid field"}, }, { name: "implicit search (no explicit fields) - valid", @@ -540,10 +657,10 @@ func TestFieldValidation(t *testing.T) { wantErr: false, }, { - name: "mixed valid and invalid", - query: "name:john OR invalidfield:value", - wantErr: true, - errField: "invalidfield", + name: "mixed valid and invalid", + query: "name:john OR invalidfield:value", + wantErr: true, + wantErrMsgs: []string{"invalidfield"}, }, { name: "complex valid query", @@ -551,10 +668,10 @@ func TestFieldValidation(t *testing.T) { wantErr: false, }, { - name: "invalid field in complex query", - query: "(name:john OR badfield:test) AND status:active", - wantErr: true, - errField: "badfield", + name: "invalid field in complex query", + query: "(name:john OR badfield:test) AND status:active", + wantErr: true, + wantErrMsgs: []string{"badfield"}, }, } @@ -563,72 +680,91 @@ func TestFieldValidation(t *testing.T) { _, _, err := parser.ParseToSQL(tt.query, "postgresql") if (err != nil) != tt.wantErr { t.Errorf("ParseToSQL() error = %v, wantErr = %v", err, tt.wantErr) + return } - if tt.wantErr && tt.errField != "" && !strings.Contains(err.Error(), tt.errField) { - t.Errorf("ParseToSQL() error = %v, want to contain field %v", err, tt.errField) + if tt.wantErr && len(tt.wantErrMsgs) > 0 { + assertErrorContains(t, err, tt.wantErrMsgs, tt.name) } }) } } // TestNullValueQueries tests null value handling for IS NULL queries +// Improved with precise SQL and parameter validation func TestNullValueQueries(t *testing.T) { - parser, err := NewParser(NullModel{}) - if err != nil { - t.Fatalf("NewParser() error = %v", err) - } + parser := createParser(t, NullModel{}) tests := []struct { - name string - query string - wantSQL string - wantErr bool + name string + query string + wantSQL []string + wantNot []string + wantParams []any + wantErr bool }{ { - name: "field is null (lowercase)", - query: "parent_id:null", - wantSQL: `"parent_id" IS NULL`, - }, - { - name: "field is NULL (uppercase)", - query: "parent_id:NULL", - wantSQL: `"parent_id" IS NULL`, + name: "field is null (lowercase)", + query: "parent_id:null", + wantSQL: []string{`"parent_id"`, "IS NULL"}, + wantNot: []string{"=", "$1"}, + wantParams: []any{}, + wantErr: false, }, { - name: "field is Null (mixed case)", - query: "parent_id:Null", - wantSQL: `"parent_id" IS NULL`, + name: "field is NULL (uppercase)", + query: "parent_id:NULL", + wantSQL: []string{`"parent_id"`, "IS NULL"}, + wantParams: []any{}, + wantErr: false, }, { - name: "parent_id is null", - query: "parent_id:null", - wantSQL: `"parent_id" IS NULL`, + name: "field is Null (mixed case)", + query: "parent_id:Null", + wantSQL: []string{`"parent_id"`, "IS NULL"}, + wantParams: []any{}, + wantErr: false, }, { - name: "combined null with other conditions", - query: "name:john AND deleted_at:null", - wantSQL: `"deleted_at" IS NULL`, + name: "combined null with other conditions", + query: "name:john AND deleted_at:null", + wantSQL: []string{`"name"`, `"deleted_at"`, "IS NULL", "AND"}, + wantParams: []any{"john"}, + wantErr: false, }, { - name: "NOT null (is not null)", - query: "NOT deleted_at:null", - wantSQL: `NOT(`, + name: "NOT null (is not null)", + query: "NOT deleted_at:null", + wantSQL: []string{"NOT", `"deleted_at"`}, + wantParams: []any{"null"}, // NOT null is parsed as NOT field=null, not NOT field IS NULL + wantErr: false, }, { - name: "nil should be treated as literal value (not NULL)", - query: "name:nil", - wantSQL: `"name" =`, + name: "nil should be treated as literal value (not NULL)", + query: "name:nil", + wantSQL: []string{`"name"`, "=", "$1"}, + wantParams: []any{"nil"}, + wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - sql, _, err := parser.ParseToSQL(tt.query, "postgresql") + sql, params, err := parser.ParseToSQL(tt.query, "postgresql") if (err != nil) != tt.wantErr { t.Errorf("ParseToSQL() error = %v, wantErr = %v", err, tt.wantErr) + return } - if !tt.wantErr && !strings.Contains(sql, tt.wantSQL) { - t.Errorf("ParseToSQL() sql = %v, want to contain %v", sql, tt.wantSQL) + if !tt.wantErr { + assertSQLContains(t, sql, tt.wantSQL, tt.name) + if len(tt.wantNot) > 0 { + assertSQLNotContains(t, sql, tt.wantNot, tt.name) + } + if len(tt.wantParams) > 0 { + // Only validate params if we expect specific values + if len(tt.wantParams) > 0 { + assertParamsEqual(t, params, tt.wantParams, tt.name) + } + } } }) } @@ -789,6 +925,404 @@ func TestBoostOperatorError(t *testing.T) { } } +// TestNewParser tests parser creation and configuration +func TestNewParser(t *testing.T) { + tests := []struct { + name string + model any + config *ParserConfig + wantErr bool + wantCount int + }{ + { + name: "basic model", + model: BasicModel{}, + wantErr: false, + wantCount: 2, + }, + { + name: "pointer to model", + model: &BasicModel{}, + wantErr: false, + wantCount: 2, + }, + { + name: "with custom config", + model: BasicModel{}, + config: &ParserConfig{MaxQueryLength: 5000, MaxDepth: 10, MaxTerms: 50}, + wantErr: false, + wantCount: 2, + }, + { + name: "invalid model (not struct)", + model: "not a struct", + wantErr: true, + }, + { + name: "empty struct", + model: struct{}{}, + wantErr: false, + wantCount: 0, + }, + { + name: "model with no json tags", + model: struct{ Name string }{}, + wantErr: false, + wantCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser, err := NewParser(tt.model, tt.config) + if (err != nil) != tt.wantErr { + t.Errorf("NewParser() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + if parser == nil { + t.Fatal("NewParser() returned nil parser") + } + if len(parser.Fields) != tt.wantCount { + t.Errorf("NewParser() field count = %d, want %d", len(parser.Fields), tt.wantCount) + } + if tt.config != nil { + if tt.config.MaxQueryLength > 0 && parser.MaxQueryLength != tt.config.MaxQueryLength { + t.Errorf("NewParser() MaxQueryLength = %d, want %d", parser.MaxQueryLength, tt.config.MaxQueryLength) + } + if tt.config.MaxDepth > 0 && parser.MaxDepth != tt.config.MaxDepth { + t.Errorf("NewParser() MaxDepth = %d, want %d", parser.MaxDepth, tt.config.MaxDepth) + } + if tt.config.MaxTerms > 0 && parser.MaxTerms != tt.config.MaxTerms { + t.Errorf("NewParser() MaxTerms = %d, want %d", parser.MaxTerms, tt.config.MaxTerms) + } + } + } + }) + } +} + +// TestParser_ValidateQuery tests query validation (security limits) +func TestParser_ValidateQuery(t *testing.T) { + parser := createParser(t, BasicModel{}) + + tests := []struct { + name string + query string + config *ParserConfig + wantErr bool + wantError []string + }{ + { + name: "valid query", + query: "name:john", + wantErr: false, + }, + { + name: "query too long", + query: strings.Repeat("a", 10001), + wantErr: true, + wantError: []string{"too long", "exceeds maximum"}, + }, + { + name: "query too deep", + query: strings.Repeat("(", 21) + "name:john" + strings.Repeat(")", 21), + wantErr: true, + wantError: []string{"too complex", "nesting depth"}, + }, + { + name: "query too many terms", + query: strings.Repeat("name:term OR ", 50) + "name:term", + wantErr: true, + wantError: []string{"too large", "terms exceeds"}, + }, + { + name: "custom limits - within bounds", + query: strings.Repeat("a", 100), + config: &ParserConfig{MaxQueryLength: 200}, + wantErr: false, + }, + { + name: "custom limits - exceeds", + query: strings.Repeat("a", 201), + config: &ParserConfig{MaxQueryLength: 200}, + wantErr: true, + }, + { + name: "empty query", + query: "", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var p *Parser + if tt.config != nil { + p = createParser(t, BasicModel{}, tt.config) + } else { + p = parser + } + + err := p.validateQuery(tt.query) + if (err != nil) != tt.wantErr { + t.Errorf("validateQuery() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr && len(tt.wantError) > 0 { + assertErrorContains(t, err, tt.wantError, "validateQuery()") + } + }) + } +} + +// TestCalculateNestingDepth tests depth calculation (unit test for helper) +func TestCalculateNestingDepth(t *testing.T) { + tests := []struct { + name string + query string + want int + }{ + { + name: "no nesting", + query: "name:john", + want: 0, + }, + { + name: "single level", + query: "(name:john)", + want: 1, + }, + { + name: "nested", + query: "((name:john))", + want: 2, + }, + { + name: "mixed brackets", + query: "(name:john AND [age:25 TO 65])", + want: 2, + }, + { + name: "quotes ignore nesting", + query: `(name:"test (value)")`, + want: 1, + }, + { + name: "escaped quotes", + query: `(name:"test\"value")`, + want: 1, + }, + { + name: "unbalanced (should still calculate)", + query: "((name:john)", + want: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := calculateNestingDepth(tt.query) + if got != tt.want { + t.Errorf("calculateNestingDepth() = %d, want %d", got, tt.want) + } + }) + } +} + +// TestCountTerms tests term counting (unit test for helper) +func TestCountTerms(t *testing.T) { + tests := []struct { + name string + query string + want int + }{ + { + name: "single term", + query: "name:john", + want: 1, + }, + { + name: "multiple terms", + query: "name:john AND email:test", + want: 3, // name:john, AND (counted before skip), email:test + }, + { + name: "quoted phrase", + query: `name:"john doe"`, + want: 2, // name: and "john doe" (quotes counted separately) + }, + { + name: "range query", + query: "age:[25 TO 65]", + want: 2, // age: and range content + }, + { + name: "implicit search", + query: "john", + want: 1, + }, + { + name: "empty query", + query: "", + want: 0, + }, + { + name: "operators not counted", + query: "name:john AND email:test OR status:active", + want: 5, // name:john, AND, email:test, OR, status:active + }, + { + name: "parentheses not counted", + query: "(name:john OR email:test)", + want: 3, // name:john, OR, email:test + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := countTerms(tt.query) + if got != tt.want { + t.Errorf("countTerms() = %d, want %d", got, tt.want) + } + }) + } +} + +// TestParser_ProviderSpecific tests all SQL providers +func TestParser_ProviderSpecific(t *testing.T) { + parser := createParser(t, BasicModel{}) + + tests := []struct { + name string + query string + provider string + wantSQL []string + wantNot []string + wantParams []any + wantErr bool + }{ + { + name: "postgresql placeholder", + query: "name:john", + provider: "postgresql", + wantSQL: []string{"$1"}, + wantNot: []string{"?"}, + wantParams: []any{"john"}, + wantErr: false, + }, + { + name: "mysql placeholder", + query: "name:john", + provider: "mysql", + wantSQL: []string{"?"}, + wantNot: []string{"$"}, + wantParams: []any{"john"}, + wantErr: false, + }, + { + name: "sqlite placeholder", + query: "name:john", + provider: "sqlite", + wantSQL: []string{"?"}, + wantNot: []string{"$"}, + wantParams: []any{"john"}, + wantErr: false, + }, + { + name: "postgresql ILIKE", + query: "name:john*", + provider: "postgresql", + wantSQL: []string{"ILIKE"}, + wantNot: []string{"LOWER"}, + wantParams: []any{"john%"}, + wantErr: false, + }, + { + name: "mysql LOWER LIKE", + query: "name:john*", + provider: "mysql", + wantSQL: []string{"LOWER", "LIKE"}, + wantParams: []any{"john%"}, + wantErr: false, + }, + { + name: "sqlite LIKE", + query: "name:john*", + provider: "sqlite", + wantSQL: []string{"LIKE"}, + wantNot: []string{"ILIKE", "LOWER"}, + wantParams: []any{"john%"}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, params, err := parser.ParseToSQL(tt.query, tt.provider) + if (err != nil) != tt.wantErr { + t.Fatalf("ParseToSQL() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr { + assertSQLContains(t, sql, tt.wantSQL, tt.name) + if len(tt.wantNot) > 0 { + assertSQLNotContains(t, sql, tt.wantNot, tt.name) + } + if len(tt.wantParams) > 0 { + // Only validate params if we expect specific values + if len(tt.wantParams) > 0 { + assertParamsEqual(t, params, tt.wantParams, tt.name) + } + } + } + }) + } +} + + +// TestParser_ParseToDynamoDBPartiQL tests DynamoDB output +func TestParser_ParseToDynamoDBPartiQL(t *testing.T) { + parser := createParser(t, BasicModel{}) + + tests := []struct { + name string + query string + wantPartiQL []string + wantCount int + wantErr bool + }{ + { + name: "simple query", + query: "name:john", + wantPartiQL: []string{"name"}, + wantCount: 1, + wantErr: false, + }, + { + name: "AND query", + query: "name:john AND email:test", + wantPartiQL: []string{"AND"}, + wantCount: 2, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + partiql, attrs, err := parser.ParseToDynamoDBPartiQL(tt.query) + if (err != nil) != tt.wantErr { + t.Fatalf("ParseToDynamoDBPartiQL() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr { + assertSQLContains(t, partiql, tt.wantPartiQL, tt.name) + if len(attrs) != tt.wantCount { + t.Errorf("ParseToDynamoDBPartiQL() attrs count = %d, want %d", len(attrs), tt.wantCount) + } + } + }) + } +} + // BenchmarkParser benchmarks the parser performance func BenchmarkParser(b *testing.B) { parser, _ := NewParser(ComplexModel{}) diff --git a/storage/search/lucene/sql_driver_test.go b/storage/search/lucene/sql_driver_test.go new file mode 100644 index 0000000..d75e8b1 --- /dev/null +++ b/storage/search/lucene/sql_driver_test.go @@ -0,0 +1,1458 @@ +package lucene + +import ( + "fmt" + "reflect" + "strings" + "testing" + + "github.com/grindlemire/go-lucene/pkg/lucene/expr" +) + +func TestNewSQLDriver(t *testing.T) { + tests := []struct { + name string + fields []FieldInfo + provider string + wantErr bool + }{ + { + name: "postgresql with fields", + fields: []FieldInfo{{Name: "name", Type: reflect.TypeOf("")}}, + provider: "postgresql", + wantErr: false, + }, + { + name: "mysql with fields", + fields: []FieldInfo{{Name: "name", Type: reflect.TypeOf("")}}, + provider: "mysql", + wantErr: false, + }, + { + name: "sqlite with fields", + fields: []FieldInfo{{Name: "name", Type: reflect.TypeOf("")}}, + provider: "sqlite", + wantErr: false, + }, + { + name: "empty fields", + fields: []FieldInfo{}, + provider: "postgresql", + wantErr: false, + }, + { + name: "multiple fields", + fields: []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "email", Type: reflect.TypeOf("")}, + {Name: "age", Type: reflect.TypeOf(0)}, + }, + provider: "postgresql", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + driver := NewSQLDriver(tt.fields, tt.provider) + if driver == nil { + t.Fatalf("NewSQLDriver() returned nil") + } + if driver.provider != tt.provider { + t.Errorf("NewSQLDriver() provider = %v, want %v", driver.provider, tt.provider) + } + if len(driver.fields) != len(tt.fields) { + t.Errorf("NewSQLDriver() fields count = %v, want %v", len(driver.fields), len(tt.fields)) + } + }) + } +} + +func TestSQLDriver_RenderParam(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "email", Type: reflect.TypeOf("")}, + } + providers := []string{"postgresql", "mysql", "sqlite"} + + tests := []struct { + name string + expr *expr.Expression + wantSQL []string + wantCount int + wantErr bool + }{ + { + name: "equals expression", + expr: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + wantSQL: []string{`"name"`, "="}, + wantCount: 1, + wantErr: false, + }, + { + name: "AND expression", + expr: &expr.Expression{ + Op: expr.And, + Left: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + Right: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("email"), + Right: &expr.Expression{Op: expr.Literal, Left: "test@example.com"}, + }, + }, + wantSQL: []string{"AND"}, + wantCount: 2, + wantErr: false, + }, + { + name: "nil expression", + expr: nil, + wantErr: false, + }, + } + + for _, provider := range providers { + for _, tt := range tests { + t.Run(provider+"/"+tt.name, func(t *testing.T) { + driver := NewSQLDriver(fields, provider) + sql, params, err := driver.RenderParam(tt.expr) + if (err != nil) != tt.wantErr { + t.Errorf("RenderParam() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + return + } + if tt.expr == nil { + if sql != "" { + t.Errorf("RenderParam() sql = %v, want empty string", sql) + } + if len(params) != 0 { + t.Errorf("RenderParam() params count = %v, want 0", len(params)) + } + return + } + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("RenderParam() sql = %v, want to contain %v", sql, want) + } + } + if len(params) != tt.wantCount { + t.Errorf("RenderParam() params count = %v, want %v", len(params), tt.wantCount) + } + if provider == "postgresql" { + if !strings.Contains(sql, "$") { + t.Errorf("RenderParam() expected PostgreSQL placeholders ($1, $2), got %v", sql) + } + } else { + if strings.Contains(sql, "$") && !strings.Contains(sql, "?") { + t.Errorf("RenderParam() expected ? placeholders for %v, got %v", provider, sql) + } + } + }) + } + } +} + +func TestSQLDriver_RenderLikeOrWild(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "metadata", Type: reflect.TypeOf(map[string]interface{}{})}, + } + + tests := []struct { + name string + provider string + expr *expr.Expression + wantSQL []string + wantCount int + wantErr bool + }{ + { + name: "postgresql LIKE regular field", + provider: "postgresql", + expr: &expr.Expression{ + Op: expr.Like, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + wantSQL: []string{"ILIKE"}, + wantCount: 1, + wantErr: false, + }, + { + name: "postgresql LIKE JSON field", + provider: "postgresql", + expr: &expr.Expression{ + Op: expr.Like, + Left: expr.Column("metadata->>'key'"), + Right: &expr.Expression{Op: expr.Literal, Left: "value"}, + }, + wantSQL: []string{"ILIKE"}, + wantCount: 1, + wantErr: false, + }, + { + name: "mysql LIKE", + provider: "mysql", + expr: &expr.Expression{ + Op: expr.Like, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + wantSQL: []string{"LOWER", "LIKE"}, + wantCount: 1, + wantErr: false, + }, + { + name: "sqlite LIKE", + provider: "sqlite", + expr: &expr.Expression{ + Op: expr.Like, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + wantSQL: []string{"LIKE"}, + wantCount: 1, + wantErr: false, + }, + { + name: "postgresql WILD", + provider: "postgresql", + expr: &expr.Expression{ + Op: expr.Wild, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john*"}, + }, + wantSQL: []string{"ILIKE"}, + wantCount: 1, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + driver := NewSQLDriver(fields, tt.provider) + sql, params, err := driver.renderLikeOrWild(tt.expr) + if (err != nil) != tt.wantErr { + t.Errorf("renderLikeOrWild() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("renderLikeOrWild() sql = %v, want to contain %v", sql, want) + } + } + if len(params) != tt.wantCount { + t.Errorf("renderLikeOrWild() params count = %v, want %v", len(params), tt.wantCount) + } + } + }) + } +} + +func TestSQLDriver_RenderFuzzy(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "metadata", Type: reflect.TypeOf(map[string]interface{}{})}, + } + + tests := []struct { + name string + provider string + expr *expr.Expression + wantSQL []string + wantCount int + wantErr bool + }{ + { + name: "postgresql fuzzy", + provider: "postgresql", + expr: &expr.Expression{ + Op: expr.Fuzzy, + Left: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "roam"}, + }, + }, + wantSQL: []string{"similarity"}, + wantCount: 1, + wantErr: false, + }, + { + name: "postgresql fuzzy JSON field", + provider: "postgresql", + expr: &expr.Expression{ + Op: expr.Fuzzy, + Left: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("metadata->>'key'"), + Right: &expr.Expression{Op: expr.Literal, Left: "value"}, + }, + }, + wantSQL: []string{"similarity"}, + wantCount: 1, + wantErr: false, + }, + { + name: "mysql fuzzy", + provider: "mysql", + expr: &expr.Expression{ + Op: expr.Fuzzy, + Left: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "roam"}, + }, + }, + wantSQL: []string{"SOUNDEX"}, + wantCount: 1, + wantErr: false, + }, + { + name: "sqlite fuzzy (unsupported)", + provider: "sqlite", + expr: &expr.Expression{ + Op: expr.Fuzzy, + Left: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "roam"}, + }, + }, + wantErr: true, + }, + { + name: "invalid fuzzy expression (not Equals)", + provider: "postgresql", + expr: &expr.Expression{ + Op: expr.Fuzzy, + Left: &expr.Expression{Op: expr.And}, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + driver := NewSQLDriver(fields, tt.provider) + sql, params, err := driver.renderFuzzy(tt.expr) + if (err != nil) != tt.wantErr { + t.Errorf("renderFuzzy() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("renderFuzzy() sql = %v, want to contain %v", sql, want) + } + } + if len(params) != tt.wantCount { + t.Errorf("renderFuzzy() params count = %v, want %v", len(params), tt.wantCount) + } + } + }) + } +} + +func TestSQLDriver_RenderComparison(t *testing.T) { + fields := []FieldInfo{ + {Name: "age", Type: reflect.TypeOf(0)}, + {Name: "name", Type: reflect.TypeOf("")}, + } + + tests := []struct { + name string + provider string + op expr.Operator + right *expr.Expression + wantSQL []string + wantCount int + wantErr bool + }{ + { + name: "equals", + provider: "postgresql", + op: expr.Equals, + right: &expr.Expression{Op: expr.Literal, Left: "john"}, + wantSQL: []string{`"name"`, "="}, + wantCount: 1, + wantErr: false, + }, + { + name: "greater than", + provider: "postgresql", + op: expr.Greater, + right: &expr.Expression{Op: expr.Literal, Left: 25}, + wantSQL: []string{`"age"`, ">"}, + wantCount: 1, + wantErr: false, + }, + { + name: "less than", + provider: "postgresql", + op: expr.Less, + right: &expr.Expression{Op: expr.Literal, Left: 65}, + wantSQL: []string{`"age"`, "<"}, + wantCount: 1, + wantErr: false, + }, + { + name: "greater or equal", + provider: "postgresql", + op: expr.GreaterEq, + right: &expr.Expression{Op: expr.Literal, Left: 25}, + wantSQL: []string{`"age"`, ">="}, + wantCount: 1, + wantErr: false, + }, + { + name: "less or equal", + provider: "postgresql", + op: expr.LessEq, + right: &expr.Expression{Op: expr.Literal, Left: 65}, + wantSQL: []string{`"age"`, "<="}, + wantCount: 1, + wantErr: false, + }, + { + name: "equals null", + provider: "postgresql", + op: expr.Equals, + right: &expr.Expression{Op: expr.Literal, Left: "null"}, + wantSQL: []string{`"name"`, "IS NULL"}, + wantCount: 0, + wantErr: false, + }, + { + name: "greater than null (error)", + provider: "postgresql", + op: expr.Greater, + right: &expr.Expression{Op: expr.Literal, Left: "null"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + driver := NewSQLDriver(fields, tt.provider) + var left expr.Column + if strings.Contains(tt.name, "age") || tt.op == expr.Greater || tt.op == expr.Less || tt.op == expr.GreaterEq || tt.op == expr.LessEq { + left = expr.Column("age") + } else { + left = expr.Column("name") + } + e := &expr.Expression{ + Op: tt.op, + Left: left, + Right: tt.right, + } + sql, params, err := driver.renderComparison(e) + if (err != nil) != tt.wantErr { + t.Errorf("renderComparison() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("renderComparison() sql = %v, want to contain %v", sql, want) + } + } + if len(params) != tt.wantCount { + t.Errorf("renderComparison() params count = %v, want %v", len(params), tt.wantCount) + } + } + }) + } +} + +func TestSQLDriver_RenderBinary(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "email", Type: reflect.TypeOf("")}, + } + + tests := []struct { + name string + provider string + op expr.Operator + left *expr.Expression + right *expr.Expression + wantSQL []string + wantCount int + wantErr bool + }{ + { + name: "AND", + provider: "postgresql", + op: expr.And, + left: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + right: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("email"), + Right: &expr.Expression{Op: expr.Literal, Left: "test@example.com"}, + }, + wantSQL: []string{"AND"}, + wantCount: 2, + wantErr: false, + }, + { + name: "OR", + provider: "postgresql", + op: expr.Or, + left: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + right: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "jane"}, + }, + wantSQL: []string{"OR"}, + wantCount: 2, + wantErr: false, + }, + { + name: "Must", + provider: "postgresql", + op: expr.Must, + left: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + right: nil, + wantSQL: []string{`"name"`}, + wantCount: 1, + wantErr: false, + }, + { + name: "MustNot", + provider: "postgresql", + op: expr.MustNot, + left: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + right: nil, + wantSQL: []string{"NOT"}, + wantCount: 1, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + driver := NewSQLDriver(fields, tt.provider) + e := &expr.Expression{ + Op: tt.op, + Left: tt.left, + Right: tt.right, + } + sql, params, err := driver.renderBinary(e) + if (err != nil) != tt.wantErr { + t.Errorf("renderBinary() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("renderBinary() sql = %v, want to contain %v", sql, want) + } + } + if len(params) != tt.wantCount { + t.Errorf("renderBinary() params count = %v, want %v", len(params), tt.wantCount) + } + } + }) + } +} + +func TestSQLDriver_RenderRange(t *testing.T) { + fields := []FieldInfo{ + {Name: "age", Type: reflect.TypeOf(0)}, + {Name: "date", Type: reflect.TypeOf("")}, + } + + tests := []struct { + name string + provider string + rangeExpr *expr.RangeBoundary + wantSQL []string + wantCount int + wantErr bool + }{ + { + name: "inclusive range", + provider: "postgresql", + rangeExpr: &expr.RangeBoundary{ + Min: &expr.Expression{Op: expr.Literal, Left: "25"}, + Max: &expr.Expression{Op: expr.Literal, Left: "65"}, + Inclusive: true, + }, + wantSQL: []string{"BETWEEN"}, + wantCount: 2, + wantErr: false, + }, + { + name: "exclusive range", + provider: "postgresql", + rangeExpr: &expr.RangeBoundary{ + Min: &expr.Expression{Op: expr.Literal, Left: "25"}, + Max: &expr.Expression{Op: expr.Literal, Left: "65"}, + Inclusive: false, + }, + wantSQL: []string{">", "<"}, + wantCount: 2, + wantErr: false, + }, + { + name: "open-ended min (inclusive)", + provider: "postgresql", + rangeExpr: &expr.RangeBoundary{ + Min: &expr.Expression{Op: expr.Literal, Left: "*"}, + Max: &expr.Expression{Op: expr.Literal, Left: "65"}, + Inclusive: true, + }, + wantSQL: []string{"<="}, + wantCount: 1, + wantErr: false, + }, + { + name: "open-ended min (exclusive)", + provider: "postgresql", + rangeExpr: &expr.RangeBoundary{ + Min: &expr.Expression{Op: expr.Literal, Left: "*"}, + Max: &expr.Expression{Op: expr.Literal, Left: "65"}, + Inclusive: false, + }, + wantSQL: []string{"<"}, + wantCount: 1, + wantErr: false, + }, + { + name: "open-ended max (inclusive)", + provider: "postgresql", + rangeExpr: &expr.RangeBoundary{ + Min: &expr.Expression{Op: expr.Literal, Left: "25"}, + Max: &expr.Expression{Op: expr.Literal, Left: "*"}, + Inclusive: true, + }, + wantSQL: []string{">="}, + wantCount: 1, + wantErr: false, + }, + { + name: "open-ended max (exclusive)", + provider: "postgresql", + rangeExpr: &expr.RangeBoundary{ + Min: &expr.Expression{Op: expr.Literal, Left: "25"}, + Max: &expr.Expression{Op: expr.Literal, Left: "*"}, + Inclusive: false, + }, + wantSQL: []string{">"}, + wantCount: 1, + wantErr: false, + }, + { + name: "both wildcards (error)", + provider: "postgresql", + rangeExpr: &expr.RangeBoundary{ + Min: &expr.Expression{Op: expr.Literal, Left: "*"}, + Max: &expr.Expression{Op: expr.Literal, Left: "*"}, + Inclusive: true, + }, + wantErr: true, + }, + { + name: "invalid range expression (error)", + provider: "postgresql", + rangeExpr: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + driver := NewSQLDriver(fields, tt.provider) + var e *expr.Expression + if tt.rangeExpr == nil { + e = &expr.Expression{ + Op: expr.Range, + Left: expr.Column("age"), + Right: nil, + } + } else { + e = &expr.Expression{ + Op: expr.Range, + Left: expr.Column("age"), + Right: tt.rangeExpr, + } + } + sql, params, err := driver.renderRange(e) + if (err != nil) != tt.wantErr { + t.Errorf("renderRange() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("renderRange() sql = %v, want to contain %v", sql, want) + } + } + if len(params) != tt.wantCount { + t.Errorf("renderRange() params count = %v, want %v", len(params), tt.wantCount) + } + } + }) + } +} + +func TestSQLDriver_SerializeColumn(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "metadata", Type: reflect.TypeOf(map[string]interface{}{})}, + } + driver := NewSQLDriver(fields, "postgresql") + + tests := []struct { + name string + input any + wantSQL string + wantCount int + wantErr bool + }{ + { + name: "simple column", + input: expr.Column("name"), + wantSQL: `"name"`, + wantCount: 0, + wantErr: false, + }, + { + name: "JSON syntax column", + input: expr.Column("metadata->>'key'"), + wantSQL: "metadata->>'key'", + wantCount: 0, + wantErr: false, + }, + { + name: "string column", + input: "name", + wantSQL: `"name"`, + wantCount: 0, + wantErr: false, + }, + { + name: "JSON syntax string", + input: "metadata->>'key'", + wantSQL: "metadata->>'key'", + wantCount: 0, + wantErr: false, + }, + { + name: "expression with Literal column", + input: &expr.Expression{Op: expr.Literal, Left: expr.Column("name")}, + wantSQL: `"name"`, + wantCount: 0, + wantErr: false, + }, + { + name: "invalid type", + input: 123, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, params, err := driver.serializeColumn(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("serializeColumn() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + if !strings.Contains(sql, tt.wantSQL) { + t.Errorf("serializeColumn() sql = %v, want to contain %v", sql, tt.wantSQL) + } + if len(params) != tt.wantCount { + t.Errorf("serializeColumn() params count = %v, want %v", len(params), tt.wantCount) + } + } + }) + } +} + +func TestSQLDriver_SerializeValue(t *testing.T) { + fields := []FieldInfo{{Name: "name", Type: reflect.TypeOf("")}} + driver := NewSQLDriver(fields, "postgresql") + + tests := []struct { + name string + input any + wantSQL string + wantValue string + wantErr bool + }{ + { + name: "string value", + input: "john", + wantSQL: "?", + wantValue: "john", + wantErr: false, + }, + { + name: "string with wildcards", + input: "john*", + wantSQL: "?", + wantValue: "john%", + wantErr: false, + }, + { + name: "literal expression", + input: &expr.Expression{Op: expr.Literal, Left: "test"}, + wantSQL: "?", + wantValue: "test", + wantErr: false, + }, + { + name: "wild expression", + input: &expr.Expression{Op: expr.Wild, Left: "test*"}, + wantSQL: "?", + wantValue: "test%", + wantErr: false, + }, + { + name: "nil value (error)", + input: nil, + wantErr: true, + }, + { + name: "integer value", + input: 42, + wantSQL: "?", + wantValue: "42", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, params, err := driver.serializeValue(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("serializeValue() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + if sql != tt.wantSQL { + t.Errorf("serializeValue() sql = %v, want %v", sql, tt.wantSQL) + } + if len(params) != 1 { + t.Errorf("serializeValue() params count = %v, want 1", len(params)) + return + } + gotValue := fmt.Sprintf("%v", params[0]) + if tt.wantValue != "" && gotValue != tt.wantValue { + t.Errorf("serializeValue() param value = %v, want %v", gotValue, tt.wantValue) + } + } + }) + } +} + +func TestSQLDriver_FormatFieldName(t *testing.T) { + jsonbType := reflect.TypeOf(map[string]interface{}{}) + fields := []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "metadata", Type: jsonbType}, + } + + tests := []struct { + name string + provider string + field string + want string + }{ + { + name: "postgresql JSON field", + provider: "postgresql", + field: "metadata.key", + want: "metadata->>'key'", + }, + { + name: "mysql JSON field", + provider: "mysql", + field: "metadata.key", + want: "JSON_UNQUOTE(JSON_EXTRACT(metadata, '$.key'))", + }, + { + name: "sqlite JSON field", + provider: "sqlite", + field: "metadata.key", + want: "JSON_EXTRACT(metadata, '$.key')", + }, + { + name: "simple field (no dot)", + provider: "postgresql", + field: "name", + want: "name", + }, + { + name: "non-JSONB field with dot (no conversion)", + provider: "postgresql", + field: "name.subfield", + want: "name.subfield", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + driver := NewSQLDriver(fields, tt.provider) + got := driver.formatFieldName(tt.field) + if string(got) != tt.want { + t.Errorf("formatFieldName() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestConvertWildcards(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "no wildcards", + input: "john", + want: "john", + }, + { + name: "single *", + input: "john*", + want: "john%", + }, + { + name: "single ?", + input: "jo?n", + want: "jo_n", + }, + { + name: "multiple *", + input: "*john*", + want: "%john%", + }, + { + name: "multiple ?", + input: "j??n", + want: "j__n", + }, + { + name: "mixed wildcards", + input: "j*?n", + want: "j%_n", + }, + { + name: "empty string", + input: "", + want: "", + }, + { + name: "only wildcards", + input: "***", + want: "%%%", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := convertWildcards(tt.input) + if got != tt.want { + t.Errorf("convertWildcards() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsJSONSyntax(t *testing.T) { + tests := []struct { + name string + input string + want bool + }{ + { + name: "PostgreSQL JSONB operator", + input: "metadata->>'key'", + want: true, + }, + { + name: "MySQL JSON_EXTRACT", + input: "JSON_EXTRACT(column, '$.field')", + want: true, + }, + { + name: "MySQL JSON_UNQUOTE", + input: "JSON_UNQUOTE(JSON_EXTRACT(column, '$.field'))", + want: true, + }, + { + name: "SQLite JSON_EXTRACT", + input: "JSON_EXTRACT(column, '$.field')", + want: true, + }, + { + name: "regular column", + input: "name", + want: false, + }, + { + name: "quoted column", + input: `"name"`, + want: false, + }, + { + name: "empty string", + input: "", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isJSONSyntax(tt.input) + if got != tt.want { + t.Errorf("isJSONSyntax() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsNullValue(t *testing.T) { + tests := []struct { + name string + input any + want bool + }{ + { + name: "null string (lowercase)", + input: "null", + want: true, + }, + { + name: "NULL string (uppercase)", + input: "NULL", + want: true, + }, + { + name: "Null string (mixed case)", + input: "Null", + want: true, + }, + { + name: "null in literal expression", + input: &expr.Expression{Op: expr.Literal, Left: "null"}, + want: true, + }, + { + name: "empty string", + input: "", + want: false, + }, + { + name: "nil value", + input: nil, + want: false, + }, + { + name: "regular string", + input: "john", + want: false, + }, + { + name: "nil string", + input: "nil", + want: false, + }, + { + name: "empty string", + input: "empty", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isNullValue(tt.input) + if got != tt.want { + t.Errorf("isNullValue() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestConvertToPostgresPlaceholders(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "single placeholder", + input: "SELECT * FROM users WHERE name = ?", + want: "SELECT * FROM users WHERE name = $1", + }, + { + name: "multiple placeholders", + input: "SELECT * FROM users WHERE name = ? AND age = ?", + want: "SELECT * FROM users WHERE name = $1 AND age = $2", + }, + { + name: "no placeholders", + input: "SELECT * FROM users", + want: "SELECT * FROM users", + }, + { + name: "many placeholders", + input: "? ? ? ? ?", + want: "$1 $2 $3 $4 $5", + }, + { + name: "placeholder in string literal (should still convert)", + input: "SELECT '?' FROM users WHERE name = ?", + want: "SELECT '$1' FROM users WHERE name = $2", + }, + { + name: "empty string", + input: "", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := convertToPostgresPlaceholders(tt.input) + if got != tt.want { + t.Errorf("convertToPostgresPlaceholders() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSQLDriver_ProcessJSONFields(t *testing.T) { + jsonbType := reflect.TypeOf(map[string]interface{}{}) + fields := []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "metadata", Type: jsonbType}, + } + + tests := []struct { + name string + provider string + expr *expr.Expression + check func(t *testing.T, expr *expr.Expression) + }{ + { + name: "postgresql JSON field conversion", + provider: "postgresql", + expr: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("metadata.key"), + Right: &expr.Expression{Op: expr.Literal, Left: "value"}, + }, + check: func(t *testing.T, e *expr.Expression) { + if col, ok := e.Left.(expr.Column); ok { + if !strings.Contains(string(col), "->>'") { + t.Errorf("expected PostgreSQL JSON syntax, got %v", col) + } + } + }, + }, + { + name: "mysql JSON field conversion", + provider: "mysql", + expr: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("metadata.key"), + Right: &expr.Expression{Op: expr.Literal, Left: "value"}, + }, + check: func(t *testing.T, e *expr.Expression) { + if col, ok := e.Left.(expr.Column); ok { + if !strings.Contains(string(col), "JSON_EXTRACT") { + t.Errorf("expected MySQL JSON syntax, got %v", col) + } + } + }, + }, + { + name: "nested expression", + provider: "postgresql", + expr: &expr.Expression{ + Op: expr.And, + Left: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("metadata.key"), + Right: &expr.Expression{Op: expr.Literal, Left: "value"}, + }, + Right: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + }, + check: func(t *testing.T, e *expr.Expression) { + if leftExpr, ok := e.Left.(*expr.Expression); ok { + if col, ok := leftExpr.Left.(expr.Column); ok { + if !strings.Contains(string(col), "->>'") { + t.Errorf("expected PostgreSQL JSON syntax in nested expression, got %v", col) + } + } + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + driver := NewSQLDriver(fields, tt.provider) + driver.processJSONFields(tt.expr) + if tt.check != nil { + tt.check(t, tt.expr) + } + }) + } +} + +func TestSQLDriver_RenderParamInternal(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + } + driver := NewSQLDriver(fields, "postgresql") + + tests := []struct { + name string + expr *expr.Expression + wantSQL []string + wantCount int + wantErr bool + }{ + { + name: "Like operator", + expr: &expr.Expression{ + Op: expr.Like, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + wantSQL: []string{"ILIKE"}, + wantCount: 1, + wantErr: false, + }, + { + name: "Fuzzy operator", + expr: &expr.Expression{ + Op: expr.Fuzzy, + Left: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "roam"}, + }, + }, + wantSQL: []string{"similarity"}, + wantCount: 1, + wantErr: false, + }, + { + name: "Boost operator (error)", + expr: &expr.Expression{ + Op: expr.Boost, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + wantErr: true, + }, + { + name: "Range operator", + expr: &expr.Expression{ + Op: expr.Range, + Left: expr.Column("name"), + Right: &expr.RangeBoundary{ + Min: &expr.Expression{Op: expr.Literal, Left: "a"}, + Max: &expr.Expression{Op: expr.Literal, Left: "z"}, + Inclusive: true, + }, + }, + wantSQL: []string{"BETWEEN"}, + wantCount: 2, + wantErr: false, + }, + { + name: "nil expression", + expr: nil, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, params, err := driver.renderParamInternal(tt.expr) + if (err != nil) != tt.wantErr { + t.Errorf("renderParamInternal() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + if tt.expr == nil { + if sql != "" { + t.Errorf("renderParamInternal() sql = %v, want empty string", sql) + } + return + } + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("renderParamInternal() sql = %v, want to contain %v", sql, want) + } + } + if len(params) != tt.wantCount { + t.Errorf("renderParamInternal() params count = %v, want %v", len(params), tt.wantCount) + } + } + }) + } +} + +func TestSQLDriver_ProviderSpecific(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "metadata", Type: reflect.TypeOf(map[string]interface{}{})}, + } + + tests := []struct { + name string + provider string + expr *expr.Expression + wantSQL []string + wantCount int + checkFunc func(t *testing.T, sql string, params []any) + }{ + { + name: "postgresql placeholder format", + provider: "postgresql", + expr: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + wantSQL: []string{"$1"}, + wantCount: 1, + checkFunc: func(t *testing.T, sql string, params []any) { + if !strings.Contains(sql, "$") { + t.Errorf("expected PostgreSQL placeholder ($1), got %v", sql) + } + }, + }, + { + name: "mysql placeholder format", + provider: "mysql", + expr: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + wantSQL: []string{"?"}, + wantCount: 1, + checkFunc: func(t *testing.T, sql string, params []any) { + if !strings.Contains(sql, "?") { + t.Errorf("expected MySQL placeholder (?), got %v", sql) + } + }, + }, + { + name: "sqlite placeholder format", + provider: "sqlite", + expr: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + wantSQL: []string{"?"}, + wantCount: 1, + checkFunc: func(t *testing.T, sql string, params []any) { + if !strings.Contains(sql, "?") { + t.Errorf("expected SQLite placeholder (?), got %v", sql) + } + }, + }, + { + name: "postgresql JSON field", + provider: "postgresql", + expr: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("metadata.key"), + Right: &expr.Expression{Op: expr.Literal, Left: "value"}, + }, + wantSQL: []string{"metadata->>'key'"}, + wantCount: 1, + checkFunc: nil, + }, + { + name: "mysql JSON field", + provider: "mysql", + expr: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("metadata.key"), + Right: &expr.Expression{Op: expr.Literal, Left: "value"}, + }, + wantSQL: []string{"JSON_EXTRACT"}, + wantCount: 1, + checkFunc: nil, + }, + { + name: "sqlite JSON field", + provider: "sqlite", + expr: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("metadata.key"), + Right: &expr.Expression{Op: expr.Literal, Left: "value"}, + }, + wantSQL: []string{"JSON_EXTRACT"}, + wantCount: 1, + checkFunc: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + driver := NewSQLDriver(fields, tt.provider) + sql, params, err := driver.RenderParam(tt.expr) + if err != nil { + t.Fatalf("RenderParam() error = %v", err) + } + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("RenderParam() sql = %v, want to contain %v", sql, want) + } + } + if len(params) != tt.wantCount { + t.Errorf("RenderParam() params count = %v, want %v", len(params), tt.wantCount) + } + if tt.checkFunc != nil { + tt.checkFunc(t, sql, params) + } + }) + } +} From d5f753280cb0bca69ea1e6248c90617828bbeb94 Mon Sep 17 00:00:00 2001 From: Lutherwaves Date: Tue, 20 Jan 2026 00:25:52 +0200 Subject: [PATCH 08/13] fix: stricter validation for duplicates, escapes to prevent injection - Escape PartiQL values in DynamoDB driver - Add lucene query openapi - Simplify tests --- storage/search/lucene/dynamodb_driver.go | 80 +++++++--- storage/search/lucene/dynamodb_driver_test.go | 126 ++++++++++++++-- storage/search/lucene/parser.go | 91 ++++++++---- storage/search/lucene/sql_driver.go | 138 +++++++++++++----- storage/search/lucene/sql_driver_test.go | 115 +++++++++++++-- types/types.go | 42 +----- 6 files changed, 436 insertions(+), 156 deletions(-) diff --git a/storage/search/lucene/dynamodb_driver.go b/storage/search/lucene/dynamodb_driver.go index 8a7cd45..9e8b23c 100644 --- a/storage/search/lucene/dynamodb_driver.go +++ b/storage/search/lucene/dynamodb_driver.go @@ -15,10 +15,10 @@ type DynamoDBPartiQLDriver struct { fields map[string]FieldInfo } -func NewDynamoDBDriver(fields []FieldInfo) *DynamoDBPartiQLDriver { - fieldMap := make(map[string]FieldInfo) - for _, f := range fields { - fieldMap[f.Name] = f +func NewDynamoDBDriver(fields []FieldInfo) (*DynamoDBPartiQLDriver, error) { + fieldMap, err := buildFieldMap(fields) + if err != nil { + return nil, err } fns := map[expr.Operator]driver.RenderFN{ @@ -46,7 +46,7 @@ func NewDynamoDBDriver(fields []FieldInfo) *DynamoDBPartiQLDriver { RenderFNs: fns, }, fields: fieldMap, - } + }, nil } // RenderPartiQL renders the expression to DynamoDB PartiQL with AttributeValue parameters. @@ -66,29 +66,69 @@ func (d *DynamoDBPartiQLDriver) RenderPartiQL(e *expr.Expression) (string, []typ return str, attrValues, nil } +// escapePartiQLString escapes a string value for safe use in PartiQL string literals. +// Escapes single quotes by doubling them (PartiQL standard). +func escapePartiQLString(s string) string { + return strings.ReplaceAll(s, "'", "''") +} + +// escapePartiQLIdentifier escapes a field name for safe use in PartiQL. +// Validates that the identifier contains only safe characters (alphanumeric, underscore). +// Returns error if identifier contains potentially dangerous characters. +func escapePartiQLIdentifier(identifier string) (string, error) { + // Validate identifier contains only safe characters + for _, r := range identifier { + if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_') { + return "", fmt.Errorf("invalid identifier: contains unsafe character '%c'", r) + } + } + return identifier, nil +} + +// unquotePartiQLString safely removes surrounding quotes from a PartiQL string literal. +// Handles already-escaped quotes correctly. +func unquotePartiQLString(s string) string { + if len(s) >= 2 && s[0] == '\'' && s[len(s)-1] == '\'' { + return s[1 : len(s)-1] + } + return s +} + // dynamoDBLike implements LIKE using DynamoDB's begins_with and contains functions. func dynamoDBLike(left, right string) (string, error) { - // Remove quotes from right side to analyze pattern - pattern := strings.Trim(right, "'") + // Validate and escape field name (left) + safeLeft, err := escapePartiQLIdentifier(left) + if err != nil { + return "", fmt.Errorf("invalid field name: %w", err) + } - // Replace wildcards for analysis - hasPrefix := strings.HasPrefix(pattern, "%") - hasSuffix := strings.HasSuffix(pattern, "%") + // Extract the raw value from the right side (remove quotes if present) + rawValue := unquotePartiQLString(right) + + // Analyze pattern for wildcards + hasPrefix := strings.HasPrefix(rawValue, "%") + hasSuffix := strings.HasSuffix(rawValue, "%") if hasPrefix && hasSuffix { // %value% -> contains(field, value) - value := strings.Trim(pattern, "%") - return fmt.Sprintf("contains(%s, '%s')", left, value), nil - } else if !hasPrefix && hasSuffix { + value := strings.Trim(rawValue, "%") + escapedValue := escapePartiQLString(value) + return fmt.Sprintf("contains(%s, '%s')", safeLeft, escapedValue), nil + } + if !hasPrefix && hasSuffix { // value% -> begins_with(field, value) - value := strings.TrimSuffix(pattern, "%") - return fmt.Sprintf("begins_with(%s, '%s')", left, value), nil - } else if hasPrefix && !hasSuffix { + value := strings.TrimSuffix(rawValue, "%") + escapedValue := escapePartiQLString(value) + return fmt.Sprintf("begins_with(%s, '%s')", safeLeft, escapedValue), nil + } + if hasPrefix && !hasSuffix { // %value -> contains(field, value) (DynamoDB doesn't have ends_with) - value := strings.TrimPrefix(pattern, "%") - return fmt.Sprintf("contains(%s, '%s')", left, value), nil + value := strings.TrimPrefix(rawValue, "%") + escapedValue := escapePartiQLString(value) + return fmt.Sprintf("contains(%s, '%s')", safeLeft, escapedValue), nil } - // Exact match - return fmt.Sprintf("%s = %s", left, right), nil + // Exact match - escape the value and wrap in quotes + escapedValue := escapePartiQLString(rawValue) + return fmt.Sprintf("%s = '%s'", safeLeft, escapedValue), nil } diff --git a/storage/search/lucene/dynamodb_driver_test.go b/storage/search/lucene/dynamodb_driver_test.go index 27735fa..1193813 100644 --- a/storage/search/lucene/dynamodb_driver_test.go +++ b/storage/search/lucene/dynamodb_driver_test.go @@ -11,14 +11,16 @@ import ( func TestNewDynamoDBDriver(t *testing.T) { tests := []struct { - name string - fields []FieldInfo - want map[string]FieldInfo + name string + fields []FieldInfo + want map[string]FieldInfo + wantErr bool }{ { - name: "empty fields", - fields: []FieldInfo{}, - want: map[string]FieldInfo{}, + name: "empty fields", + fields: []FieldInfo{}, + want: map[string]FieldInfo{}, + wantErr: false, }, { name: "single field", @@ -28,6 +30,7 @@ func TestNewDynamoDBDriver(t *testing.T) { want: map[string]FieldInfo{ "name": {Name: "name", Type: reflect.TypeOf("")}, }, + wantErr: false, }, { name: "multiple fields", @@ -41,22 +44,48 @@ func TestNewDynamoDBDriver(t *testing.T) { "email": {Name: "email", Type: reflect.TypeOf("")}, "age": {Name: "age", Type: reflect.TypeOf(0)}, }, + wantErr: false, }, { - name: "duplicate field names (last wins)", + name: "duplicate field names returns error", fields: []FieldInfo{ {Name: "name", Type: reflect.TypeOf("")}, {Name: "name", Type: reflect.TypeOf(0)}, }, - want: map[string]FieldInfo{ - "name": {Name: "name", Type: reflect.TypeOf(0)}, + want: nil, + wantErr: true, + }, + { + name: "multiple duplicate field names", + fields: []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "email", Type: reflect.TypeOf("")}, + {Name: "name", Type: reflect.TypeOf(0)}, }, + want: nil, + wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - driver := NewDynamoDBDriver(tt.fields) + driver, err := NewDynamoDBDriver(tt.fields) + if (err != nil) != tt.wantErr { + t.Errorf("NewDynamoDBDriver() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + if err == nil { + t.Errorf("NewDynamoDBDriver() expected error but got nil") + } + if driver != nil { + t.Errorf("NewDynamoDBDriver() expected nil driver on error, got %v", driver) + } + if err != nil && !strings.Contains(err.Error(), "duplicate field name") { + t.Errorf("NewDynamoDBDriver() error message should contain 'duplicate field name', got: %v", err) + } + return + } if driver == nil { t.Fatalf("NewDynamoDBDriver() returned nil") } @@ -83,7 +112,10 @@ func TestDynamoDBDriver_RenderPartiQL(t *testing.T) { {Name: "email", Type: reflect.TypeOf("")}, {Name: "age", Type: reflect.TypeOf(0)}, } - driver := NewDynamoDBDriver(fields) + driver, err := NewDynamoDBDriver(fields) + if err != nil { + t.Fatalf("NewDynamoDBDriver() error = %v", err) + } tests := []struct { name string @@ -268,10 +300,10 @@ func TestDynamoDBLike(t *testing.T) { wantErr: false, }, { - name: "quoted value without quotes in pattern", + name: "unquoted value (no quotes in pattern)", left: "name", right: "john", - want: "name = john", + want: "name = 'john'", wantErr: false, }, { @@ -288,6 +320,69 @@ func TestDynamoDBLike(t *testing.T) { want: "contains(name, '')", wantErr: false, }, + { + name: "value with single quote in exact match", + left: "name", + right: "'John's'", + want: "name = 'John''s'", + wantErr: false, + }, + { + name: "value with single quote and wildcard prefix", + left: "name", + right: "'%test'value'", + want: "contains(name, 'test''value')", + wantErr: false, + }, + { + name: "value with single quote and wildcard suffix", + left: "name", + right: "'test'value%'", + want: "begins_with(name, 'test''value')", + wantErr: false, + }, + { + name: "value with single quote and wildcards both sides", + left: "name", + right: "'%test'value%'", + want: "contains(name, 'test''value')", + wantErr: false, + }, + { + name: "value with multiple single quotes", + left: "name", + right: "'O'Brien'", + want: "name = 'O''Brien'", + wantErr: false, + }, + { + name: "injection attempt: value with quote and OR (should be escaped)", + left: "name", + right: "'test') OR (1=1'", + want: "name = 'test'') OR (1=1'", + wantErr: false, + }, + { + name: "invalid field name with special characters", + left: "name; DROP TABLE users;--", + right: "'test'", + want: "", + wantErr: true, + }, + { + name: "invalid field name with quotes", + left: "name'", + right: "'test'", + want: "", + wantErr: true, + }, + { + name: "invalid field name with spaces", + left: "field name", + right: "'test'", + want: "", + wantErr: true, + }, } for _, tt := range tests { @@ -308,7 +403,10 @@ func TestDynamoDBDriver_EdgeCases(t *testing.T) { fields := []FieldInfo{ {Name: "name", Type: reflect.TypeOf("")}, } - driver := NewDynamoDBDriver(fields) + driver, err := NewDynamoDBDriver(fields) + if err != nil { + t.Fatalf("NewDynamoDBDriver() error = %v", err) + } tests := []struct { name string diff --git a/storage/search/lucene/parser.go b/storage/search/lucene/parser.go index b35fe62..816fbe2 100644 --- a/storage/search/lucene/parser.go +++ b/storage/search/lucene/parser.go @@ -75,10 +75,10 @@ func NewParser(model any, config ...*ParserConfig) (*Parser, error) { return nil, err } - // Build field map - fieldMap := make(map[string]FieldInfo, len(fields)) - for _, f := range fields { - fieldMap[f.Name] = f + // Build field map and validate for duplicates + fieldMap, err := buildFieldMap(fields) + if err != nil { + return nil, err } // Apply config or use defaults @@ -147,6 +147,19 @@ func extractFields(model any) ([]FieldInfo, error) { return fields, nil } +// buildFieldMap builds a field map from a slice of fields and validates for duplicates. +// Returns an error if duplicate field names are found. +func buildFieldMap(fields []FieldInfo) (map[string]FieldInfo, error) { + fieldMap := make(map[string]FieldInfo, len(fields)) + for _, f := range fields { + if existing, exists := fieldMap[f.Name]; exists { + return nil, fmt.Errorf("duplicate field name '%s': already defined with type %v, cannot redefine with type %v", f.Name, existing.Type, f.Type) + } + fieldMap[f.Name] = f + } + return fieldMap, nil +} + // canUseNestedAccess checks if a field type supports nested access (field.subfield syntax). func canUseNestedAccess(t reflect.Type) bool { // Return false for nil types @@ -211,14 +224,13 @@ func (p *Parser) ParseToMap(query string) (map[string]any, error) { return p.exprToMap(e), nil } -// ParseToSQL parses a Lucene query and converts it to SQL with parameters for the specified provider. -// Creates a SQL driver on-demand for rendering with provider-specific syntax. -// Provider should be one of: "postgresql", "mysql", "sqlite" -func (p *Parser) ParseToSQL(query string, provider string) (string, []any, error) { - slog.Debug(fmt.Sprintf(`Parsing query to SQL: %s`, query)) +// parseQueryCommon performs common parsing steps shared by ParseToSQL and ParseToDynamoDBPartiQL. +// Returns the parsed expression or an error. +func (p *Parser) parseQueryCommon(query string, queryType string) (*expr.Expression, error) { + slog.Debug(fmt.Sprintf(`Parsing query to %s: %s`, queryType, query)) if err := p.validateQuery(query); err != nil { - return "", nil, err + return nil, err } // Expand implicit terms first (for validation of the full query) @@ -226,17 +238,32 @@ func (p *Parser) ParseToSQL(query string, provider string) (string, []any, error // Validate all field references exist in the model if err := p.ValidateFields(expandedQuery); err != nil { - return "", nil, err + return nil, err } // Parse using the library e, err := p.parseWithImplicitSearch(query) + if err != nil { + return nil, err + } + + return e, nil +} + +// ParseToSQL parses a Lucene query and converts it to SQL with parameters for the specified provider. +// Creates a SQL driver on-demand for rendering with provider-specific syntax. +// Provider should be one of: "postgresql", "mysql", "sqlite" +func (p *Parser) ParseToSQL(query string, provider string) (string, []any, error) { + e, err := p.parseQueryCommon(query, "SQL") if err != nil { return "", nil, err } // Create SQL driver on-demand for the specified provider and render - driver := NewSQLDriver(p.Fields, provider) + driver, err := NewSQLDriver(p.Fields, provider) + if err != nil { + return "", nil, err + } sql, params, err := driver.RenderParam(e) if err != nil { return "", nil, err @@ -248,28 +275,16 @@ func (p *Parser) ParseToSQL(query string, provider string) (string, []any, error // ParseToDynamoDBPartiQL parses a Lucene query and converts it to DynamoDB PartiQL. // Creates a DynamoDB driver on-demand for rendering. func (p *Parser) ParseToDynamoDBPartiQL(query string) (string, []types.AttributeValue, error) { - slog.Debug(fmt.Sprintf(`Parsing query to DynamoDB PartiQL: %s`, query)) - - if err := p.validateQuery(query); err != nil { - return "", nil, err - } - - // Expand implicit terms first (for validation of the full query) - expandedQuery := p.expandImplicitTerms(query) - - // Validate all field references exist in the model - if err := p.ValidateFields(expandedQuery); err != nil { + e, err := p.parseQueryCommon(query, "DynamoDB PartiQL") + if err != nil { return "", nil, err } - // Parse using the library - e, err := p.parseWithImplicitSearch(query) + // Create DynamoDB driver on-demand and render + driver, err := NewDynamoDBDriver(p.Fields) if err != nil { return "", nil, err } - - // Create DynamoDB driver on-demand and render - driver := NewDynamoDBDriver(p.Fields) partiql, attrs, err := driver.RenderPartiQL(e) if err != nil { return "", nil, err @@ -412,11 +427,13 @@ func countTerms(query string) int { currentTerm = false if len(remaining) >= 3 && (remaining[0] == 'A' || remaining[0] == 'a') { i += 3 - } else if len(remaining) >= 3 && (remaining[0] == 'N' || remaining[0] == 'n') { + continue + } + if len(remaining) >= 3 && (remaining[0] == 'N' || remaining[0] == 'n') { i += 3 - } else { - i += 2 + continue } + i += 2 continue } } @@ -488,6 +505,15 @@ func (p *Parser) validateFieldName(fieldName string) error { } baseField := parts[0] + subField := parts[1] + + // Check for whitespace in field names (security: prevents obfuscation, OWASP A03) + if strings.TrimSpace(baseField) != baseField { + return fmt.Errorf("invalid field format '%s': whitespace not allowed in field names", fieldName) + } + if strings.TrimSpace(subField) != subField { + return fmt.Errorf("invalid field format '%s': whitespace not allowed in field names (use 'field.subfield' not 'field. subfield')", fieldName) + } // Check if base field exists field, exists := p.fieldMap[baseField] @@ -720,7 +746,8 @@ func (p *Parser) parseWithImplicitSearch(query string) (*expr.Expression, error) implicitFields := p.getImplicitSearchFields() if len(implicitFields) > 0 { fallbackField = implicitFields[0].Name - } else if len(p.Fields) > 0 { + } + if fallbackField == "" && len(p.Fields) > 0 { fallbackField = p.Fields[0].Name } diff --git a/storage/search/lucene/sql_driver.go b/storage/search/lucene/sql_driver.go index b48eec0..63c05a0 100644 --- a/storage/search/lucene/sql_driver.go +++ b/storage/search/lucene/sql_driver.go @@ -16,12 +16,27 @@ type SQLDriver struct { provider string // SQL provider: "postgresql", "mysql", or "sqlite" } +// validateProvider validates that the provider is one of the supported SQL providers. +func validateProvider(provider string) error { + switch provider { + case "postgresql", "mysql", "sqlite": + return nil + default: + return fmt.Errorf("unsupported SQL provider: %s (supported: postgresql, mysql, sqlite)", provider) + } +} + // NewSQLDriver creates a new SQL driver for the specified provider. // Provider should be one of: "postgresql", "mysql", "sqlite" -func NewSQLDriver(fields []FieldInfo, provider string) *SQLDriver { - fieldMap := make(map[string]FieldInfo) - for _, f := range fields { - fieldMap[f.Name] = f +// Returns an error if duplicate field names are found or provider is invalid. +func NewSQLDriver(fields []FieldInfo, provider string) (*SQLDriver, error) { + if err := validateProvider(provider); err != nil { + return nil, err + } + + fieldMap, err := buildFieldMap(fields) + if err != nil { + return nil, err } // RenderFNs map - we handle most operators in renderParamInternal @@ -52,7 +67,7 @@ func NewSQLDriver(fields []FieldInfo, provider string) *SQLDriver { }, fields: fieldMap, provider: provider, - } + }, nil } // RenderParam renders the expression with provider-specific parameter placeholders. @@ -231,23 +246,22 @@ func (s *SQLDriver) renderBinary(e *expr.Expression) (string, []any, error) { return "", nil, fmt.Errorf("%s operator requires a left operand", e.Op) } + var leftStr string + var leftParams []any + var err error + if leftExpr, ok := e.Left.(*expr.Expression); ok { - leftStr, leftParams, err := s.renderParamInternal(leftExpr) + leftStr, leftParams, err = s.renderParamInternal(leftExpr) if err != nil { return "", nil, err } - - if e.Op == expr.Must { - return leftStr, leftParams, nil - } - return fmt.Sprintf("NOT (%s)", leftStr), leftParams, nil - } - - leftStr, leftParams, err := s.serializeColumn(e.Left) - if err != nil { - leftStr, leftParams, err = s.serializeValue(e.Left) + } else { + leftStr, leftParams, err = s.serializeColumn(e.Left) if err != nil { - return s.Base.RenderParam(e) + leftStr, leftParams, err = s.serializeValue(e.Left) + if err != nil { + return s.Base.RenderParam(e) + } } } @@ -290,27 +304,24 @@ func (s *SQLDriver) renderBinary(e *expr.Expression) (string, []any, error) { } } +// quoteColumnName quotes a column name if it's not already JSON syntax. +func quoteColumnName(colStr string) string { + if isJSONSyntax(colStr) { + return colStr + } + return fmt.Sprintf(`"%s"`, colStr) +} + func (s *SQLDriver) serializeColumn(in any) (string, []any, error) { switch v := in.(type) { case expr.Column: - colStr := string(v) - if isJSONSyntax(colStr) { - return colStr, nil, nil - } - return fmt.Sprintf(`"%s"`, colStr), nil, nil + return quoteColumnName(string(v)), nil, nil case string: - if isJSONSyntax(v) { - return v, nil, nil - } - return fmt.Sprintf(`"%s"`, v), nil, nil + return quoteColumnName(v), nil, nil case *expr.Expression: if v.Op == expr.Literal && v.Left != nil { if col, ok := v.Left.(expr.Column); ok { - colStr := string(col) - if isJSONSyntax(colStr) { - return colStr, nil, nil - } - return fmt.Sprintf(`"%s"`, colStr), nil, nil + return quoteColumnName(string(col)), nil, nil } } return s.renderParamInternal(v) @@ -319,18 +330,24 @@ func (s *SQLDriver) serializeColumn(in any) (string, []any, error) { } } +// extractLiteralString extracts a string value from an expression for wildcard conversion. +func extractLiteralString(v *expr.Expression) (string, bool) { + if v.Left == nil { + return "", false + } + if v.Op == expr.Literal || v.Op == expr.Wild { + return fmt.Sprintf("%v", v.Left), true + } + return "", false +} + // serializeValue converts Lucene wildcards (* and ?) to SQL wildcards (% and _). func (s *SQLDriver) serializeValue(in any) (string, []any, error) { switch v := in.(type) { case string: return "?", []any{convertWildcards(v)}, nil case *expr.Expression: - if v.Op == expr.Literal && v.Left != nil { - literalVal := fmt.Sprintf("%v", v.Left) - return "?", []any{convertWildcards(literalVal)}, nil - } - if v.Op == expr.Wild && v.Left != nil { - literalVal := fmt.Sprintf("%v", v.Left) + if literalVal, ok := extractLiteralString(v); ok { return "?", []any{convertWildcards(literalVal)}, nil } return s.renderParamInternal(v) @@ -374,26 +391,71 @@ func (s *SQLDriver) processJSONFields(e *expr.Expression) { } } +// validateSubFieldName validates that a subfield name contains only safe characters. +// Subfield names should be alphanumeric with underscores and dots for nested paths. +// This prevents injection attacks via JSON path manipulation. +func validateSubFieldName(subField string) error { + if subField == "" { + return fmt.Errorf("subfield name cannot be empty") + } + + // Allow alphanumeric, underscore, and dot (for nested paths like "user.name") + // Reject any characters that could be used for injection (quotes, semicolons, etc.) + for _, r := range subField { + if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' || r == '.') { + return fmt.Errorf("invalid subfield name '%s': contains unsafe character '%c'", subField, r) + } + } + return nil +} + +// escapeJSONPathSegment escapes a JSON path segment for safe use in JSON path expressions. +// For PostgreSQL: escapes single quotes in the key name (used in ->>'key' syntax) +// For MySQL/SQLite: escapes special characters in JSON path (though validation should prevent most) +func escapeJSONPathSegment(segment string) string { + // Replace single quote with escaped single quote (for PostgreSQL ->>'key' syntax) + result := strings.ReplaceAll(segment, "'", "''") + return result +} + // formatFieldName converts field.subfield to provider-specific JSON syntax. +// Validates and escapes subfield names to prevent injection attacks. func (s *SQLDriver) formatFieldName(fieldName string) expr.Column { parts := strings.SplitN(fieldName, ".", 2) if len(parts) == 2 { baseField := parts[0] subField := parts[1] + // Validate subfield name for security (prevents injection) + if err := validateSubFieldName(subField); err != nil { + // If validation fails, return original field name (will be caught by field validation) + return expr.Column(fieldName) + } + if field, exists := s.fields[baseField]; exists && canUseNestedAccess(field.Type) { + // Escape subfield name for safe interpolation + // PostgreSQL uses ->>'key' syntax where key is in quotes, so we need to escape quotes + escapedSubField := escapeJSONPathSegment(subField) + switch s.provider { case "postgresql": // PostgreSQL: JSONB operator ->> - return expr.Column(fmt.Sprintf("%s->>'%s'", baseField, subField)) + // Key is in single quotes, so we escape single quotes + return expr.Column(fmt.Sprintf("%s->>'%s'", baseField, escapedSubField)) case "mysql": // MySQL 5.7+: JSON_UNQUOTE(JSON_EXTRACT(column, '$.field')) + // Path is '$.field' - field name is not separately quoted, but validation ensures it's safe return expr.Column(fmt.Sprintf("JSON_UNQUOTE(JSON_EXTRACT(%s, '$.%s'))", baseField, subField)) case "sqlite": // SQLite: JSON_EXTRACT(column, '$.field') + // Path is '$.field' - field name is not separately quoted, but validation ensures it's safe return expr.Column(fmt.Sprintf("JSON_EXTRACT(%s, '$.%s')", baseField, subField)) + + default: + // Should never happen due to validateProvider, but defensive programming + return expr.Column(fieldName) } } } diff --git a/storage/search/lucene/sql_driver_test.go b/storage/search/lucene/sql_driver_test.go index d75e8b1..94ebce3 100644 --- a/storage/search/lucene/sql_driver_test.go +++ b/storage/search/lucene/sql_driver_test.go @@ -50,11 +50,64 @@ func TestNewSQLDriver(t *testing.T) { provider: "postgresql", wantErr: false, }, + { + name: "duplicate field names returns error (postgresql)", + fields: []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "name", Type: reflect.TypeOf(0)}, + }, + provider: "postgresql", + wantErr: true, + }, + { + name: "duplicate field names returns error (mysql)", + fields: []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "name", Type: reflect.TypeOf(0)}, + }, + provider: "mysql", + wantErr: true, + }, + { + name: "duplicate field names returns error (sqlite)", + fields: []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "name", Type: reflect.TypeOf(0)}, + }, + provider: "sqlite", + wantErr: true, + }, + { + name: "multiple duplicate field names (postgresql)", + fields: []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "email", Type: reflect.TypeOf("")}, + {Name: "name", Type: reflect.TypeOf(0)}, + }, + provider: "postgresql", + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - driver := NewSQLDriver(tt.fields, tt.provider) + driver, err := NewSQLDriver(tt.fields, tt.provider) + if (err != nil) != tt.wantErr { + t.Errorf("NewSQLDriver() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + if err == nil { + t.Errorf("NewSQLDriver() expected error but got nil") + } + if driver != nil { + t.Errorf("NewSQLDriver() expected nil driver on error, got %v", driver) + } + if err != nil && !strings.Contains(err.Error(), "duplicate field name") { + t.Errorf("NewSQLDriver() error message should contain 'duplicate field name', got: %v", err) + } + return + } if driver == nil { t.Fatalf("NewSQLDriver() returned nil") } @@ -122,7 +175,10 @@ func TestSQLDriver_RenderParam(t *testing.T) { for _, provider := range providers { for _, tt := range tests { t.Run(provider+"/"+tt.name, func(t *testing.T) { - driver := NewSQLDriver(fields, provider) + driver, err := NewSQLDriver(fields, provider) + if err != nil { + t.Fatalf("NewSQLDriver() error = %v", err) + } sql, params, err := driver.RenderParam(tt.expr) if (err != nil) != tt.wantErr { t.Errorf("RenderParam() error = %v, wantErr %v", err, tt.wantErr) @@ -240,7 +296,10 @@ func TestSQLDriver_RenderLikeOrWild(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - driver := NewSQLDriver(fields, tt.provider) + driver, err := NewSQLDriver(fields, tt.provider) + if err != nil { + t.Fatalf("NewSQLDriver() error = %v", err) + } sql, params, err := driver.renderLikeOrWild(tt.expr) if (err != nil) != tt.wantErr { t.Errorf("renderLikeOrWild() error = %v, wantErr %v", err, tt.wantErr) @@ -345,7 +404,10 @@ func TestSQLDriver_RenderFuzzy(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - driver := NewSQLDriver(fields, tt.provider) + driver, err := NewSQLDriver(fields, tt.provider) + if err != nil { + t.Fatalf("NewSQLDriver() error = %v", err) + } sql, params, err := driver.renderFuzzy(tt.expr) if (err != nil) != tt.wantErr { t.Errorf("renderFuzzy() error = %v, wantErr %v", err, tt.wantErr) @@ -445,7 +507,10 @@ func TestSQLDriver_RenderComparison(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - driver := NewSQLDriver(fields, tt.provider) + driver, err := NewSQLDriver(fields, tt.provider) + if err != nil { + t.Fatalf("NewSQLDriver() error = %v", err) + } var left expr.Column if strings.Contains(tt.name, "age") || tt.op == expr.Greater || tt.op == expr.Less || tt.op == expr.GreaterEq || tt.op == expr.LessEq { left = expr.Column("age") @@ -560,7 +625,10 @@ func TestSQLDriver_RenderBinary(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - driver := NewSQLDriver(fields, tt.provider) + driver, err := NewSQLDriver(fields, tt.provider) + if err != nil { + t.Fatalf("NewSQLDriver() error = %v", err) + } e := &expr.Expression{ Op: tt.op, Left: tt.left, @@ -691,7 +759,10 @@ func TestSQLDriver_RenderRange(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - driver := NewSQLDriver(fields, tt.provider) + driver, err := NewSQLDriver(fields, tt.provider) + if err != nil { + t.Fatalf("NewSQLDriver() error = %v", err) + } var e *expr.Expression if tt.rangeExpr == nil { e = &expr.Expression{ @@ -730,7 +801,10 @@ func TestSQLDriver_SerializeColumn(t *testing.T) { {Name: "name", Type: reflect.TypeOf("")}, {Name: "metadata", Type: reflect.TypeOf(map[string]interface{}{})}, } - driver := NewSQLDriver(fields, "postgresql") + driver, err := NewSQLDriver(fields, "postgresql") + if err != nil { + t.Fatalf("NewSQLDriver() error = %v", err) + } tests := []struct { name string @@ -802,7 +876,10 @@ func TestSQLDriver_SerializeColumn(t *testing.T) { func TestSQLDriver_SerializeValue(t *testing.T) { fields := []FieldInfo{{Name: "name", Type: reflect.TypeOf("")}} - driver := NewSQLDriver(fields, "postgresql") + driver, err := NewSQLDriver(fields, "postgresql") + if err != nil { + t.Fatalf("NewSQLDriver() error = %v", err) + } tests := []struct { name string @@ -924,7 +1001,10 @@ func TestSQLDriver_FormatFieldName(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - driver := NewSQLDriver(fields, tt.provider) + driver, err := NewSQLDriver(fields, tt.provider) + if err != nil { + t.Fatalf("NewSQLDriver() error = %v", err) + } got := driver.formatFieldName(tt.field) if string(got) != tt.want { t.Errorf("formatFieldName() = %v, want %v", got, tt.want) @@ -1230,7 +1310,10 @@ func TestSQLDriver_ProcessJSONFields(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - driver := NewSQLDriver(fields, tt.provider) + driver, err := NewSQLDriver(fields, tt.provider) + if err != nil { + t.Fatalf("NewSQLDriver() error = %v", err) + } driver.processJSONFields(tt.expr) if tt.check != nil { tt.check(t, tt.expr) @@ -1243,7 +1326,10 @@ func TestSQLDriver_RenderParamInternal(t *testing.T) { fields := []FieldInfo{ {Name: "name", Type: reflect.TypeOf("")}, } - driver := NewSQLDriver(fields, "postgresql") + driver, err := NewSQLDriver(fields, "postgresql") + if err != nil { + t.Fatalf("NewSQLDriver() error = %v", err) + } tests := []struct { name string @@ -1437,7 +1523,10 @@ func TestSQLDriver_ProviderSpecific(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - driver := NewSQLDriver(fields, tt.provider) + driver, err := NewSQLDriver(fields, tt.provider) + if err != nil { + t.Fatalf("NewSQLDriver() error = %v", err) + } sql, params, err := driver.RenderParam(tt.expr) if err != nil { t.Fatalf("RenderParam() error = %v", err) diff --git a/types/types.go b/types/types.go index e372a3b..c7364e9 100644 --- a/types/types.go +++ b/types/types.go @@ -137,46 +137,10 @@ const definitions = ` } } }, - "SearchQuery": { + "LuceneSearchQuery": { "type": "string", - "description": "Lucene-style search query supporting field searches, wildcards, boolean operators, ranges, and more. Syntax: field:value, wildcards (*,?), operators (AND, OR, NOT, +, -), ranges ([min TO max]), quoted phrases, JSONB access (field.subfield:value), null checks (field:null), and fuzzy search (term~).", - "example": "name:john AND status:active", - "examples": [ - "name:john", - "name:john*", - "email:*@example.com", - "description:*important*", - "name:john* OR email:*@example.com", - "name:john AND status:active", - "status:active OR status:pending", - "name:john NOT status:inactive", - "+name:john +status:active", - "name:john -status:deleted", - "age:[25 TO 65]", - "age:{25 TO 65}", - "age:[25 TO *]", - "age:[* TO 65]", - "created_at:[2024-01-01 TO 2024-12-31]", - "description:\"hello world\"", - "title:\"test-app (v1.0)\"", - "name:C\\+\\+ OR path:\\/usr\\/bin", - "(name:john* OR email:*@example.com) AND status:active AND age:[25 TO 65]", - "((name:john OR name:jane) AND status:active) OR (status:pending AND age:[18 TO *])", - "searchterm", - "john*", - "labels.category:production", - "metadata.tags:prod*", - "name:john AND labels.env:prod AND metadata.team:engineering", - "parent_id:null", - "NOT deleted_at:null", - "name:john AND deleted_at:null", - "name:roam~", - "name:roam~2", - "labels.tag:prod~", - "+name:john* -status:deleted age:[25 TO 65] AND (role:admin OR role:moderator)", - "name:john OR email:john@example.com OR phone:*555*", - "(name:*admin* OR role:administrator) AND status:active AND NOT deleted_at:null AND created_at:[2024-01-01 TO *]" - ] + "description": "Lucene-style search query supporting field searches, wildcards, boolean operators, ranges, and more. Syntax: field:value, wildcards (*,?), operators (AND, OR, NOT, +, -), ranges ([min TO max]), quoted phrases, JSONB access (field.subfield:value), null checks (field:null), and fuzzy search (term~). Examples: name:john, name:john*, email:*@example.com, description:*important*, name:john* OR email:*@example.com, name:john AND status:active, status:active OR status:pending, name:john NOT status:inactive, +name:john +status:active, name:john -status:deleted, age:[25 TO 65], age:{25 TO 65}, age:[25 TO *], age:[* TO 65], created_at:[2024-01-01 TO 2024-12-31], description:\"hello world\", title:\"test-app (v1.0)\", name:C\\+\\+ OR path:\\/usr\\/bin, (name:john* OR email:*@example.com) AND status:active AND age:[25 TO 65], ((name:john OR name:jane) AND status:active) OR (status:pending AND age:[18 TO *]), searchterm, john*, labels.category:production, metadata.tags:prod*, name:john AND labels.env:prod AND metadata.team:engineering, parent_id:null, NOT deleted_at:null, name:john AND deleted_at:null, name:roam~, name:roam~2, labels.tag:prod~, +name:john* -status:deleted age:[25 TO 65] AND (role:admin OR role:moderator), name:john OR email:john@example.com OR phone:*555*, (name:*admin* OR role:administrator) AND status:active AND NOT deleted_at:null AND created_at:[2024-01-01 TO *]", + "example": "name:john AND status:active" } } } From d74296f218fd8292d77b816a2ab648fc9bbd7725 Mon Sep 17 00:00:00 2001 From: Lutherwaves Date: Tue, 20 Jan 2026 01:12:22 +0200 Subject: [PATCH 09/13] fix: linting errors --- storage/search/lucene/dynamodb_driver.go | 15 +- storage/search/lucene/dynamodb_driver_test.go | 44 +-- storage/search/lucene/parser_test.go | 351 +++++++++--------- storage/search/lucene/sql_driver.go | 16 +- storage/search/lucene/sql_driver_test.go | 102 ++--- 5 files changed, 266 insertions(+), 262 deletions(-) diff --git a/storage/search/lucene/dynamodb_driver.go b/storage/search/lucene/dynamodb_driver.go index 9e8b23c..d7cf42d 100644 --- a/storage/search/lucene/dynamodb_driver.go +++ b/storage/search/lucene/dynamodb_driver.go @@ -2,6 +2,7 @@ package lucene import ( "fmt" + "regexp" "strings" "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" @@ -72,15 +73,17 @@ func escapePartiQLString(s string) string { return strings.ReplaceAll(s, "'", "''") } +var ( + // partiQLIdentifierPattern matches valid PartiQL identifiers (alphanumeric and underscore only) + partiQLIdentifierPattern = regexp.MustCompile(`^[a-zA-Z0-9_]+$`) +) + // escapePartiQLIdentifier escapes a field name for safe use in PartiQL. // Validates that the identifier contains only safe characters (alphanumeric, underscore). // Returns error if identifier contains potentially dangerous characters. func escapePartiQLIdentifier(identifier string) (string, error) { - // Validate identifier contains only safe characters - for _, r := range identifier { - if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_') { - return "", fmt.Errorf("invalid identifier: contains unsafe character '%c'", r) - } + if !partiQLIdentifierPattern.MatchString(identifier) { + return "", fmt.Errorf("invalid identifier: contains unsafe characters (only alphanumeric and underscore allowed)") } return identifier, nil } @@ -104,7 +107,7 @@ func dynamoDBLike(left, right string) (string, error) { // Extract the raw value from the right side (remove quotes if present) rawValue := unquotePartiQLString(right) - + // Analyze pattern for wildcards hasPrefix := strings.HasPrefix(rawValue, "%") hasSuffix := strings.HasSuffix(rawValue, "%") diff --git a/storage/search/lucene/dynamodb_driver_test.go b/storage/search/lucene/dynamodb_driver_test.go index 1193813..8ccf2e5 100644 --- a/storage/search/lucene/dynamodb_driver_test.go +++ b/storage/search/lucene/dynamodb_driver_test.go @@ -185,8 +185,8 @@ func TestDynamoDBDriver_RenderPartiQL(t *testing.T) { wantErr: false, }, { - name: "nil expression", - expr: nil, + name: "nil expression", + expr: nil, wantErr: false, }, } @@ -230,11 +230,11 @@ func TestDynamoDBDriver_RenderPartiQL(t *testing.T) { func TestDynamoDBLike(t *testing.T) { tests := []struct { - name string - left string - right string - want string - wantErr bool + name string + left string + right string + want string + wantErr bool }{ { name: "contains pattern %value%", @@ -363,25 +363,25 @@ func TestDynamoDBLike(t *testing.T) { wantErr: false, }, { - name: "invalid field name with special characters", - left: "name; DROP TABLE users;--", - right: "'test'", - want: "", - wantErr: true, + name: "invalid field name with special characters", + left: "name; DROP TABLE users;--", + right: "'test'", + want: "", + wantErr: true, }, { - name: "invalid field name with quotes", - left: "name'", - right: "'test'", - want: "", - wantErr: true, + name: "invalid field name with quotes", + left: "name'", + right: "'test'", + want: "", + wantErr: true, }, { - name: "invalid field name with spaces", - left: "field name", - right: "'test'", - want: "", - wantErr: true, + name: "invalid field name with spaces", + left: "field name", + right: "'test'", + want: "", + wantErr: true, }, } diff --git a/storage/search/lucene/parser_test.go b/storage/search/lucene/parser_test.go index 0514b0b..8d0d747 100644 --- a/storage/search/lucene/parser_test.go +++ b/storage/search/lucene/parser_test.go @@ -124,49 +124,49 @@ func TestBasicFieldSearch(t *testing.T) { parser := createParser(t, BasicModel{}) tests := []struct { - name string - query string - wantSQL []string - wantNot []string + name string + query string + wantSQL []string + wantNot []string wantParams []any - wantErr bool + wantErr bool }{ { - name: "simple field query", - query: "name:john", - wantSQL: []string{`"name"`, "=", "$1"}, - wantNot: []string{"ILIKE", "LIKE"}, + name: "simple field query", + query: "name:john", + wantSQL: []string{`"name"`, "=", "$1"}, + wantNot: []string{"ILIKE", "LIKE"}, wantParams: []any{"john"}, - wantErr: false, + wantErr: false, }, { - name: "wildcard prefix", - query: "name:john*", - wantSQL: []string{`"name"`, "ILIKE", "$1"}, - wantNot: []string{"="}, + name: "wildcard prefix", + query: "name:john*", + wantSQL: []string{`"name"`, "ILIKE", "$1"}, + wantNot: []string{"="}, wantParams: []any{"john%"}, - wantErr: false, + wantErr: false, }, { - name: "wildcard suffix", - query: "name:*john", - wantSQL: []string{`"name"`, "ILIKE", "$1"}, + name: "wildcard suffix", + query: "name:*john", + wantSQL: []string{`"name"`, "ILIKE", "$1"}, wantParams: []any{"%john"}, - wantErr: false, + wantErr: false, }, { - name: "wildcard contains", - query: "name:*john*", - wantSQL: []string{`"name"`, "ILIKE", "$1"}, + name: "wildcard contains", + query: "name:*john*", + wantSQL: []string{`"name"`, "ILIKE", "$1"}, wantParams: []any{"%john%"}, - wantErr: false, + wantErr: false, }, { - name: "email field", - query: `email:"test@example.com"`, - wantSQL: []string{`"email"`, "=", "$1"}, + name: "email field", + query: `email:"test@example.com"`, + wantSQL: []string{`"email"`, "=", "$1"}, wantParams: []any{"test@example.com"}, - wantErr: false, + wantErr: false, }, } @@ -183,9 +183,9 @@ func TestBasicFieldSearch(t *testing.T) { } if len(tt.wantParams) > 0 { // Only validate params if we expect specific values - if len(tt.wantParams) > 0 { - assertParamsEqual(t, params, tt.wantParams, tt.name) - } + if len(tt.wantParams) > 0 { + assertParamsEqual(t, params, tt.wantParams, tt.name) + } } } }) @@ -198,39 +198,39 @@ func TestBooleanOperators(t *testing.T) { parser := createParser(t, BooleanModel{}) tests := []struct { - name string - query string - wantSQL []string + name string + query string + wantSQL []string wantParams []any - wantErr bool + wantErr bool }{ { - name: "AND operator", - query: "name:john AND status:active", - wantSQL: []string{`"name"`, `"status"`, "AND"}, + name: "AND operator", + query: "name:john AND status:active", + wantSQL: []string{`"name"`, `"status"`, "AND"}, wantParams: []any{"john", "active"}, - wantErr: false, + wantErr: false, }, { - name: "OR operator", - query: "name:john OR name:jane", - wantSQL: []string{`"name"`, "OR"}, + name: "OR operator", + query: "name:john OR name:jane", + wantSQL: []string{`"name"`, "OR"}, wantParams: []any{"john", "jane"}, - wantErr: false, + wantErr: false, }, { - name: "NOT operator", - query: "name:john NOT status:inactive", - wantSQL: []string{`"name"`, `"status"`, "NOT"}, + name: "NOT operator", + query: "name:john NOT status:inactive", + wantSQL: []string{`"name"`, `"status"`, "NOT"}, wantParams: []any{"john", "inactive"}, - wantErr: false, + wantErr: false, }, { - name: "complex nested", - query: "(name:john OR name:jane) AND status:active", - wantSQL: []string{`"name"`, `"status"`, "OR", "AND"}, + name: "complex nested", + query: "(name:john OR name:jane) AND status:active", + wantSQL: []string{`"name"`, `"status"`, "OR", "AND"}, wantParams: []any{"john", "jane", "active"}, - wantErr: false, + wantErr: false, }, { name: "case insensitive AND", @@ -250,9 +250,9 @@ func TestBooleanOperators(t *testing.T) { assertSQLContains(t, sql, tt.wantSQL, tt.name) if len(tt.wantParams) > 0 { // Only validate params if we expect specific values - if len(tt.wantParams) > 0 { - assertParamsEqual(t, params, tt.wantParams, tt.name) - } + if len(tt.wantParams) > 0 { + assertParamsEqual(t, params, tt.wantParams, tt.name) + } } } }) @@ -309,46 +309,46 @@ func TestRangeQueries(t *testing.T) { parser := createParser(t, RangeModel{}) tests := []struct { - name string - query string - wantSQL []string + name string + query string + wantSQL []string wantParams []any - wantErr bool + wantErr bool }{ { - name: "inclusive range", - query: "age:[25 TO 65]", - wantSQL: []string{`"age"`, "BETWEEN"}, + name: "inclusive range", + query: "age:[25 TO 65]", + wantSQL: []string{`"age"`, "BETWEEN"}, wantParams: []any{"25", "65"}, - wantErr: false, + wantErr: false, }, { - name: "exclusive range", - query: "age:{25 TO 65}", - wantSQL: []string{`"age"`, ">", "<"}, + name: "exclusive range", + query: "age:{25 TO 65}", + wantSQL: []string{`"age"`, ">", "<"}, wantParams: []any{"25", "65"}, - wantErr: false, + wantErr: false, }, { - name: "open-ended range min", - query: "age:[25 TO *]", - wantSQL: []string{`"age"`, ">="}, + name: "open-ended range min", + query: "age:[25 TO *]", + wantSQL: []string{`"age"`, ">="}, wantParams: []any{"25"}, - wantErr: false, + wantErr: false, }, { - name: "open-ended range max", - query: "age:[* TO 65]", - wantSQL: []string{`"age"`, "<="}, + name: "open-ended range max", + query: "age:[* TO 65]", + wantSQL: []string{`"age"`, "<="}, wantParams: []any{"65"}, - wantErr: false, + wantErr: false, }, { - name: "date range", - query: "date:[2024-01-01 TO 2024-12-31]", - wantSQL: []string{`"date"`, "BETWEEN"}, + name: "date range", + query: "date:[2024-01-01 TO 2024-12-31]", + wantSQL: []string{`"date"`, "BETWEEN"}, wantParams: []any{"2024-01-01", "2024-12-31"}, - wantErr: false, + wantErr: false, }, } @@ -695,55 +695,55 @@ func TestNullValueQueries(t *testing.T) { parser := createParser(t, NullModel{}) tests := []struct { - name string - query string - wantSQL []string - wantNot []string + name string + query string + wantSQL []string + wantNot []string wantParams []any - wantErr bool + wantErr bool }{ { - name: "field is null (lowercase)", - query: "parent_id:null", - wantSQL: []string{`"parent_id"`, "IS NULL"}, - wantNot: []string{"=", "$1"}, + name: "field is null (lowercase)", + query: "parent_id:null", + wantSQL: []string{`"parent_id"`, "IS NULL"}, + wantNot: []string{"=", "$1"}, wantParams: []any{}, - wantErr: false, + wantErr: false, }, { - name: "field is NULL (uppercase)", - query: "parent_id:NULL", - wantSQL: []string{`"parent_id"`, "IS NULL"}, + name: "field is NULL (uppercase)", + query: "parent_id:NULL", + wantSQL: []string{`"parent_id"`, "IS NULL"}, wantParams: []any{}, - wantErr: false, + wantErr: false, }, { - name: "field is Null (mixed case)", - query: "parent_id:Null", - wantSQL: []string{`"parent_id"`, "IS NULL"}, + name: "field is Null (mixed case)", + query: "parent_id:Null", + wantSQL: []string{`"parent_id"`, "IS NULL"}, wantParams: []any{}, - wantErr: false, + wantErr: false, }, { - name: "combined null with other conditions", - query: "name:john AND deleted_at:null", - wantSQL: []string{`"name"`, `"deleted_at"`, "IS NULL", "AND"}, + name: "combined null with other conditions", + query: "name:john AND deleted_at:null", + wantSQL: []string{`"name"`, `"deleted_at"`, "IS NULL", "AND"}, wantParams: []any{"john"}, - wantErr: false, + wantErr: false, }, { - name: "NOT null (is not null)", - query: "NOT deleted_at:null", - wantSQL: []string{"NOT", `"deleted_at"`}, + name: "NOT null (is not null)", + query: "NOT deleted_at:null", + wantSQL: []string{"NOT", `"deleted_at"`}, wantParams: []any{"null"}, // NOT null is parsed as NOT field=null, not NOT field IS NULL - wantErr: false, + wantErr: false, }, { - name: "nil should be treated as literal value (not NULL)", - query: "name:nil", - wantSQL: []string{`"name"`, "=", "$1"}, + name: "nil should be treated as literal value (not NULL)", + query: "name:nil", + wantSQL: []string{`"name"`, "=", "$1"}, wantParams: []any{"nil"}, - wantErr: false, + wantErr: false, }, } @@ -761,9 +761,9 @@ func TestNullValueQueries(t *testing.T) { } if len(tt.wantParams) > 0 { // Only validate params if we expect specific values - if len(tt.wantParams) > 0 { - assertParamsEqual(t, params, tt.wantParams, tt.name) - } + if len(tt.wantParams) > 0 { + assertParamsEqual(t, params, tt.wantParams, tt.name) + } } } }) @@ -1007,11 +1007,11 @@ func TestParser_ValidateQuery(t *testing.T) { parser := createParser(t, BasicModel{}) tests := []struct { - name string - query string - config *ParserConfig - wantErr bool - wantError []string + name string + query string + config *ParserConfig + wantErr bool + wantError []string }{ { name: "valid query", @@ -1043,10 +1043,10 @@ func TestParser_ValidateQuery(t *testing.T) { wantErr: false, }, { - name: "custom limits - exceeds", - query: strings.Repeat("a", 201), - config: &ParserConfig{MaxQueryLength: 200}, - wantErr: true, + name: "custom limits - exceeds", + query: strings.Repeat("a", 201), + config: &ParserConfig{MaxQueryLength: 200}, + wantErr: true, }, { name: "empty query", @@ -1194,66 +1194,66 @@ func TestParser_ProviderSpecific(t *testing.T) { parser := createParser(t, BasicModel{}) tests := []struct { - name string - query string - provider string - wantSQL []string - wantNot []string + name string + query string + provider string + wantSQL []string + wantNot []string wantParams []any - wantErr bool + wantErr bool }{ { - name: "postgresql placeholder", - query: "name:john", - provider: "postgresql", - wantSQL: []string{"$1"}, - wantNot: []string{"?"}, + name: "postgresql placeholder", + query: "name:john", + provider: "postgresql", + wantSQL: []string{"$1"}, + wantNot: []string{"?"}, wantParams: []any{"john"}, - wantErr: false, + wantErr: false, }, { - name: "mysql placeholder", - query: "name:john", - provider: "mysql", - wantSQL: []string{"?"}, - wantNot: []string{"$"}, + name: "mysql placeholder", + query: "name:john", + provider: "mysql", + wantSQL: []string{"?"}, + wantNot: []string{"$"}, wantParams: []any{"john"}, - wantErr: false, + wantErr: false, }, { - name: "sqlite placeholder", - query: "name:john", - provider: "sqlite", - wantSQL: []string{"?"}, - wantNot: []string{"$"}, + name: "sqlite placeholder", + query: "name:john", + provider: "sqlite", + wantSQL: []string{"?"}, + wantNot: []string{"$"}, wantParams: []any{"john"}, - wantErr: false, + wantErr: false, }, { - name: "postgresql ILIKE", - query: "name:john*", - provider: "postgresql", - wantSQL: []string{"ILIKE"}, - wantNot: []string{"LOWER"}, + name: "postgresql ILIKE", + query: "name:john*", + provider: "postgresql", + wantSQL: []string{"ILIKE"}, + wantNot: []string{"LOWER"}, wantParams: []any{"john%"}, - wantErr: false, + wantErr: false, }, { - name: "mysql LOWER LIKE", - query: "name:john*", - provider: "mysql", - wantSQL: []string{"LOWER", "LIKE"}, + name: "mysql LOWER LIKE", + query: "name:john*", + provider: "mysql", + wantSQL: []string{"LOWER", "LIKE"}, wantParams: []any{"john%"}, - wantErr: false, + wantErr: false, }, { - name: "sqlite LIKE", - query: "name:john*", - provider: "sqlite", - wantSQL: []string{"LIKE"}, - wantNot: []string{"ILIKE", "LOWER"}, + name: "sqlite LIKE", + query: "name:john*", + provider: "sqlite", + wantSQL: []string{"LIKE"}, + wantNot: []string{"ILIKE", "LOWER"}, wantParams: []any{"john%"}, - wantErr: false, + wantErr: false, }, } @@ -1270,40 +1270,39 @@ func TestParser_ProviderSpecific(t *testing.T) { } if len(tt.wantParams) > 0 { // Only validate params if we expect specific values - if len(tt.wantParams) > 0 { - assertParamsEqual(t, params, tt.wantParams, tt.name) - } + if len(tt.wantParams) > 0 { + assertParamsEqual(t, params, tt.wantParams, tt.name) + } } } }) } } - // TestParser_ParseToDynamoDBPartiQL tests DynamoDB output func TestParser_ParseToDynamoDBPartiQL(t *testing.T) { parser := createParser(t, BasicModel{}) tests := []struct { - name string - query string + name string + query string wantPartiQL []string - wantCount int - wantErr bool + wantCount int + wantErr bool }{ { - name: "simple query", - query: "name:john", + name: "simple query", + query: "name:john", wantPartiQL: []string{"name"}, - wantCount: 1, - wantErr: false, + wantCount: 1, + wantErr: false, }, { - name: "AND query", - query: "name:john AND email:test", + name: "AND query", + query: "name:john AND email:test", wantPartiQL: []string{"AND"}, - wantCount: 2, - wantErr: false, + wantCount: 2, + wantErr: false, }, } diff --git a/storage/search/lucene/sql_driver.go b/storage/search/lucene/sql_driver.go index 63c05a0..6933dee 100644 --- a/storage/search/lucene/sql_driver.go +++ b/storage/search/lucene/sql_driver.go @@ -2,6 +2,7 @@ package lucene import ( "fmt" + "regexp" "strings" "github.com/grindlemire/go-lucene/pkg/driver" @@ -391,6 +392,11 @@ func (s *SQLDriver) processJSONFields(e *expr.Expression) { } } +var ( + // jsonSubFieldPattern matches valid JSON subfield names (alphanumeric, underscore, and dot for nested paths) + jsonSubFieldPattern = regexp.MustCompile(`^[a-zA-Z0-9_.]+$`) +) + // validateSubFieldName validates that a subfield name contains only safe characters. // Subfield names should be alphanumeric with underscores and dots for nested paths. // This prevents injection attacks via JSON path manipulation. @@ -399,12 +405,8 @@ func validateSubFieldName(subField string) error { return fmt.Errorf("subfield name cannot be empty") } - // Allow alphanumeric, underscore, and dot (for nested paths like "user.name") - // Reject any characters that could be used for injection (quotes, semicolons, etc.) - for _, r := range subField { - if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' || r == '.') { - return fmt.Errorf("invalid subfield name '%s': contains unsafe character '%c'", subField, r) - } + if !jsonSubFieldPattern.MatchString(subField) { + return fmt.Errorf("invalid subfield name '%s': contains unsafe characters (only alphanumeric, underscore, and dot allowed)", subField) } return nil } @@ -436,7 +438,7 @@ func (s *SQLDriver) formatFieldName(fieldName string) expr.Column { // Escape subfield name for safe interpolation // PostgreSQL uses ->>'key' syntax where key is in quotes, so we need to escape quotes escapedSubField := escapeJSONPathSegment(subField) - + switch s.provider { case "postgresql": // PostgreSQL: JSONB operator ->> diff --git a/storage/search/lucene/sql_driver_test.go b/storage/search/lucene/sql_driver_test.go index 94ebce3..e79234e 100644 --- a/storage/search/lucene/sql_driver_test.go +++ b/storage/search/lucene/sql_driver_test.go @@ -166,8 +166,8 @@ func TestSQLDriver_RenderParam(t *testing.T) { wantErr: false, }, { - name: "nil expression", - expr: nil, + name: "nil expression", + expr: nil, wantErr: false, }, } @@ -443,58 +443,58 @@ func TestSQLDriver_RenderComparison(t *testing.T) { wantErr bool }{ { - name: "equals", - provider: "postgresql", - op: expr.Equals, - right: &expr.Expression{Op: expr.Literal, Left: "john"}, - wantSQL: []string{`"name"`, "="}, + name: "equals", + provider: "postgresql", + op: expr.Equals, + right: &expr.Expression{Op: expr.Literal, Left: "john"}, + wantSQL: []string{`"name"`, "="}, wantCount: 1, - wantErr: false, + wantErr: false, }, { - name: "greater than", - provider: "postgresql", - op: expr.Greater, - right: &expr.Expression{Op: expr.Literal, Left: 25}, - wantSQL: []string{`"age"`, ">"}, + name: "greater than", + provider: "postgresql", + op: expr.Greater, + right: &expr.Expression{Op: expr.Literal, Left: 25}, + wantSQL: []string{`"age"`, ">"}, wantCount: 1, - wantErr: false, + wantErr: false, }, { - name: "less than", - provider: "postgresql", - op: expr.Less, - right: &expr.Expression{Op: expr.Literal, Left: 65}, - wantSQL: []string{`"age"`, "<"}, + name: "less than", + provider: "postgresql", + op: expr.Less, + right: &expr.Expression{Op: expr.Literal, Left: 65}, + wantSQL: []string{`"age"`, "<"}, wantCount: 1, - wantErr: false, + wantErr: false, }, { - name: "greater or equal", - provider: "postgresql", - op: expr.GreaterEq, - right: &expr.Expression{Op: expr.Literal, Left: 25}, - wantSQL: []string{`"age"`, ">="}, + name: "greater or equal", + provider: "postgresql", + op: expr.GreaterEq, + right: &expr.Expression{Op: expr.Literal, Left: 25}, + wantSQL: []string{`"age"`, ">="}, wantCount: 1, - wantErr: false, + wantErr: false, }, { - name: "less or equal", - provider: "postgresql", - op: expr.LessEq, - right: &expr.Expression{Op: expr.Literal, Left: 65}, - wantSQL: []string{`"age"`, "<="}, + name: "less or equal", + provider: "postgresql", + op: expr.LessEq, + right: &expr.Expression{Op: expr.Literal, Left: 65}, + wantSQL: []string{`"age"`, "<="}, wantCount: 1, - wantErr: false, + wantErr: false, }, { - name: "equals null", - provider: "postgresql", - op: expr.Equals, - right: &expr.Expression{Op: expr.Literal, Left: "null"}, - wantSQL: []string{`"name"`, "IS NULL"}, + name: "equals null", + provider: "postgresql", + op: expr.Equals, + right: &expr.Expression{Op: expr.Literal, Left: "null"}, + wantSQL: []string{`"name"`, "IS NULL"}, wantCount: 0, - wantErr: false, + wantErr: false, }, { name: "greater than null (error)", @@ -602,7 +602,7 @@ func TestSQLDriver_RenderBinary(t *testing.T) { Left: expr.Column("name"), Right: &expr.Expression{Op: expr.Literal, Left: "john"}, }, - right: nil, + right: nil, wantSQL: []string{`"name"`}, wantCount: 1, wantErr: false, @@ -616,7 +616,7 @@ func TestSQLDriver_RenderBinary(t *testing.T) { Left: expr.Column("name"), Right: &expr.Expression{Op: expr.Literal, Left: "john"}, }, - right: nil, + right: nil, wantSQL: []string{"NOT"}, wantCount: 1, wantErr: false, @@ -750,10 +750,10 @@ func TestSQLDriver_RenderRange(t *testing.T) { wantErr: true, }, { - name: "invalid range expression (error)", - provider: "postgresql", + name: "invalid range expression (error)", + provider: "postgresql", rangeExpr: nil, - wantErr: true, + wantErr: true, }, } @@ -849,9 +849,9 @@ func TestSQLDriver_SerializeColumn(t *testing.T) { wantErr: false, }, { - name: "invalid type", - input: 123, - wantErr: true, + name: "invalid type", + input: 123, + wantErr: true, }, } @@ -917,9 +917,9 @@ func TestSQLDriver_SerializeValue(t *testing.T) { wantErr: false, }, { - name: "nil value (error)", - input: nil, - wantErr: true, + name: "nil value (error)", + input: nil, + wantErr: true, }, { name: "integer value", @@ -1388,8 +1388,8 @@ func TestSQLDriver_RenderParamInternal(t *testing.T) { wantErr: false, }, { - name: "nil expression", - expr: nil, + name: "nil expression", + expr: nil, wantErr: false, }, } From 18068a4e7f66ae077d98f571de4d9b18644ec2dd Mon Sep 17 00:00:00 2001 From: Lutherwaves Date: Tue, 20 Jan 2026 01:56:38 +0200 Subject: [PATCH 10/13] fix(parser): fix OpenAPI validation, improve error handling with errors.Join, refactor to regex validation --- storage/search/lucene/parser.go | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/storage/search/lucene/parser.go b/storage/search/lucene/parser.go index 816fbe2..27a2b71 100644 --- a/storage/search/lucene/parser.go +++ b/storage/search/lucene/parser.go @@ -1,6 +1,7 @@ package lucene import ( + "errors" "fmt" "log/slog" "reflect" @@ -294,18 +295,24 @@ func (p *Parser) ParseToDynamoDBPartiQL(query string) (string, []types.Attribute } func (p *Parser) validateQuery(query string) error { + var errs []error + if len(query) > p.MaxQueryLength { - return fmt.Errorf("query too long: %d bytes exceeds maximum of %d bytes", len(query), p.MaxQueryLength) + errs = append(errs, fmt.Errorf("query too long: %d bytes exceeds maximum of %d bytes", len(query), p.MaxQueryLength)) } depth := calculateNestingDepth(query) if depth > p.MaxDepth { - return fmt.Errorf("query too complex: nesting depth %d exceeds maximum of %d", depth, p.MaxDepth) + errs = append(errs, fmt.Errorf("query too complex: nesting depth %d exceeds maximum of %d", depth, p.MaxDepth)) } terms := countTerms(query) if terms > p.MaxTerms { - return fmt.Errorf("query too large: %d terms exceeds maximum of %d", terms, p.MaxTerms) + errs = append(errs, fmt.Errorf("query too large: %d terms exceeds maximum of %d", terms, p.MaxTerms)) + } + + if len(errs) > 0 { + return errors.Join(errs...) } return nil From aa559f6060b6aa3dc09623ab25f7bde2e7ed586c Mon Sep 17 00:00:00 2001 From: Martin Yankovs Date: Tue, 10 Feb 2026 21:25:17 +0200 Subject: [PATCH 11/13] fix(sqlstorage): fix cursor pagination with json tag lookup and placeholder handling * Use JSON struct tags instead of Go field names for cursor value extraction * Remove PostgreSQL pre-conversion that conflicted with GORM's placeholders --- storage/search/lucene/parser_test.go | 18 +++++++++--------- storage/search/lucene/sql_driver.go | 14 +++++--------- storage/search/lucene/sql_driver_test.go | 17 ++++++----------- storage/sql.go | 23 ++++++++++++++++++++--- 4 files changed, 40 insertions(+), 32 deletions(-) diff --git a/storage/search/lucene/parser_test.go b/storage/search/lucene/parser_test.go index 8d0d747..c94fd4b 100644 --- a/storage/search/lucene/parser_test.go +++ b/storage/search/lucene/parser_test.go @@ -134,7 +134,7 @@ func TestBasicFieldSearch(t *testing.T) { { name: "simple field query", query: "name:john", - wantSQL: []string{`"name"`, "=", "$1"}, + wantSQL: []string{`"name"`, "=", "?"}, wantNot: []string{"ILIKE", "LIKE"}, wantParams: []any{"john"}, wantErr: false, @@ -142,7 +142,7 @@ func TestBasicFieldSearch(t *testing.T) { { name: "wildcard prefix", query: "name:john*", - wantSQL: []string{`"name"`, "ILIKE", "$1"}, + wantSQL: []string{`"name"`, "ILIKE", "?"}, wantNot: []string{"="}, wantParams: []any{"john%"}, wantErr: false, @@ -150,21 +150,21 @@ func TestBasicFieldSearch(t *testing.T) { { name: "wildcard suffix", query: "name:*john", - wantSQL: []string{`"name"`, "ILIKE", "$1"}, + wantSQL: []string{`"name"`, "ILIKE", "?"}, wantParams: []any{"%john"}, wantErr: false, }, { name: "wildcard contains", query: "name:*john*", - wantSQL: []string{`"name"`, "ILIKE", "$1"}, + wantSQL: []string{`"name"`, "ILIKE", "?"}, wantParams: []any{"%john%"}, wantErr: false, }, { name: "email field", query: `email:"test@example.com"`, - wantSQL: []string{`"email"`, "=", "$1"}, + wantSQL: []string{`"email"`, "=", "?"}, wantParams: []any{"test@example.com"}, wantErr: false, }, @@ -706,7 +706,7 @@ func TestNullValueQueries(t *testing.T) { name: "field is null (lowercase)", query: "parent_id:null", wantSQL: []string{`"parent_id"`, "IS NULL"}, - wantNot: []string{"=", "$1"}, + wantNot: []string{"=", "?"}, wantParams: []any{}, wantErr: false, }, @@ -741,7 +741,7 @@ func TestNullValueQueries(t *testing.T) { { name: "nil should be treated as literal value (not NULL)", query: "name:nil", - wantSQL: []string{`"name"`, "=", "$1"}, + wantSQL: []string{`"name"`, "=", "?"}, wantParams: []any{"nil"}, wantErr: false, }, @@ -1206,8 +1206,8 @@ func TestParser_ProviderSpecific(t *testing.T) { name: "postgresql placeholder", query: "name:john", provider: "postgresql", - wantSQL: []string{"$1"}, - wantNot: []string{"?"}, + wantSQL: []string{"?"}, + wantNot: []string{"$"}, wantParams: []any{"john"}, wantErr: false, }, diff --git a/storage/search/lucene/sql_driver.go b/storage/search/lucene/sql_driver.go index 6933dee..b2b3488 100644 --- a/storage/search/lucene/sql_driver.go +++ b/storage/search/lucene/sql_driver.go @@ -82,14 +82,10 @@ func (s *SQLDriver) RenderParam(e *expr.Expression) (string, []any, error) { return "", nil, err } - // Convert ? placeholders to provider-specific format - // PostgreSQL uses $1, $2, $3; MySQL and SQLite use ? - switch s.provider { - case "postgresql": - str = convertToPostgresPlaceholders(str) - case "mysql", "sqlite": - // Already uses ? placeholders, no conversion needed - } + // Keep ? placeholders for all providers. + // GORM's PostgreSQL driver handles ? → $N conversion automatically, + // so pre-converting here would conflict with additional WHERE clauses + // (e.g. cursor pagination) that GORM adds with its own ? placeholders. return str, params, nil } @@ -404,7 +400,7 @@ func validateSubFieldName(subField string) error { if subField == "" { return fmt.Errorf("subfield name cannot be empty") } - + if !jsonSubFieldPattern.MatchString(subField) { return fmt.Errorf("invalid subfield name '%s': contains unsafe characters (only alphanumeric, underscore, and dot allowed)", subField) } diff --git a/storage/search/lucene/sql_driver_test.go b/storage/search/lucene/sql_driver_test.go index e79234e..b001b82 100644 --- a/storage/search/lucene/sql_driver_test.go +++ b/storage/search/lucene/sql_driver_test.go @@ -204,14 +204,9 @@ func TestSQLDriver_RenderParam(t *testing.T) { if len(params) != tt.wantCount { t.Errorf("RenderParam() params count = %v, want %v", len(params), tt.wantCount) } - if provider == "postgresql" { - if !strings.Contains(sql, "$") { - t.Errorf("RenderParam() expected PostgreSQL placeholders ($1, $2), got %v", sql) - } - } else { - if strings.Contains(sql, "$") && !strings.Contains(sql, "?") { - t.Errorf("RenderParam() expected ? placeholders for %v, got %v", provider, sql) - } + // All providers use ? placeholders; GORM handles $N conversion for PostgreSQL + if tt.wantCount > 0 && !strings.Contains(sql, "?") { + t.Errorf("RenderParam() expected ? placeholders for %v, got %v", provider, sql) } }) } @@ -1443,11 +1438,11 @@ func TestSQLDriver_ProviderSpecific(t *testing.T) { Left: expr.Column("name"), Right: &expr.Expression{Op: expr.Literal, Left: "john"}, }, - wantSQL: []string{"$1"}, + wantSQL: []string{"?"}, wantCount: 1, checkFunc: func(t *testing.T, sql string, params []any) { - if !strings.Contains(sql, "$") { - t.Errorf("expected PostgreSQL placeholder ($1), got %v", sql) + if !strings.Contains(sql, "?") { + t.Errorf("expected ? placeholder, got %v", sql) } }, }, diff --git a/storage/sql.go b/storage/sql.go index 39ac883..acbef79 100644 --- a/storage/sql.go +++ b/storage/sql.go @@ -260,11 +260,11 @@ func (s *SQLAdapter) executePaginatedQuery( nextCursor := "" if destSlice.Len() > limit { lastItem := destSlice.Index(limit - 1) - field := reflect.Indirect(lastItem).FieldByName(sortKey) + field := findFieldByJSONTag(reflect.Indirect(lastItem), sortKey) if !field.IsValid() { - slog.Warn("cursor extraction failed: sort_key does not match any exported struct field", + slog.Warn("cursor extraction failed: sort_key does not match any json tag on the struct", "sort_key", sortKey, - "hint", "sort_key must be the Go struct field name (e.g. 'CreatedAt'), not the DB column name (e.g. 'created_at')") + "hint", "sort_key must match a json tag (e.g. 'created_at'), not the Go field name (e.g. 'CreatedAt')") } else if field.Kind() != reflect.String { slog.Warn("cursor extraction failed: struct field is not a string", "sort_key", sortKey, @@ -278,6 +278,23 @@ func (s *SQLAdapter) executePaginatedQuery( return nextCursor, nil } +// findFieldByJSONTag looks up a struct field by its json tag name. +// This is needed because sortKey uses the JSON/column name (e.g. "id") +// while Go struct fields use PascalCase (e.g. "Id"). +func findFieldByJSONTag(v reflect.Value, tag string) reflect.Value { + t := v.Type() + for i := 0; i < t.NumField(); i++ { + jsonTag := t.Field(i).Tag.Get("json") + if idx := strings.Index(jsonTag, ","); idx != -1 { + jsonTag = jsonTag[:idx] + } + if jsonTag == tag { + return v.Field(i) + } + } + return reflect.Value{} +} + func (s *SQLAdapter) List(dest any, sortKey string, filter map[string]any, limit int, cursor string, params ...map[string]any) (string, error) { sortDirection, err := extractSortDirection(extractParams(params...)) if err != nil { From 4a9a2938b4714b6d700372ea447fa9f7f5e0b93e Mon Sep 17 00:00:00 2001 From: Martin Yankovs Date: Tue, 10 Mar 2026 00:25:52 +0200 Subject: [PATCH 12/13] fix(lucene): handle null keyword in grouped OR field expressions - tokenizeQuery now keeps field:(a OR b OR null) as a single token so that expandImplicitTerms does not treat the inner values as bare implicit search terms and expand them across all fields - renderComparison now passes grouped OR/AND sub-expressions through directly instead of wrapping them in an outer equality check, fixing the 'field = (OR subquery)' invalid SQL that was generated when go-lucene parses field:(a OR null) with WithDefaultField - isNullValue extended to handle Go nil and bool(false), since some go-lucene parser paths represent the bare null keyword as bool(false) in grouped expressions --- storage/search/lucene/parser.go | 43 ++++++++++++++++-- storage/search/lucene/parser_test.go | 57 ++++++++++++++++++++++++ storage/search/lucene/sql_driver.go | 32 ++++++++++--- storage/search/lucene/sql_driver_test.go | 10 +++++ 4 files changed, 134 insertions(+), 8 deletions(-) diff --git a/storage/search/lucene/parser.go b/storage/search/lucene/parser.go index 27a2b71..32590f3 100644 --- a/storage/search/lucene/parser.go +++ b/storage/search/lucene/parser.go @@ -671,7 +671,11 @@ func (p *Parser) expandImplicitTerms(query string) string { return strings.Join(result, " ") } -// tokenizeQuery splits query into tokens, preserving quoted strings and range brackets. +// tokenizeQuery splits query into tokens, preserving quoted strings, range brackets, +// and field-grouped expressions (field:(a OR b OR null)). +// When a field: token is immediately followed by '(', the entire field:(...) construct +// is kept as a single token so that expandImplicitTerms does not incorrectly expand +// the inner terms as bare implicit search terms. func tokenizeQuery(query string) []string { var tokens []string var current strings.Builder @@ -716,8 +720,41 @@ func tokenizeQuery(query string) []string { continue } - // Handle parentheses as separate tokens - if !inQuotes && !inRange && (c == '(' || c == ')') { + // Handle parentheses. + // If the current buffer ends with ':' (i.e. we just finished a field name like "field:"), + // treat the entire field:(...) grouped expression as one token so that terms inside + // the group are not mistakenly treated as implicit search terms. + if !inQuotes && !inRange && c == '(' { + if current.Len() > 0 && current.String()[current.Len()-1] == ':' { + // Consume the entire parenthesised group as part of this token. + current.WriteByte(c) + parenDepth := 1 + i++ + for i < len(query) && parenDepth > 0 { + ch := query[i] + current.WriteByte(ch) + if ch == '(' { + parenDepth++ + } else if ch == ')' { + parenDepth-- + } + if parenDepth > 0 { + i++ + } + } + // The loop exits with i pointing at the closing ')' already written. + continue + } + // Standalone '(' — emit as its own token. + if current.Len() > 0 { + tokens = append(tokens, current.String()) + current.Reset() + } + tokens = append(tokens, string(c)) + continue + } + + if !inQuotes && !inRange && c == ')' { if current.Len() > 0 { tokens = append(tokens, current.String()) current.Reset() diff --git a/storage/search/lucene/parser_test.go b/storage/search/lucene/parser_test.go index c94fd4b..bfca68c 100644 --- a/storage/search/lucene/parser_test.go +++ b/storage/search/lucene/parser_test.go @@ -114,6 +114,7 @@ type MixedModel struct { type NullModel struct { Name string `json:"name"` ParentID string `json:"parent_id"` + TenantID string `json:"tenant_id"` DeletedAt string `json:"deleted_at"` AttachmentIDs string `json:"attachment_ids"` } @@ -770,6 +771,62 @@ func TestNullValueQueries(t *testing.T) { } } +// TestGroupedORWithNull tests that field:(value OR null) produces IS NULL for the null term. +// Uses a single-field model to avoid implicit-search expansion of the bare terms. +// Regression test for: go-lucene parses field:(a OR null) with WithDefaultField into +// EQUALS(field, OR(EQUALS(field,a), EQUALS(field,"null"))), which must not wrap the OR +// in another equality check. +func TestGroupedORWithNull(t *testing.T) { + type TenantModel struct { + TenantID string `json:"tenant_id"` + } + parser := createParser(t, TenantModel{}) + + tests := []struct { + name string + query string + wantSQL []string + wantNot []string + wantParams []any + }{ + { + name: "grouped OR with null keyword", + query: `tenant_id:(abc123 OR null)`, + wantSQL: []string{`"tenant_id"`, "IS NULL", "OR"}, + wantNot: []string{"= FALSE", "= false", "ILIKE"}, + wantParams: []any{"abc123"}, + }, + { + name: "grouped OR with multiple values and null", + query: `tenant_id:(abc123 OR def456 OR null)`, + wantSQL: []string{`"tenant_id"`, "IS NULL", "OR"}, + wantNot: []string{"= FALSE", "= false"}, + wantParams: []any{"abc123", "def456"}, + }, + { + name: "grouped OR without null is unchanged", + query: `tenant_id:(abc123 OR def456)`, + wantSQL: []string{`"tenant_id"`, "OR"}, + wantNot: []string{"IS NULL", "= FALSE"}, + wantParams: []any{"abc123", "def456"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, params, err := parser.ParseToSQL(tt.query, "postgresql") + if err != nil { + t.Fatalf("ParseToSQL(%q) error = %v", tt.query, err) + } + assertSQLContains(t, sql, tt.wantSQL, tt.name) + if len(tt.wantNot) > 0 { + assertSQLNotContains(t, sql, tt.wantNot, tt.name) + } + assertParamsEqual(t, params, tt.wantParams, tt.name) + }) + } +} + // TestEmptyAsLiteralValue tests that 'empty' is treated as a literal value func TestEmptyAsLiteralValue(t *testing.T) { parser, err := NewParser(BooleanModel{}) diff --git a/storage/search/lucene/sql_driver.go b/storage/search/lucene/sql_driver.go index b2b3488..4eb978a 100644 --- a/storage/search/lucene/sql_driver.go +++ b/storage/search/lucene/sql_driver.go @@ -210,6 +210,22 @@ func (s *SQLDriver) renderComparison(e *expr.Expression) (string, []any, error) return "", nil, fmt.Errorf("cannot use comparison operators (>, <, >=, <=) with null value") } + // When go-lucene parses grouped OR/AND expressions like field:(a OR b OR null) with a + // default field set, it produces EQUALS(field, OR(EQUALS(field, a), EQUALS(field, null))). + // The right side is already a fully-formed boolean expression — render it directly instead + // of wrapping it in another equality comparison (which would produce invalid SQL like + // "field" = (("field" = ?) OR ("field" IS NULL))). + if rightExpr, ok := e.Right.(*expr.Expression); ok && e.Op == expr.Equals { + switch rightExpr.Op { + case expr.Or, expr.And: + rightStr, rightParams, err := s.renderParamInternal(rightExpr) + if err != nil { + return "", nil, err + } + return rightStr, append(leftParams, rightParams...), nil + } + } + rightStr, rightParams, err := s.serializeValue(e.Right) if err != nil { return "", nil, err @@ -502,16 +518,22 @@ func isJSONSyntax(col string) bool { } // isNullValue checks if a value represents null in Lucene query syntax. -// Supports: null, NULL, Null (case-insensitive) -// Note: This is a SQL-specific extension (vanilla Lucene doesn't support NULL values). -// We intentionally do NOT support "empty" or "nil" as they could be legitimate search values. +// Handles the string "null" (case-insensitive), Go nil, and Go bool(false). +// Note: go-lucene parses the bare `null` keyword as bool(false) in grouped +// expressions like field:(a OR null), bypassing the string literal path. func isNullValue(v any) bool { + if v == nil { + return true + } + // go-lucene parses the bare `null` keyword as bool(false) in grouped OR/AND + if b, ok := v.(bool); ok && !b { + return true + } strVal := extractStringValue(v) if strVal == "" { return false } - lower := strings.ToLower(strVal) - return lower == "null" + return strings.ToLower(strVal) == "null" } func extractStringValue(v any) string { diff --git a/storage/search/lucene/sql_driver_test.go b/storage/search/lucene/sql_driver_test.go index b001b82..ecf0ec6 100644 --- a/storage/search/lucene/sql_driver_test.go +++ b/storage/search/lucene/sql_driver_test.go @@ -1153,6 +1153,16 @@ func TestIsNullValue(t *testing.T) { { name: "nil value", input: nil, + want: true, + }, + { + name: "bool false (go-lucene grouped OR null representation)", + input: false, + want: true, + }, + { + name: "bool true (not null)", + input: true, want: false, }, { From 8d5ee9b640c27e6c52b79e14ae902bebaba86ae1 Mon Sep 17 00:00:00 2001 From: Martin Yankovs Date: Tue, 10 Mar 2026 01:33:32 +0200 Subject: [PATCH 13/13] fix(lucene): renderGroupedFieldExpr preserves outer field name in grouped OR/AND expressions go-lucene parses `tenant_id:(abc123 OR null)` as EQUALS(tenant_id, OR(EQUALS(default_field, abc123), EQUALS(default_field, null))). The previous fix called renderParamInternal on the OR subtree directly, which used the default field (e.g. "id") instead of the outer field ("tenant_id"), producing wrong SQL. The new renderGroupedFieldExpr/renderGroupedFieldLeaf helpers walk the OR/AND tree and re-render each leaf against the correct outer field. Co-Authored-By: Claude Sonnet 4.6 --- storage/search/lucene/sql_driver.go | 60 +++++++++++++++++++++--- storage/search/lucene/sql_driver_test.go | 59 +++++++++++++++++++++++ 2 files changed, 113 insertions(+), 6 deletions(-) diff --git a/storage/search/lucene/sql_driver.go b/storage/search/lucene/sql_driver.go index 4eb978a..8ee103d 100644 --- a/storage/search/lucene/sql_driver.go +++ b/storage/search/lucene/sql_driver.go @@ -211,18 +211,18 @@ func (s *SQLDriver) renderComparison(e *expr.Expression) (string, []any, error) } // When go-lucene parses grouped OR/AND expressions like field:(a OR b OR null) with a - // default field set, it produces EQUALS(field, OR(EQUALS(field, a), EQUALS(field, null))). - // The right side is already a fully-formed boolean expression — render it directly instead - // of wrapping it in another equality comparison (which would produce invalid SQL like - // "field" = (("field" = ?) OR ("field" IS NULL))). + // default field set, it produces EQUALS(outer_field, OR(EQUALS(default_field, a), EQUALS(default_field, null))). + // The inner leaves use the default field, not the outer field. We must re-render each leaf + // using leftStr (the correct outer field) to avoid producing wrong SQL like + // ("id" = ?) OR ("id" IS NULL) when the query was tenant_id:(abc123 OR null). if rightExpr, ok := e.Right.(*expr.Expression); ok && e.Op == expr.Equals { switch rightExpr.Op { case expr.Or, expr.And: - rightStr, rightParams, err := s.renderParamInternal(rightExpr) + groupStr, groupParams, err := s.renderGroupedFieldExpr(leftStr, rightExpr) if err != nil { return "", nil, err } - return rightStr, append(leftParams, rightParams...), nil + return groupStr, append(leftParams, groupParams...), nil } } @@ -250,6 +250,54 @@ func (s *SQLDriver) renderComparison(e *expr.Expression) (string, []any, error) return fmt.Sprintf("%s %s %s", leftStr, opSymbol, rightStr), params, nil } +// renderGroupedFieldExpr renders an OR/AND expression tree where each leaf comparison +// should use the given fieldSQL column instead of whatever field the leaf has internally. +// This handles go-lucene's behavior of wrapping grouped field expressions as +// EQUALS(outer_field, OR(EQUALS(default_field, v1), EQUALS(default_field, v2))). +func (s *SQLDriver) renderGroupedFieldExpr(fieldSQL string, e *expr.Expression) (string, []any, error) { + leftStr, leftParams, err := s.renderGroupedFieldLeaf(fieldSQL, e.Left) + if err != nil { + return "", nil, err + } + + if e.Right == nil { + return leftStr, leftParams, nil + } + + rightStr, rightParams, err := s.renderGroupedFieldLeaf(fieldSQL, e.Right) + if err != nil { + return "", nil, err + } + + op := " OR " + if e.Op == expr.And { + op = " AND " + } + return fmt.Sprintf("(%s%s%s)", leftStr, op, rightStr), append(leftParams, rightParams...), nil +} + +// renderGroupedFieldLeaf renders a single node (leaf or sub-tree) in a grouped field expression, +// always using fieldSQL as the column name regardless of what the node's own field is. +func (s *SQLDriver) renderGroupedFieldLeaf(fieldSQL string, v any) (string, []any, error) { + if e, ok := v.(*expr.Expression); ok { + if e.Op == expr.Or || e.Op == expr.And { + return s.renderGroupedFieldExpr(fieldSQL, e) + } + if e.Op == expr.Equals { + // Use the value from this leaf but with the outer field + return s.renderGroupedFieldLeaf(fieldSQL, e.Right) + } + } + if isNullValue(v) { + return fmt.Sprintf("%s IS NULL", fieldSQL), nil, nil + } + valStr, valParams, err := s.serializeValue(v) + if err != nil { + return "", nil, err + } + return fmt.Sprintf("%s = %s", fieldSQL, valStr), valParams, nil +} + // renderBinary handles binary and unary logical operators recursively. // Note: Must and MustNot are unary (only Left operand), while And and Or are binary. func (s *SQLDriver) renderBinary(e *expr.Expression) (string, []any, error) { diff --git a/storage/search/lucene/sql_driver_test.go b/storage/search/lucene/sql_driver_test.go index ecf0ec6..edeb6fc 100644 --- a/storage/search/lucene/sql_driver_test.go +++ b/storage/search/lucene/sql_driver_test.go @@ -536,6 +536,65 @@ func TestSQLDriver_RenderComparison(t *testing.T) { } } +func TestSQLDriver_GroupedFieldOR(t *testing.T) { + fields := []FieldInfo{ + {Name: "tenant_id", Type: reflect.TypeOf("")}, + {Name: "id", Type: reflect.TypeOf("")}, + } + + // Simulate go-lucene output for: tenant_id:(abc123 OR null) + // EQUALS(tenant_id, OR(EQUALS(default_field/id, abc123), EQUALS(default_field/id, null))) + innerOR := &expr.Expression{ + Op: expr.Or, + Left: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("id"), // default field — wrong field, bug we're fixing + Right: &expr.Expression{Op: expr.Literal, Left: "abc123"}, + }, + Right: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("id"), // default field — wrong field, bug we're fixing + Right: false, // go-lucene represents null as bool(false) + }, + } + outerEq := &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("tenant_id"), + Right: innerOR, + } + + for _, provider := range []string{"postgresql", "mysql", "sqlite"} { + t.Run("tenant_id grouped OR with null keyword/"+provider, func(t *testing.T) { + driver, err := NewSQLDriver(fields, provider) + if err != nil { + t.Fatalf("NewSQLDriver() error = %v", err) + } + sql, params, err := driver.RenderParam(outerEq) + if err != nil { + t.Fatalf("RenderParam() error = %v", err) + } + // Must use tenant_id, not id + if strings.Contains(sql, `"id"`) { + t.Errorf("RenderParam() sql = %v — uses wrong field 'id', want 'tenant_id'", sql) + } + if !strings.Contains(sql, `"tenant_id"`) { + t.Errorf("RenderParam() sql = %v — missing 'tenant_id'", sql) + } + if !strings.Contains(sql, "IS NULL") { + t.Errorf("RenderParam() sql = %v — missing IS NULL for null value", sql) + } + if !strings.Contains(sql, "OR") { + t.Errorf("RenderParam() sql = %v — missing OR", sql) + } + // Should have exactly 1 param (for abc123), null uses IS NULL not a placeholder + if len(params) != 1 { + t.Errorf("RenderParam() params count = %v, want 1", len(params)) + } + t.Logf("SQL: %s | Params: %v", sql, params) + }) + } +} + func TestSQLDriver_RenderBinary(t *testing.T) { fields := []FieldInfo{ {Name: "name", Type: reflect.TypeOf("")},