Skip to content

Commit c08754d

Browse files
committed
migrate builtin function lookup to unicase
1 parent 7d2fc12 commit c08754d

File tree

5 files changed

+65
-68
lines changed

5 files changed

+65
-68
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/query/functions/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ siphasher = { workspace = true }
6161
strength_reduce = { workspace = true }
6262
stringslice = { workspace = true }
6363
twox-hash = { workspace = true }
64+
unicase = { workspace = true }
6465

6566
[dev-dependencies]
6667
comfy-table = { workspace = true }

src/query/functions/src/lib.rs

Lines changed: 44 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,17 @@
2626
use aggregates::AggregateFunctionFactory;
2727
use ctor::ctor;
2828
use databend_common_expression::FunctionRegistry;
29+
use unicase::Ascii;
2930

3031
pub mod aggregates;
3132
mod cast_rules;
3233
pub mod scalars;
3334
pub mod srfs;
3435

3536
pub fn is_builtin_function(name: &str) -> bool {
36-
BUILTIN_FUNCTIONS.contains(name)
37-
|| AggregateFunctionFactory::instance().contains(name)
37+
let name = Ascii::new(name);
38+
BUILTIN_FUNCTIONS.contains(name.into_inner())
39+
|| AggregateFunctionFactory::instance().contains(name.into_inner())
3840
|| GENERAL_WINDOW_FUNCTIONS.contains(&name)
3941
|| GENERAL_LAMBDA_FUNCTIONS.contains(&name)
4042
|| GENERAL_SEARCH_FUNCTIONS.contains(&name)
@@ -44,53 +46,58 @@ pub fn is_builtin_function(name: &str) -> bool {
4446
// The plan of search function, async function and udf contains some arguments defined in meta,
4547
// which may be modified by user at any time. Those functions are not not suitable for caching.
4648
pub fn is_cacheable_function(name: &str) -> bool {
47-
BUILTIN_FUNCTIONS.contains(name)
48-
|| AggregateFunctionFactory::instance().contains(name)
49+
let name = Ascii::new(name);
50+
BUILTIN_FUNCTIONS.contains(name.into_inner())
51+
|| AggregateFunctionFactory::instance().contains(name.into_inner())
4952
|| GENERAL_WINDOW_FUNCTIONS.contains(&name)
5053
|| GENERAL_LAMBDA_FUNCTIONS.contains(&name)
5154
}
5255

5356
#[ctor]
5457
pub static BUILTIN_FUNCTIONS: FunctionRegistry = builtin_functions();
5558

56-
pub const ASYNC_FUNCTIONS: [&str; 2] = ["nextval", "dict_get"];
59+
pub const ASYNC_FUNCTIONS: [Ascii<&str>; 2] = [Ascii::new("nextval"), Ascii::new("dict_get")];
5760

58-
pub const GENERAL_WINDOW_FUNCTIONS: [&str; 13] = [
59-
"row_number",
60-
"rank",
61-
"dense_rank",
62-
"percent_rank",
63-
"lag",
64-
"lead",
65-
"first_value",
66-
"first",
67-
"last_value",
68-
"last",
69-
"nth_value",
70-
"ntile",
71-
"cume_dist",
61+
pub const GENERAL_WINDOW_FUNCTIONS: [Ascii<&str>; 13] = [
62+
Ascii::new("row_number"),
63+
Ascii::new("rank"),
64+
Ascii::new("dense_rank"),
65+
Ascii::new("percent_rank"),
66+
Ascii::new("lag"),
67+
Ascii::new("lead"),
68+
Ascii::new("first_value"),
69+
Ascii::new("first"),
70+
Ascii::new("last_value"),
71+
Ascii::new("last"),
72+
Ascii::new("nth_value"),
73+
Ascii::new("ntile"),
74+
Ascii::new("cume_dist"),
7275
];
7376

74-
pub const GENERAL_LAMBDA_FUNCTIONS: [&str; 16] = [
75-
"array_transform",
76-
"array_apply",
77-
"array_map",
78-
"array_filter",
79-
"array_reduce",
80-
"json_array_transform",
81-
"json_array_apply",
82-
"json_array_map",
83-
"json_array_filter",
84-
"json_array_reduce",
85-
"map_filter",
86-
"map_transform_keys",
87-
"map_transform_values",
88-
"json_map_filter",
89-
"json_map_transform_keys",
90-
"json_map_transform_values",
77+
pub const GENERAL_LAMBDA_FUNCTIONS: [Ascii<&str>; 16] = [
78+
Ascii::new("array_transform"),
79+
Ascii::new("array_apply"),
80+
Ascii::new("array_map"),
81+
Ascii::new("array_filter"),
82+
Ascii::new("array_reduce"),
83+
Ascii::new("json_array_transform"),
84+
Ascii::new("json_array_apply"),
85+
Ascii::new("json_array_map"),
86+
Ascii::new("json_array_filter"),
87+
Ascii::new("json_array_reduce"),
88+
Ascii::new("map_filter"),
89+
Ascii::new("map_transform_keys"),
90+
Ascii::new("map_transform_values"),
91+
Ascii::new("json_map_filter"),
92+
Ascii::new("json_map_transform_keys"),
93+
Ascii::new("json_map_transform_values"),
9194
];
9295

93-
pub const GENERAL_SEARCH_FUNCTIONS: [&str; 3] = ["match", "query", "score"];
96+
pub const GENERAL_SEARCH_FUNCTIONS: [Ascii<&str>; 3] = [
97+
Ascii::new("match"),
98+
Ascii::new("query"),
99+
Ascii::new("score"),
100+
];
94101

95102
fn builtin_functions() -> FunctionRegistry {
96103
let mut registry = FunctionRegistry::empty();

src/query/sql/src/planner/semantic/type_check.rs

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ use itertools::Itertools;
106106
use jsonb::keypath::KeyPath;
107107
use jsonb::keypath::KeyPaths;
108108
use simsearch::SimSearch;
109-
use unicase::UniCase;
109+
use unicase::Ascii;
110110

111111
use super::name_resolution::NameResolutionContext;
112112
use super::normalize_identifier;
@@ -185,7 +185,7 @@ pub struct TypeChecker<'a> {
185185
// This is used to check if there is nested aggregate function.
186186
in_aggregate_function: bool,
187187

188-
// true if current expr is inside an window function.
188+
// true if current expr is inside a window function.
189189
// This is used to allow aggregation function in window's aggregate function.
190190
in_window_function: bool,
191191
forbid_udf: bool,
@@ -722,7 +722,10 @@ impl<'a> TypeChecker<'a> {
722722
} => {
723723
let func_name = normalize_identifier(name, self.name_resolution_ctx).to_string();
724724
let func_name = func_name.as_str();
725-
if !is_builtin_function(func_name) && !Self::is_sugar_function(func_name) {
725+
let uni_case_func_name = Ascii::new(func_name);
726+
if !is_builtin_function(func_name)
727+
&& !Self::all_sugar_functions().contains(&uni_case_func_name)
728+
{
726729
if let Some(udf) = self.resolve_udf(*span, func_name, args)? {
727730
return Ok(udf);
728731
} else {
@@ -731,10 +734,10 @@ impl<'a> TypeChecker<'a> {
731734
.all_function_names()
732735
.into_iter()
733736
.chain(AggregateFunctionFactory::instance().registered_names())
734-
.chain(GENERAL_WINDOW_FUNCTIONS.iter().cloned().map(str::to_string))
735-
.chain(GENERAL_LAMBDA_FUNCTIONS.iter().cloned().map(str::to_string))
736-
.chain(GENERAL_SEARCH_FUNCTIONS.iter().cloned().map(str::to_string))
737-
.chain(ASYNC_FUNCTIONS.iter().cloned().map(str::to_string))
737+
.chain(GENERAL_WINDOW_FUNCTIONS.iter().cloned().map(|ascii| ascii.into_inner()))
738+
.chain(GENERAL_LAMBDA_FUNCTIONS.iter().cloned().map(|ascii| ascii.into_inner()))
739+
.chain(GENERAL_SEARCH_FUNCTIONS.iter().cloned().map(|ascii| ascii.into_inner()))
740+
.chain(ASYNC_FUNCTIONS.iter().cloned().map(|ascii| ascii.into_inner()))
738741
.chain(
739742
Self::all_sugar_functions()
740743
.iter()
@@ -768,15 +771,15 @@ impl<'a> TypeChecker<'a> {
768771
// check window function legal
769772
if window.is_some()
770773
&& !AggregateFunctionFactory::instance().contains(func_name)
771-
&& !GENERAL_WINDOW_FUNCTIONS.contains(&func_name)
774+
&& !GENERAL_WINDOW_FUNCTIONS.contains(&uni_case_func_name)
772775
{
773776
return Err(ErrorCode::SemanticError(
774777
"only window and aggregate functions allowed in window syntax",
775778
)
776779
.set_span(*span));
777780
}
778781
// check lambda function legal
779-
if lambda.is_some() && !GENERAL_LAMBDA_FUNCTIONS.contains(&func_name) {
782+
if lambda.is_some() && !GENERAL_LAMBDA_FUNCTIONS.contains(&uni_case_func_name) {
780783
return Err(ErrorCode::SemanticError(
781784
"only lambda functions allowed in lambda syntax",
782785
)
@@ -785,7 +788,7 @@ impl<'a> TypeChecker<'a> {
785788

786789
let args: Vec<&Expr> = args.iter().collect();
787790

788-
if GENERAL_WINDOW_FUNCTIONS.contains(&func_name) {
791+
if GENERAL_WINDOW_FUNCTIONS.contains(&uni_case_func_name) {
789792
// general window function
790793
if window.is_none() {
791794
return Err(ErrorCode::SemanticError(format!(
@@ -851,7 +854,7 @@ impl<'a> TypeChecker<'a> {
851854
// aggregate function
852855
Box::new((new_agg_func.into(), data_type))
853856
}
854-
} else if GENERAL_LAMBDA_FUNCTIONS.contains(&func_name) {
857+
} else if GENERAL_LAMBDA_FUNCTIONS.contains(&uni_case_func_name) {
855858
if lambda.is_none() {
856859
return Err(ErrorCode::SemanticError(format!(
857860
"function {func_name} must have a lambda expression",
@@ -860,8 +863,8 @@ impl<'a> TypeChecker<'a> {
860863
}
861864
let lambda = lambda.as_ref().unwrap();
862865
self.resolve_lambda_function(*span, func_name, &args, lambda)?
863-
} else if GENERAL_SEARCH_FUNCTIONS.contains(&func_name) {
864-
match func_name {
866+
} else if GENERAL_SEARCH_FUNCTIONS.contains(&uni_case_func_name) {
867+
match func_name.to_lowercase().as_str() {
865868
"score" => self.resolve_score_search_function(*span, func_name, &args)?,
866869
"match" => self.resolve_match_search_function(*span, func_name, &args)?,
867870
"query" => self.resolve_query_search_function(*span, func_name, &args)?,
@@ -873,7 +876,7 @@ impl<'a> TypeChecker<'a> {
873876
.set_span(*span));
874877
}
875878
}
876-
} else if ASYNC_FUNCTIONS.contains(&func_name) {
879+
} else if ASYNC_FUNCTIONS.contains(&uni_case_func_name) {
877880
self.resolve_async_function(*span, func_name, &args)?
878881
} else if BUILTIN_FUNCTIONS
879882
.get_property(func_name)
@@ -1435,7 +1438,7 @@ impl<'a> TypeChecker<'a> {
14351438
self.in_window_function = false;
14361439

14371440
// If { IGNORE | RESPECT } NULLS is not specified, the default is RESPECT NULLS
1438-
// (i.e. a NULL value will be returned if the expression contains a NULL value and it is the first value in the expression).
1441+
// (i.e. a NULL value will be returned if the expression contains a NULL value, and it is the first value in the expression).
14391442
let ignore_null = if let Some(ignore_null) = window_ignore_null {
14401443
*ignore_null
14411444
} else {
@@ -2080,7 +2083,7 @@ impl<'a> TypeChecker<'a> {
20802083
param_count: usize,
20812084
span: Span,
20822085
) -> Result<()> {
2083-
// json lambda functions are casted to array or map, ignored here.
2086+
// json lambda functions are cast to array or map, ignored here.
20842087
let expected_count = if func_name == "array_reduce" {
20852088
2
20862089
} else if func_name.starts_with("array") {
@@ -3153,11 +3156,6 @@ impl<'a> TypeChecker<'a> {
31533156
FUNCTIONS
31543157
}
31553158

3156-
pub fn is_sugar_function(name: &str) -> bool {
3157-
let name = Ascii::new(name);
3158-
all_sugar_functions().iter().any(|func| func.eq(&name))
3159-
}
3160-
31613159
fn try_rewrite_sugar_function(
31623160
&mut self,
31633161
span: Span,

src/query/sql/tests/type_check.rs

Lines changed: 0 additions & 10 deletions
This file was deleted.

0 commit comments

Comments
 (0)