diff --git a/README.md b/README.md index d2951ce..d3baaa5 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ Features: - Fast, low-allocation parser and runtime - Many simple expressions are zero-allocation - Type checking during parsing +- Type conversion for `func()` - Simple - Easy to learn - Easy to read @@ -137,6 +138,8 @@ Math operations between constants are precomputed when possible, so it is effici - `and` - `or` +Both `and` and `or` are short-circuited. + ```py 1 < 2 and 3 < 4 ``` @@ -148,6 +151,20 @@ Non-boolean values are converted to booleans. The following result in `true`: - array with at least one item - map with at least one key/value pair +### Functions + +- `identifier(...)` + +Functions can be called by providing them in the variables map. + +```go +result, err := mexpr.Eval("myFunc(a, b)", map[string]interface{}{ + "myFunc": func(a, b int) int { return a + b }, + "a": 1, + "b": 2, +}) +``` + ### String operators - Indexing, e.g. `foo[0]` @@ -221,6 +238,21 @@ not (items where id > 3) - `in` (has key), e.g. `"key" in foo` - `contains` e.g. `foo contains "key"` +### Conversions + +Any value concatenated with a string will result in a string. For example `"id" + 1` will result in `"id1"`. + +The value of a variable can be mapped to a function. This allows the implementor to use functions to retrieve actual values of variables rather than pre-computing values: + +```go +result, _ := mexpr.Eval(`id + 1`, map[string]interface{}{ + "id": func() int { return 123 }, +}) +// result is 124 +``` + +In combination with short-circuiting with and/or it allows lazy evaluation. + #### Map wildcard filtering A `where` clause can be used as a wildcard key to filter values for all keys in a map. The left side of the clause is the map to be filtered, while the right side is an expression to run on each value of the map. If the right side expression evaluates to true then the value is added to the result slice. For example, given: diff --git a/conversions.go b/conversions.go index 1195302..16aeb1c 100644 --- a/conversions.go +++ b/conversions.go @@ -12,6 +12,10 @@ func isNumber(v interface{}) bool { return true case float32, float64: return true + case func() int: + return true + case func() float64: + return true } return false } @@ -42,6 +46,10 @@ func toNumber(ast *Node, v interface{}) (float64, Error) { return float64(n), nil case float32: return float64(n), nil + case func() int: + return float64(n()), nil + case func() float64: + return n(), nil } return 0, NewError(ast.Offset, ast.Length, "unable to convert to number: %v", v) } @@ -64,6 +72,8 @@ func toString(v interface{}) string { return string(s) case []byte: return string(s) + case func() string: + return s() } return fmt.Sprintf("%v", v) } @@ -162,6 +172,10 @@ func normalize(v interface{}) interface{} { return float64(n) case []byte: return string(n) + case func() int: + return float64(n()) + case func() float64: + return n() } return v diff --git a/interpreter.go b/interpreter.go index a177dcb..f04efac 100644 --- a/interpreter.go +++ b/interpreter.go @@ -96,6 +96,9 @@ func (i *interpreter) run(ast *Node, value any) (any, Error) { return value, nil case "length": // Special pseudo-property to get the value's length. + if s, ok := value.(func() string); ok { + return len(s()), nil + } if s, ok := value.(string); ok { return len(s), nil } @@ -103,17 +106,34 @@ func (i *interpreter) run(ast *Node, value any) (any, Error) { return len(a), nil } case "lower": + if s, ok := value.(func() string); ok { + return strings.ToLower(s()), nil + } if s, ok := value.(string); ok { return strings.ToLower(s), nil } case "upper": + if s, ok := value.(func() string); ok { + return strings.ToUpper(s()), nil + } if s, ok := value.(string); ok { return strings.ToUpper(s), nil } } if m, ok := value.(map[string]any); ok { if v, ok := m[ast.Value.(string)]; ok { - return v, nil + switch n := v.(type) { + case func() int: + return n(), nil + case func() float64: + return n(), nil + case func() bool: + return n(), nil + case func() string: + return n(), nil + default: + return v, nil + } } } if m, ok := value.(map[any]any); ok { @@ -335,11 +355,21 @@ func (i *interpreter) run(ast *Node, value any) (any, Error) { if err != nil { return nil, err } + left := toBool(resultLeft) + switch ast.Type { + case NodeAnd: + if !left { + return left, nil + } + case NodeOr: + if left { + return left, nil + } + } resultRight, err := i.run(ast.Right, value) if err != nil { return nil, err } - left := toBool(resultLeft) right := toBool(resultRight) switch ast.Type { case NodeAnd: @@ -470,6 +500,75 @@ func (i *interpreter) run(ast *Node, value any) (any, Error) { } } return results, nil + case NodeFunctionCall: + funcName := ast.Left.Value.(string) + if m, ok := value.(map[string]any); ok { + if fn, ok := m[funcName]; ok { + // Get function parameters + params := []any{} + for _, param := range ast.Value.([]Node) { + paramValue, err := i.run(¶m, value) + if err != nil { + return nil, err + } + params = append(params, paramValue) + } + + // Execute function based on parameter count + switch f := fn.(type) { + case func() any: + if len(params) != 0 { + return nil, NewError(ast.Offset, ast.Length, "function %s expects 0 parameter, got %d", funcName, len(params)) + } + result := f() + switch result.(type) { + case error: + return nil, NewError(ast.Offset, ast.Length, "Runtime error: %v", result.(error)) + default: + return result, nil + } + case func(any) any: + if len(params) != 1 { + return nil, NewError(ast.Offset, ast.Length, "function %s expects 1 parameter, got %d", funcName, len(params)) + } + result := f(params[0]) + switch result.(type) { + case error: + return nil, NewError(ast.Offset, ast.Length, "Runtime error: %v", result.(error)) + default: + return result, nil + } + case func(any, any) any: + if len(params) != 2 { + return nil, NewError(ast.Offset, ast.Length, "function %s expects 2 parameters, got %d", funcName, len(params)) + } + result := f(params[0], params[1]) + switch result.(type) { + case error: + return nil, NewError(ast.Offset, ast.Length, "Runtime error: %v", result.(error)) + default: + return result, nil + } + case func(any, any, any) any: + if len(params) != 3 { + return nil, NewError(ast.Offset, ast.Length, "function %s expects 3 parameters, got %d", funcName, len(params)) + } + result := f(params[0], params[1], params[2]) + switch result.(type) { + case error: + return nil, NewError(ast.Offset, ast.Length, "Runtime error: %v", result.(error)) + default: + return result, nil + } + + } + return nil, NewError(ast.Offset, ast.Length, "unsupported function type for %s", funcName) + } + } + if i.strict { + return nil, NewError(ast.Offset, ast.Length, "function %s not found", funcName) + } + return nil, nil } return nil, nil } diff --git a/interpreter_test.go b/interpreter_test.go index 9ff9a0a..0ab51f2 100644 --- a/interpreter_test.go +++ b/interpreter_test.go @@ -2,6 +2,7 @@ package mexpr import ( "encoding/json" + "fmt" "reflect" "strings" "testing" @@ -232,6 +233,124 @@ func TestInterpreter(t *testing.T) { } } +func TestFunctions(t *testing.T) { + + varMap := make(map[string]interface{}) + + varMap["func0"] = func() any { + return 43.0 + } + + varMap["func1"] = func(param1 any) any { + switch param1.(type) { + case float64: + return param1.(float64) * 2 + default: + return fmt.Errorf("Invalid type for param1") + } + } + + varMap["func2"] = func(param1 any, param2 any) any { + switch param1.(type) { + case float64: + switch param2.(type) { + case float64: + return param1.(float64) * param2.(float64) + default: + return fmt.Errorf("Invalid type for param2") + } + default: + return fmt.Errorf("Invalid type for param1") + } + } + + varMap["func3"] = func(param1 any, param2 any, param3 any) any { + switch param1.(type) { + case float64: + switch param2.(type) { + case float64: + switch param3.(type) { + case float64: + return param1.(float64) * param2.(float64) * param3.(float64) + default: + return fmt.Errorf("Invalid type for param3") + } + default: + return fmt.Errorf("Invalid type for param2") + } + default: + return fmt.Errorf("Invalid type for param1") + } + } + + type test struct { + expr string + output interface{} + err string + } + cases := []test{ + {expr: "func0()", output: 43.0}, + {expr: "func1(42)", output: 84.0}, + {expr: "func2(3,4)", output: 12.0}, + {expr: "func3(2,3,4)", output: 24.0}, + {expr: "func0(42)", err: "expects 0 parameter"}, + {expr: "func1()", err: "expects 1 parameter"}, + {expr: "func1(1,2)", err: "expects 1 parameter"}, + {expr: "func2()", err: "expects 2 parameters"}, + {expr: "func2(1)", err: "expects 2 parameters"}, + {expr: "func2(1,2,3)", err: "expects 2 parameters"}, + {expr: "func3()", err: "expects 3 parameters"}, + {expr: "func3(1)", err: "expects 3 parameters"}, + {expr: "func3(1,2)", err: "expects 3 parameters"}, + {expr: "func3(1,2,3,4)", err: "expects 3 parameters"}, + {expr: "func1(\"foo\")", err: "Invalid type for"}, + {expr: "func2(\"foo\",\"bar\")", err: "Invalid type for"}, + {expr: "func3(\"foo\",\"qux\",\"quz\")", err: "Invalid type for"}, + } + + for _, tc := range cases { + t.Run(tc.expr, func(t *testing.T) { + + ast, err := Parse(tc.expr, nil) + + if ast != nil { + t.Log("graph G {\n" + ast.Dot("") + "\n}") + } + + if tc.err != "" { + if err != nil { + if strings.Contains(err.Error(), tc.err) { + return + } + t.Fatal(err.Pretty(tc.expr)) + } + } else { + if err != nil { + t.Fatal(err.Pretty(tc.expr)) + } + } + + result, err := Run(ast, varMap, StrictMode) + if tc.err != "" { + if err == nil { + t.Fatal("expected error but found none") + } + if strings.Contains(err.Error(), tc.err) { + return + } + t.Fatal(err.Pretty(tc.expr)) + } else { + if err != nil { + t.Fatal(err.Pretty(tc.expr)) + } + if !reflect.DeepEqual(tc.output, result) { + t.Fatalf("expected %v but found %v", tc.output, result) + } + } + }) + } +} + func FuzzMexpr(f *testing.F) { f.Fuzz(func(t *testing.T, s string) { Eval(s, nil) diff --git a/lexer.go b/lexer.go index aacb7fd..446a4b2 100644 --- a/lexer.go +++ b/lexer.go @@ -31,6 +31,7 @@ const ( TokenStringCompare TokenWhere TokenEOF + TokenComma // New token type for separating function parameters ) func (t TokenType) String() string { @@ -73,6 +74,8 @@ func (t TokenType) String() string { return "where" case TokenEOF: return "eof" + case TokenComma: + return "comma" } return "unknown" } @@ -97,6 +100,8 @@ func basic(input rune) TokenType { return TokenMulDiv case '^': return TokenPower + case ',': + return TokenComma } return TokenUnknown diff --git a/parser.go b/parser.go index 72f8fb7..2e1b161 100644 --- a/parser.go +++ b/parser.go @@ -39,6 +39,7 @@ const ( NodeBefore NodeAfter NodeWhere + NodeFunctionCall ) // Node is a unit of the binary tree that makes up the abstract syntax tree. @@ -144,6 +145,7 @@ var bindingPowers = map[TokenType]int{ TokenPower: 50, TokenLeftBracket: 60, TokenLeftParen: 70, + TokenComma: 1, // Low binding power for parameter lists } // precomputeLiterals takes two `NodeLiteral` nodes and a math operation and @@ -430,6 +432,63 @@ func (p *parser) led(t *Token, n *Node) (*Node, Error) { } nn.Value = []interface{}{0.0, 0.0} return nn, nil + case TokenLeftParen: + // Only treat as function call if left node is identifier + if n.Type != NodeIdentifier { + return nil, NewError(t.Offset, t.Length, "unexpected left parenthesis") + } + + // Parse function parameters + params := []Node{} + offset := t.Offset + + // Handle empty parameter list + if p.token.Type == TokenRightParen { + if err := p.advance(); err != nil { + return nil, err + } + return &Node{ + Type: NodeFunctionCall, + Left: n, + Value: params, + Offset: offset, + Length: uint8(p.token.Offset + uint16(p.token.Length) - offset), + }, nil + } + + // Parse parameters + for { + param, err := p.parse(bindingPowers[TokenComma]) + if err != nil { + return nil, err + } + if param == nil { + return nil, NewError(p.token.Offset, p.token.Length, "expected parameter") + } + params = append(params, *param) + + if p.token.Type == TokenRightParen { + if err := p.advance(); err != nil { + return nil, err + } + break + } + + if p.token.Type != TokenComma { + return nil, NewError(p.token.Offset, p.token.Length, "expected comma or right parenthesis") + } + if err := p.advance(); err != nil { + return nil, err + } + } + + return &Node{ + Type: NodeFunctionCall, + Left: n, + Value: params, + Offset: offset, + Length: uint8(p.token.Offset + uint16(p.token.Length) - offset), + }, nil } return nil, NewError(t.Offset, t.Length, "unexpected token %s", t.Type) }