Skip to content

Commit 0c94ab5

Browse files
authored
fix: ignore case when matching function name (#16912)
* add method to check sugar function * compare sugar functions using unicase ascii * add sqllogictest * migrate builtin function lookup to unicase * fix issues * fix build issue
1 parent 1f9a4eb commit 0c94ab5

File tree

7 files changed

+128
-84
lines changed

7 files changed

+128
-84
lines changed

Cargo.lock

Lines changed: 8 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,7 @@ tower = { version = "0.5.1", features = ["util"] }
485485
tower-service = "0.3.3"
486486
twox-hash = "1.6.3"
487487
typetag = "0.2.3"
488+
unicase = "2.8.0"
488489
unicode-segmentation = "1.10.1"
489490
unindent = "0.2"
490491
url = "2.3.1"

src/query/functions/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ siphasher = { workspace = true }
6262
strength_reduce = { workspace = true }
6363
stringslice = { workspace = true }
6464
twox-hash = { workspace = true }
65+
unicase = { workspace = true }
6566

6667
[dev-dependencies]
6768
comfy-table = { workspace = true }

src/query/functions/src/lib.rs

Lines changed: 44 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,17 @@
2727
use aggregates::AggregateFunctionFactory;
2828
use ctor::ctor;
2929
use databend_common_expression::FunctionRegistry;
30+
use unicase::Ascii;
3031

3132
pub mod aggregates;
3233
mod cast_rules;
3334
pub mod scalars;
3435
pub mod srfs;
3536

3637
pub fn is_builtin_function(name: &str) -> bool {
37-
BUILTIN_FUNCTIONS.contains(name)
38-
|| AggregateFunctionFactory::instance().contains(name)
38+
let name = Ascii::new(name);
39+
BUILTIN_FUNCTIONS.contains(name.into_inner())
40+
|| AggregateFunctionFactory::instance().contains(name.into_inner())
3941
|| GENERAL_WINDOW_FUNCTIONS.contains(&name)
4042
|| GENERAL_LAMBDA_FUNCTIONS.contains(&name)
4143
|| GENERAL_SEARCH_FUNCTIONS.contains(&name)
@@ -45,56 +47,61 @@ pub fn is_builtin_function(name: &str) -> bool {
4547
// The plan of search function, async function and udf contains some arguments defined in meta,
4648
// which may be modified by user at any time. Those functions are not not suitable for caching.
4749
pub fn is_cacheable_function(name: &str) -> bool {
48-
BUILTIN_FUNCTIONS.contains(name)
49-
|| AggregateFunctionFactory::instance().contains(name)
50+
let name = Ascii::new(name);
51+
BUILTIN_FUNCTIONS.contains(name.into_inner())
52+
|| AggregateFunctionFactory::instance().contains(name.into_inner())
5053
|| GENERAL_WINDOW_FUNCTIONS.contains(&name)
5154
|| GENERAL_LAMBDA_FUNCTIONS.contains(&name)
5255
}
5356

5457
#[ctor]
5558
pub static BUILTIN_FUNCTIONS: FunctionRegistry = builtin_functions();
5659

57-
pub const ASYNC_FUNCTIONS: [&str; 2] = ["nextval", "dict_get"];
60+
pub const ASYNC_FUNCTIONS: [Ascii<&str>; 2] = [Ascii::new("nextval"), Ascii::new("dict_get")];
5861

59-
pub const GENERAL_WINDOW_FUNCTIONS: [&str; 13] = [
60-
"row_number",
61-
"rank",
62-
"dense_rank",
63-
"percent_rank",
64-
"lag",
65-
"lead",
66-
"first_value",
67-
"first",
68-
"last_value",
69-
"last",
70-
"nth_value",
71-
"ntile",
72-
"cume_dist",
62+
pub const GENERAL_WINDOW_FUNCTIONS: [Ascii<&str>; 13] = [
63+
Ascii::new("row_number"),
64+
Ascii::new("rank"),
65+
Ascii::new("dense_rank"),
66+
Ascii::new("percent_rank"),
67+
Ascii::new("lag"),
68+
Ascii::new("lead"),
69+
Ascii::new("first_value"),
70+
Ascii::new("first"),
71+
Ascii::new("last_value"),
72+
Ascii::new("last"),
73+
Ascii::new("nth_value"),
74+
Ascii::new("ntile"),
75+
Ascii::new("cume_dist"),
7376
];
7477

7578
pub const RANK_WINDOW_FUNCTIONS: [&str; 5] =
7679
["first_value", "first", "last_value", "last", "nth_value"];
7780

78-
pub const GENERAL_LAMBDA_FUNCTIONS: [&str; 16] = [
79-
"array_transform",
80-
"array_apply",
81-
"array_map",
82-
"array_filter",
83-
"array_reduce",
84-
"json_array_transform",
85-
"json_array_apply",
86-
"json_array_map",
87-
"json_array_filter",
88-
"json_array_reduce",
89-
"map_filter",
90-
"map_transform_keys",
91-
"map_transform_values",
92-
"json_map_filter",
93-
"json_map_transform_keys",
94-
"json_map_transform_values",
81+
pub const GENERAL_LAMBDA_FUNCTIONS: [Ascii<&str>; 16] = [
82+
Ascii::new("array_transform"),
83+
Ascii::new("array_apply"),
84+
Ascii::new("array_map"),
85+
Ascii::new("array_filter"),
86+
Ascii::new("array_reduce"),
87+
Ascii::new("json_array_transform"),
88+
Ascii::new("json_array_apply"),
89+
Ascii::new("json_array_map"),
90+
Ascii::new("json_array_filter"),
91+
Ascii::new("json_array_reduce"),
92+
Ascii::new("map_filter"),
93+
Ascii::new("map_transform_keys"),
94+
Ascii::new("map_transform_values"),
95+
Ascii::new("json_map_filter"),
96+
Ascii::new("json_map_transform_keys"),
97+
Ascii::new("json_map_transform_values"),
9598
];
9699

97-
pub const GENERAL_SEARCH_FUNCTIONS: [&str; 3] = ["match", "query", "score"];
100+
pub const GENERAL_SEARCH_FUNCTIONS: [Ascii<&str>; 3] = [
101+
Ascii::new("match"),
102+
Ascii::new("query"),
103+
Ascii::new("score"),
104+
];
98105

99106
fn builtin_functions() -> FunctionRegistry {
100107
let mut registry = FunctionRegistry::empty();

src/query/sql/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ serde = { workspace = true }
7373
sha2 = { workspace = true }
7474
simsearch = { workspace = true }
7575
tokio = { workspace = true }
76+
unicase = { workspace = true }
7677
url = { workspace = true }
7778

7879
[lints]

src/query/sql/src/planner/semantic/type_check.rs

Lines changed: 70 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ use itertools::Itertools;
110110
use jsonb::keypath::KeyPath;
111111
use jsonb::keypath::KeyPaths;
112112
use simsearch::SimSearch;
113+
use unicase::Ascii;
113114

114115
use super::name_resolution::NameResolutionContext;
115116
use super::normalize_identifier;
@@ -194,7 +195,7 @@ pub struct TypeChecker<'a> {
194195
// This is used to check if there is nested aggregate function.
195196
in_aggregate_function: bool,
196197

197-
// true if current expr is inside an window function.
198+
// true if current expr is inside a window function.
198199
// This is used to allow aggregation function in window's aggregate function.
199200
in_window_function: bool,
200201
forbid_udf: bool,
@@ -731,8 +732,9 @@ impl<'a> TypeChecker<'a> {
731732
} => {
732733
let func_name = normalize_identifier(name, self.name_resolution_ctx).to_string();
733734
let func_name = func_name.as_str();
735+
let uni_case_func_name = Ascii::new(func_name);
734736
if !is_builtin_function(func_name)
735-
&& !Self::all_sugar_functions().contains(&func_name)
737+
&& !Self::all_sugar_functions().contains(&uni_case_func_name)
736738
{
737739
if let Some(udf) = self.resolve_udf(*span, func_name, args)? {
738740
return Ok(udf);
@@ -743,15 +745,35 @@ impl<'a> TypeChecker<'a> {
743745
.all_function_names()
744746
.into_iter()
745747
.chain(AggregateFunctionFactory::instance().registered_names())
746-
.chain(GENERAL_WINDOW_FUNCTIONS.iter().cloned().map(str::to_string))
747-
.chain(GENERAL_LAMBDA_FUNCTIONS.iter().cloned().map(str::to_string))
748-
.chain(GENERAL_SEARCH_FUNCTIONS.iter().cloned().map(str::to_string))
749-
.chain(ASYNC_FUNCTIONS.iter().cloned().map(str::to_string))
748+
.chain(
749+
GENERAL_WINDOW_FUNCTIONS
750+
.iter()
751+
.cloned()
752+
.map(|ascii| ascii.into_inner().to_string()),
753+
)
754+
.chain(
755+
GENERAL_LAMBDA_FUNCTIONS
756+
.iter()
757+
.cloned()
758+
.map(|ascii| ascii.into_inner().to_string()),
759+
)
760+
.chain(
761+
GENERAL_SEARCH_FUNCTIONS
762+
.iter()
763+
.cloned()
764+
.map(|ascii| ascii.into_inner().to_string()),
765+
)
766+
.chain(
767+
ASYNC_FUNCTIONS
768+
.iter()
769+
.cloned()
770+
.map(|ascii| ascii.into_inner().to_string()),
771+
)
750772
.chain(
751773
Self::all_sugar_functions()
752774
.iter()
753775
.cloned()
754-
.map(str::to_string),
776+
.map(|ascii| ascii.into_inner().to_string()),
755777
);
756778
let mut engine: SimSearch<String> = SimSearch::new();
757779
for func_name in all_funcs {
@@ -779,15 +801,15 @@ impl<'a> TypeChecker<'a> {
779801
// check window function legal
780802
if window.is_some()
781803
&& !AggregateFunctionFactory::instance().contains(func_name)
782-
&& !GENERAL_WINDOW_FUNCTIONS.contains(&func_name)
804+
&& !GENERAL_WINDOW_FUNCTIONS.contains(&uni_case_func_name)
783805
{
784806
return Err(ErrorCode::SemanticError(
785807
"only window and aggregate functions allowed in window syntax",
786808
)
787809
.set_span(*span));
788810
}
789811
// check lambda function legal
790-
if lambda.is_some() && !GENERAL_LAMBDA_FUNCTIONS.contains(&func_name) {
812+
if lambda.is_some() && !GENERAL_LAMBDA_FUNCTIONS.contains(&uni_case_func_name) {
791813
return Err(ErrorCode::SemanticError(
792814
"only lambda functions allowed in lambda syntax",
793815
)
@@ -796,7 +818,7 @@ impl<'a> TypeChecker<'a> {
796818

797819
let args: Vec<&Expr> = args.iter().collect();
798820

799-
if GENERAL_WINDOW_FUNCTIONS.contains(&func_name) {
821+
if GENERAL_WINDOW_FUNCTIONS.contains(&uni_case_func_name) {
800822
// general window function
801823
if window.is_none() {
802824
return Err(ErrorCode::SemanticError(format!(
@@ -862,7 +884,7 @@ impl<'a> TypeChecker<'a> {
862884
// aggregate function
863885
Box::new((new_agg_func.into(), data_type))
864886
}
865-
} else if GENERAL_LAMBDA_FUNCTIONS.contains(&func_name) {
887+
} else if GENERAL_LAMBDA_FUNCTIONS.contains(&uni_case_func_name) {
866888
if lambda.is_none() {
867889
return Err(ErrorCode::SemanticError(format!(
868890
"function {func_name} must have a lambda expression",
@@ -871,8 +893,8 @@ impl<'a> TypeChecker<'a> {
871893
}
872894
let lambda = lambda.as_ref().unwrap();
873895
self.resolve_lambda_function(*span, func_name, &args, lambda)?
874-
} else if GENERAL_SEARCH_FUNCTIONS.contains(&func_name) {
875-
match func_name {
896+
} else if GENERAL_SEARCH_FUNCTIONS.contains(&uni_case_func_name) {
897+
match func_name.to_lowercase().as_str() {
876898
"score" => self.resolve_score_search_function(*span, func_name, &args)?,
877899
"match" => self.resolve_match_search_function(*span, func_name, &args)?,
878900
"query" => self.resolve_query_search_function(*span, func_name, &args)?,
@@ -884,7 +906,7 @@ impl<'a> TypeChecker<'a> {
884906
.set_span(*span));
885907
}
886908
}
887-
} else if ASYNC_FUNCTIONS.contains(&func_name) {
909+
} else if ASYNC_FUNCTIONS.contains(&uni_case_func_name) {
888910
self.resolve_async_function(*span, func_name, &args)?
889911
} else if BUILTIN_FUNCTIONS
890912
.get_property(func_name)
@@ -1445,7 +1467,7 @@ impl<'a> TypeChecker<'a> {
14451467
self.in_window_function = false;
14461468

14471469
// If { IGNORE | RESPECT } NULLS is not specified, the default is RESPECT NULLS
1448-
// (i.e. a NULL value will be returned if the expression contains a NULL value and it is the first value in the expression).
1470+
// (i.e. a NULL value will be returned if the expression contains a NULL value, and it is the first value in the expression).
14491471
let ignore_null = if let Some(ignore_null) = window_ignore_null {
14501472
*ignore_null
14511473
} else {
@@ -2090,7 +2112,7 @@ impl<'a> TypeChecker<'a> {
20902112
param_count: usize,
20912113
span: Span,
20922114
) -> Result<()> {
2093-
// json lambda functions are casted to array or map, ignored here.
2115+
// json lambda functions are cast to array or map, ignored here.
20942116
let expected_count = if func_name == "array_reduce" {
20952117
2
20962118
} else if func_name.starts_with("array") {
@@ -3124,37 +3146,38 @@ impl<'a> TypeChecker<'a> {
31243146
Ok(Box::new((subquery_expr.into(), data_type)))
31253147
}
31263148

3127-
pub fn all_sugar_functions() -> &'static [&'static str] {
3128-
&[
3129-
"current_catalog",
3130-
"database",
3131-
"currentdatabase",
3132-
"current_database",
3133-
"version",
3134-
"user",
3135-
"currentuser",
3136-
"current_user",
3137-
"current_role",
3138-
"connection_id",
3139-
"timezone",
3140-
"nullif",
3141-
"ifnull",
3142-
"nvl",
3143-
"nvl2",
3144-
"is_null",
3145-
"is_error",
3146-
"error_or",
3147-
"coalesce",
3148-
"last_query_id",
3149-
"array_sort",
3150-
"array_aggregate",
3151-
"to_variant",
3152-
"try_to_variant",
3153-
"greatest",
3154-
"least",
3155-
"stream_has_data",
3156-
"getvariable",
3157-
]
3149+
pub fn all_sugar_functions() -> &'static [Ascii<&'static str>] {
3150+
static FUNCTIONS: &[Ascii<&'static str>] = &[
3151+
Ascii::new("current_catalog"),
3152+
Ascii::new("database"),
3153+
Ascii::new("currentdatabase"),
3154+
Ascii::new("current_database"),
3155+
Ascii::new("version"),
3156+
Ascii::new("user"),
3157+
Ascii::new("currentuser"),
3158+
Ascii::new("current_user"),
3159+
Ascii::new("current_role"),
3160+
Ascii::new("connection_id"),
3161+
Ascii::new("timezone"),
3162+
Ascii::new("nullif"),
3163+
Ascii::new("ifnull"),
3164+
Ascii::new("nvl"),
3165+
Ascii::new("nvl2"),
3166+
Ascii::new("is_null"),
3167+
Ascii::new("is_error"),
3168+
Ascii::new("error_or"),
3169+
Ascii::new("coalesce"),
3170+
Ascii::new("last_query_id"),
3171+
Ascii::new("array_sort"),
3172+
Ascii::new("array_aggregate"),
3173+
Ascii::new("to_variant"),
3174+
Ascii::new("try_to_variant"),
3175+
Ascii::new("greatest"),
3176+
Ascii::new("least"),
3177+
Ascii::new("stream_has_data"),
3178+
Ascii::new("getvariable"),
3179+
];
3180+
FUNCTIONS
31583181
}
31593182

31603183
fn try_rewrite_sugar_function(

tests/sqllogictests/suites/query/case_sensitivity/name_hit.test

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ select * from Student
2828
statement ok
2929
set unquoted_ident_case_sensitive = 1
3030

31+
statement ok
32+
SELECT VERSION()
33+
3134
statement error (?s)1025,.*Unknown table `default`\.`default`\.student \.
3235
INSERT INTO student VALUES(1)
3336

0 commit comments

Comments
 (0)