@@ -110,6 +110,7 @@ use itertools::Itertools;
110
110
use jsonb:: keypath:: KeyPath ;
111
111
use jsonb:: keypath:: KeyPaths ;
112
112
use simsearch:: SimSearch ;
113
+ use unicase:: Ascii ;
113
114
114
115
use super :: name_resolution:: NameResolutionContext ;
115
116
use super :: normalize_identifier;
@@ -194,7 +195,7 @@ pub struct TypeChecker<'a> {
194
195
// This is used to check if there is nested aggregate function.
195
196
in_aggregate_function : bool ,
196
197
197
- // true if current expr is inside an window function.
198
+ // true if current expr is inside a window function.
198
199
// This is used to allow aggregation function in window's aggregate function.
199
200
in_window_function : bool ,
200
201
forbid_udf : bool ,
@@ -731,8 +732,9 @@ impl<'a> TypeChecker<'a> {
731
732
} => {
732
733
let func_name = normalize_identifier ( name, self . name_resolution_ctx ) . to_string ( ) ;
733
734
let func_name = func_name. as_str ( ) ;
735
+ let uni_case_func_name = Ascii :: new ( func_name) ;
734
736
if !is_builtin_function ( func_name)
735
- && !Self :: all_sugar_functions ( ) . contains ( & func_name )
737
+ && !Self :: all_sugar_functions ( ) . contains ( & uni_case_func_name )
736
738
{
737
739
if let Some ( udf) = self . resolve_udf ( * span, func_name, args) ? {
738
740
return Ok ( udf) ;
@@ -743,15 +745,35 @@ impl<'a> TypeChecker<'a> {
743
745
. all_function_names ( )
744
746
. into_iter ( )
745
747
. chain ( AggregateFunctionFactory :: instance ( ) . registered_names ( ) )
746
- . chain ( GENERAL_WINDOW_FUNCTIONS . iter ( ) . cloned ( ) . map ( str:: to_string) )
747
- . chain ( GENERAL_LAMBDA_FUNCTIONS . iter ( ) . cloned ( ) . map ( str:: to_string) )
748
- . chain ( GENERAL_SEARCH_FUNCTIONS . iter ( ) . cloned ( ) . map ( str:: to_string) )
749
- . chain ( ASYNC_FUNCTIONS . iter ( ) . cloned ( ) . map ( str:: to_string) )
748
+ . chain (
749
+ GENERAL_WINDOW_FUNCTIONS
750
+ . iter ( )
751
+ . cloned ( )
752
+ . map ( |ascii| ascii. into_inner ( ) . to_string ( ) ) ,
753
+ )
754
+ . chain (
755
+ GENERAL_LAMBDA_FUNCTIONS
756
+ . iter ( )
757
+ . cloned ( )
758
+ . map ( |ascii| ascii. into_inner ( ) . to_string ( ) ) ,
759
+ )
760
+ . chain (
761
+ GENERAL_SEARCH_FUNCTIONS
762
+ . iter ( )
763
+ . cloned ( )
764
+ . map ( |ascii| ascii. into_inner ( ) . to_string ( ) ) ,
765
+ )
766
+ . chain (
767
+ ASYNC_FUNCTIONS
768
+ . iter ( )
769
+ . cloned ( )
770
+ . map ( |ascii| ascii. into_inner ( ) . to_string ( ) ) ,
771
+ )
750
772
. chain (
751
773
Self :: all_sugar_functions ( )
752
774
. iter ( )
753
775
. cloned ( )
754
- . map ( str :: to_string) ,
776
+ . map ( |ascii| ascii . into_inner ( ) . to_string ( ) ) ,
755
777
) ;
756
778
let mut engine: SimSearch < String > = SimSearch :: new ( ) ;
757
779
for func_name in all_funcs {
@@ -779,15 +801,15 @@ impl<'a> TypeChecker<'a> {
779
801
// check window function legal
780
802
if window. is_some ( )
781
803
&& !AggregateFunctionFactory :: instance ( ) . contains ( func_name)
782
- && !GENERAL_WINDOW_FUNCTIONS . contains ( & func_name )
804
+ && !GENERAL_WINDOW_FUNCTIONS . contains ( & uni_case_func_name )
783
805
{
784
806
return Err ( ErrorCode :: SemanticError (
785
807
"only window and aggregate functions allowed in window syntax" ,
786
808
)
787
809
. set_span ( * span) ) ;
788
810
}
789
811
// check lambda function legal
790
- if lambda. is_some ( ) && !GENERAL_LAMBDA_FUNCTIONS . contains ( & func_name ) {
812
+ if lambda. is_some ( ) && !GENERAL_LAMBDA_FUNCTIONS . contains ( & uni_case_func_name ) {
791
813
return Err ( ErrorCode :: SemanticError (
792
814
"only lambda functions allowed in lambda syntax" ,
793
815
)
@@ -796,7 +818,7 @@ impl<'a> TypeChecker<'a> {
796
818
797
819
let args: Vec < & Expr > = args. iter ( ) . collect ( ) ;
798
820
799
- if GENERAL_WINDOW_FUNCTIONS . contains ( & func_name ) {
821
+ if GENERAL_WINDOW_FUNCTIONS . contains ( & uni_case_func_name ) {
800
822
// general window function
801
823
if window. is_none ( ) {
802
824
return Err ( ErrorCode :: SemanticError ( format ! (
@@ -862,7 +884,7 @@ impl<'a> TypeChecker<'a> {
862
884
// aggregate function
863
885
Box :: new ( ( new_agg_func. into ( ) , data_type) )
864
886
}
865
- } else if GENERAL_LAMBDA_FUNCTIONS . contains ( & func_name ) {
887
+ } else if GENERAL_LAMBDA_FUNCTIONS . contains ( & uni_case_func_name ) {
866
888
if lambda. is_none ( ) {
867
889
return Err ( ErrorCode :: SemanticError ( format ! (
868
890
"function {func_name} must have a lambda expression" ,
@@ -871,8 +893,8 @@ impl<'a> TypeChecker<'a> {
871
893
}
872
894
let lambda = lambda. as_ref ( ) . unwrap ( ) ;
873
895
self . resolve_lambda_function ( * span, func_name, & args, lambda) ?
874
- } else if GENERAL_SEARCH_FUNCTIONS . contains ( & func_name ) {
875
- match func_name {
896
+ } else if GENERAL_SEARCH_FUNCTIONS . contains ( & uni_case_func_name ) {
897
+ match func_name. to_lowercase ( ) . as_str ( ) {
876
898
"score" => self . resolve_score_search_function ( * span, func_name, & args) ?,
877
899
"match" => self . resolve_match_search_function ( * span, func_name, & args) ?,
878
900
"query" => self . resolve_query_search_function ( * span, func_name, & args) ?,
@@ -884,7 +906,7 @@ impl<'a> TypeChecker<'a> {
884
906
. set_span ( * span) ) ;
885
907
}
886
908
}
887
- } else if ASYNC_FUNCTIONS . contains ( & func_name ) {
909
+ } else if ASYNC_FUNCTIONS . contains ( & uni_case_func_name ) {
888
910
self . resolve_async_function ( * span, func_name, & args) ?
889
911
} else if BUILTIN_FUNCTIONS
890
912
. get_property ( func_name)
@@ -1445,7 +1467,7 @@ impl<'a> TypeChecker<'a> {
1445
1467
self . in_window_function = false ;
1446
1468
1447
1469
// If { IGNORE | RESPECT } NULLS is not specified, the default is RESPECT NULLS
1448
- // (i.e. a NULL value will be returned if the expression contains a NULL value and it is the first value in the expression).
1470
+ // (i.e. a NULL value will be returned if the expression contains a NULL value, and it is the first value in the expression).
1449
1471
let ignore_null = if let Some ( ignore_null) = window_ignore_null {
1450
1472
* ignore_null
1451
1473
} else {
@@ -2090,7 +2112,7 @@ impl<'a> TypeChecker<'a> {
2090
2112
param_count : usize ,
2091
2113
span : Span ,
2092
2114
) -> Result < ( ) > {
2093
- // json lambda functions are casted to array or map, ignored here.
2115
+ // json lambda functions are cast to array or map, ignored here.
2094
2116
let expected_count = if func_name == "array_reduce" {
2095
2117
2
2096
2118
} else if func_name. starts_with ( "array" ) {
@@ -3124,37 +3146,38 @@ impl<'a> TypeChecker<'a> {
3124
3146
Ok ( Box :: new ( ( subquery_expr. into ( ) , data_type) ) )
3125
3147
}
3126
3148
3127
- pub fn all_sugar_functions ( ) -> & ' static [ & ' static str ] {
3128
- & [
3129
- "current_catalog" ,
3130
- "database" ,
3131
- "currentdatabase" ,
3132
- "current_database" ,
3133
- "version" ,
3134
- "user" ,
3135
- "currentuser" ,
3136
- "current_user" ,
3137
- "current_role" ,
3138
- "connection_id" ,
3139
- "timezone" ,
3140
- "nullif" ,
3141
- "ifnull" ,
3142
- "nvl" ,
3143
- "nvl2" ,
3144
- "is_null" ,
3145
- "is_error" ,
3146
- "error_or" ,
3147
- "coalesce" ,
3148
- "last_query_id" ,
3149
- "array_sort" ,
3150
- "array_aggregate" ,
3151
- "to_variant" ,
3152
- "try_to_variant" ,
3153
- "greatest" ,
3154
- "least" ,
3155
- "stream_has_data" ,
3156
- "getvariable" ,
3157
- ]
3149
+ pub fn all_sugar_functions ( ) -> & ' static [ Ascii < & ' static str > ] {
3150
+ static FUNCTIONS : & [ Ascii < & ' static str > ] = & [
3151
+ Ascii :: new ( "current_catalog" ) ,
3152
+ Ascii :: new ( "database" ) ,
3153
+ Ascii :: new ( "currentdatabase" ) ,
3154
+ Ascii :: new ( "current_database" ) ,
3155
+ Ascii :: new ( "version" ) ,
3156
+ Ascii :: new ( "user" ) ,
3157
+ Ascii :: new ( "currentuser" ) ,
3158
+ Ascii :: new ( "current_user" ) ,
3159
+ Ascii :: new ( "current_role" ) ,
3160
+ Ascii :: new ( "connection_id" ) ,
3161
+ Ascii :: new ( "timezone" ) ,
3162
+ Ascii :: new ( "nullif" ) ,
3163
+ Ascii :: new ( "ifnull" ) ,
3164
+ Ascii :: new ( "nvl" ) ,
3165
+ Ascii :: new ( "nvl2" ) ,
3166
+ Ascii :: new ( "is_null" ) ,
3167
+ Ascii :: new ( "is_error" ) ,
3168
+ Ascii :: new ( "error_or" ) ,
3169
+ Ascii :: new ( "coalesce" ) ,
3170
+ Ascii :: new ( "last_query_id" ) ,
3171
+ Ascii :: new ( "array_sort" ) ,
3172
+ Ascii :: new ( "array_aggregate" ) ,
3173
+ Ascii :: new ( "to_variant" ) ,
3174
+ Ascii :: new ( "try_to_variant" ) ,
3175
+ Ascii :: new ( "greatest" ) ,
3176
+ Ascii :: new ( "least" ) ,
3177
+ Ascii :: new ( "stream_has_data" ) ,
3178
+ Ascii :: new ( "getvariable" ) ,
3179
+ ] ;
3180
+ FUNCTIONS
3158
3181
}
3159
3182
3160
3183
fn try_rewrite_sugar_function (
0 commit comments