Skip to content

Commit f4bbea5

Browse files
committed
Add support deref in all builtins
1 parent 7d564c0 commit f4bbea5

File tree

3 files changed

+52
-13
lines changed

3 files changed

+52
-13
lines changed

builtin/builtin_test.go

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ func Test_int_unwraps_underlying_value(t *testing.T) {
641641

642642
func TestBuiltin_with_deref(t *testing.T) {
643643
x := 42
644-
arr := []any{1, 2, 3}
644+
arr := []int{1, 2, 3}
645645
arrStr := []string{"1", "2", "3"}
646646
m := map[string]any{"a": 1, "b": 2}
647647
jsonString := `["1"]`
@@ -659,13 +659,31 @@ func TestBuiltin_with_deref(t *testing.T) {
659659
input string
660660
want any
661661
}{
662+
{`all(arr, # > 0)`, true},
663+
{`none(arr, # < 0)`, true},
664+
{`any(arr, # > 0)`, true},
665+
{`one(arr, # > 2)`, true},
666+
{`filter(arr, # > 0)`, []any{1, 2, 3}},
667+
{`map(arr, # * #)`, []any{1, 4, 9}},
668+
{`count(arr, # > 0)`, 3},
669+
{`sum(arr)`, 6},
670+
{`find(arr, # > 0)`, 1},
671+
{`findIndex(arr, # > 1)`, 1},
672+
{`findLast(arr, # > 0)`, 3},
673+
{`findLastIndex(arr, # > 0)`, 2},
674+
{`groupBy(arr, # % 2 == 0)`, map[any][]any{false: {1, 3}, true: {2}}},
675+
{`sortBy(arr, -#)`, []any{3, 2, 1}},
676+
{`reduce(arr, # + #acc, x)`, 6 + 42},
677+
{`ceil(x)`, 42.0},
678+
{`floor(x)`, 42.0},
679+
{`round(x)`, 42.0},
662680
{`int(x)`, 42},
663681
{`float(x)`, 42.0},
664682
{`abs(x)`, 42},
665683
{`first(arr)`, 1},
666684
{`last(arr)`, 3},
667-
{`take(arr, 1)`, []any{1}},
668-
{`take(arr, x)`, []any{1, 2, 3}},
685+
{`take(arr, 1)`, []int{1}},
686+
{`take(arr, x)`, []int{1, 2, 3}},
669687
{`'a' in keys(m)`, true},
670688
{`1 in values(m)`, true},
671689
{`len(arr)`, 3},
@@ -685,10 +703,15 @@ func TestBuiltin_with_deref(t *testing.T) {
685703
t.Run(test.input, func(t *testing.T) {
686704
program, err := expr.Compile(test.input, expr.Env(env))
687705
require.NoError(t, err)
706+
println(program.Disassemble())
688707

689708
out, err := expr.Run(program, env)
690709
require.NoError(t, err)
691710
assert.Equal(t, test.want, out)
711+
712+
out, err = expr.Eval(test.input, env)
713+
require.NoError(t, err)
714+
assert.Equal(t, test.want, out)
692715
})
693716
}
694717
}

checker/checker.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,7 @@ func (v *checker) functionReturnType(node *ast.CallNode) Nature {
660660
func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature {
661661
switch node.Name {
662662
case "all", "none", "any", "one":
663-
collection := v.visit(node.Arguments[0])
663+
collection := v.visit(node.Arguments[0]).Deref()
664664
if !isArray(collection) && !isUnknown(collection) {
665665
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
666666
}
@@ -681,7 +681,7 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature {
681681
return v.error(node.Arguments[1], "predicate should has one input and one output param")
682682

683683
case "filter":
684-
collection := v.visit(node.Arguments[0])
684+
collection := v.visit(node.Arguments[0]).Deref()
685685
if !isArray(collection) && !isUnknown(collection) {
686686
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
687687
}
@@ -705,7 +705,7 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature {
705705
return v.error(node.Arguments[1], "predicate should has one input and one output param")
706706

707707
case "map":
708-
collection := v.visit(node.Arguments[0])
708+
collection := v.visit(node.Arguments[0]).Deref()
709709
if !isArray(collection) && !isUnknown(collection) {
710710
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
711711
}
@@ -723,7 +723,7 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature {
723723
return v.error(node.Arguments[1], "predicate should has one input and one output param")
724724

725725
case "count":
726-
collection := v.visit(node.Arguments[0])
726+
collection := v.visit(node.Arguments[0]).Deref()
727727
if !isArray(collection) && !isUnknown(collection) {
728728
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
729729
}
@@ -748,7 +748,7 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature {
748748
return v.error(node.Arguments[1], "predicate should has one input and one output param")
749749

750750
case "sum":
751-
collection := v.visit(node.Arguments[0])
751+
collection := v.visit(node.Arguments[0]).Deref()
752752
if !isArray(collection) && !isUnknown(collection) {
753753
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
754754
}
@@ -771,7 +771,7 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature {
771771
}
772772

773773
case "find", "findLast":
774-
collection := v.visit(node.Arguments[0])
774+
collection := v.visit(node.Arguments[0]).Deref()
775775
if !isArray(collection) && !isUnknown(collection) {
776776
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
777777
}
@@ -795,7 +795,7 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature {
795795
return v.error(node.Arguments[1], "predicate should has one input and one output param")
796796

797797
case "findIndex", "findLastIndex":
798-
collection := v.visit(node.Arguments[0])
798+
collection := v.visit(node.Arguments[0]).Deref()
799799
if !isArray(collection) && !isUnknown(collection) {
800800
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
801801
}
@@ -816,7 +816,7 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature {
816816
return v.error(node.Arguments[1], "predicate should has one input and one output param")
817817

818818
case "groupBy":
819-
collection := v.visit(node.Arguments[0])
819+
collection := v.visit(node.Arguments[0]).Deref()
820820
if !isArray(collection) && !isUnknown(collection) {
821821
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
822822
}
@@ -835,7 +835,7 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature {
835835
return v.error(node.Arguments[1], "predicate should has one input and one output param")
836836

837837
case "sortBy":
838-
collection := v.visit(node.Arguments[0])
838+
collection := v.visit(node.Arguments[0]).Deref()
839839
if !isArray(collection) && !isUnknown(collection) {
840840
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
841841
}
@@ -857,7 +857,7 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) Nature {
857857
return v.error(node.Arguments[1], "predicate should has one input and one output param")
858858

859859
case "reduce":
860-
collection := v.visit(node.Arguments[0])
860+
collection := v.visit(node.Arguments[0]).Deref()
861861
if !isArray(collection) && !isUnknown(collection) {
862862
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
863863
}

compiler/compiler.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
791791
switch node.Name {
792792
case "all":
793793
c.compile(node.Arguments[0])
794+
c.derefInNeeded(node.Arguments[0])
794795
c.emit(OpBegin)
795796
var loopBreak int
796797
c.emitLoop(func() {
@@ -805,6 +806,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
805806

806807
case "none":
807808
c.compile(node.Arguments[0])
809+
c.derefInNeeded(node.Arguments[0])
808810
c.emit(OpBegin)
809811
var loopBreak int
810812
c.emitLoop(func() {
@@ -820,6 +822,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
820822

821823
case "any":
822824
c.compile(node.Arguments[0])
825+
c.derefInNeeded(node.Arguments[0])
823826
c.emit(OpBegin)
824827
var loopBreak int
825828
c.emitLoop(func() {
@@ -834,6 +837,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
834837

835838
case "one":
836839
c.compile(node.Arguments[0])
840+
c.derefInNeeded(node.Arguments[0])
837841
c.emit(OpBegin)
838842
c.emitLoop(func() {
839843
c.compile(node.Arguments[1])
@@ -849,6 +853,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
849853

850854
case "filter":
851855
c.compile(node.Arguments[0])
856+
c.derefInNeeded(node.Arguments[0])
852857
c.emit(OpBegin)
853858
c.emitLoop(func() {
854859
c.compile(node.Arguments[1])
@@ -868,6 +873,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
868873

869874
case "map":
870875
c.compile(node.Arguments[0])
876+
c.derefInNeeded(node.Arguments[0])
871877
c.emit(OpBegin)
872878
c.emitLoop(func() {
873879
c.compile(node.Arguments[1])
@@ -879,6 +885,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
879885

880886
case "count":
881887
c.compile(node.Arguments[0])
888+
c.derefInNeeded(node.Arguments[0])
882889
c.emit(OpBegin)
883890
c.emitLoop(func() {
884891
if len(node.Arguments) == 2 {
@@ -896,6 +903,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
896903

897904
case "sum":
898905
c.compile(node.Arguments[0])
906+
c.derefInNeeded(node.Arguments[0])
899907
c.emit(OpBegin)
900908
c.emit(OpInt, 0)
901909
c.emit(OpSetAcc)
@@ -915,6 +923,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
915923

916924
case "find":
917925
c.compile(node.Arguments[0])
926+
c.derefInNeeded(node.Arguments[0])
918927
c.emit(OpBegin)
919928
var loopBreak int
920929
c.emitLoop(func() {
@@ -942,6 +951,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
942951

943952
case "findIndex":
944953
c.compile(node.Arguments[0])
954+
c.derefInNeeded(node.Arguments[0])
945955
c.emit(OpBegin)
946956
var loopBreak int
947957
c.emitLoop(func() {
@@ -960,6 +970,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
960970

961971
case "findLast":
962972
c.compile(node.Arguments[0])
973+
c.derefInNeeded(node.Arguments[0])
963974
c.emit(OpBegin)
964975
var loopBreak int
965976
c.emitLoopBackwards(func() {
@@ -987,6 +998,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
987998

988999
case "findLastIndex":
9891000
c.compile(node.Arguments[0])
1001+
c.derefInNeeded(node.Arguments[0])
9901002
c.emit(OpBegin)
9911003
var loopBreak int
9921004
c.emitLoopBackwards(func() {
@@ -1005,6 +1017,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
10051017

10061018
case "groupBy":
10071019
c.compile(node.Arguments[0])
1020+
c.derefInNeeded(node.Arguments[0])
10081021
c.emit(OpBegin)
10091022
c.emit(OpCreate, 1)
10101023
c.emit(OpSetAcc)
@@ -1018,6 +1031,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
10181031

10191032
case "sortBy":
10201033
c.compile(node.Arguments[0])
1034+
c.derefInNeeded(node.Arguments[0])
10211035
c.emit(OpBegin)
10221036
if len(node.Arguments) == 3 {
10231037
c.compile(node.Arguments[2])
@@ -1036,9 +1050,11 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
10361050

10371051
case "reduce":
10381052
c.compile(node.Arguments[0])
1053+
c.derefInNeeded(node.Arguments[0])
10391054
c.emit(OpBegin)
10401055
if len(node.Arguments) == 3 {
10411056
c.compile(node.Arguments[2])
1057+
c.derefInNeeded(node.Arguments[2])
10421058
c.emit(OpSetAcc)
10431059
} else {
10441060
c.emit(OpPointer)

0 commit comments

Comments
 (0)