Skip to content

Commit

Permalink
migrate builtin function lookup to unicase
Browse files Browse the repository at this point in the history
  • Loading branch information
notauserx committed Nov 27, 2024
1 parent 7d2fc12 commit c08754d
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 68 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/query/functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ siphasher = { workspace = true }
strength_reduce = { workspace = true }
stringslice = { workspace = true }
twox-hash = { workspace = true }
unicase = { workspace = true }

[dev-dependencies]
comfy-table = { workspace = true }
Expand Down
81 changes: 44 additions & 37 deletions src/query/functions/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,17 @@
use aggregates::AggregateFunctionFactory;
use ctor::ctor;
use databend_common_expression::FunctionRegistry;
use unicase::Ascii;

pub mod aggregates;
mod cast_rules;
pub mod scalars;
pub mod srfs;

pub fn is_builtin_function(name: &str) -> bool {
BUILTIN_FUNCTIONS.contains(name)
|| AggregateFunctionFactory::instance().contains(name)
let name = Ascii::new(name);
BUILTIN_FUNCTIONS.contains(name.into_inner())
|| AggregateFunctionFactory::instance().contains(name.into_inner())
|| GENERAL_WINDOW_FUNCTIONS.contains(&name)
|| GENERAL_LAMBDA_FUNCTIONS.contains(&name)
|| GENERAL_SEARCH_FUNCTIONS.contains(&name)
Expand All @@ -44,53 +46,58 @@ pub fn is_builtin_function(name: &str) -> bool {
// The plan of search function, async function and udf contains some arguments defined in meta,
// which may be modified by user at any time. Those functions are not not suitable for caching.
pub fn is_cacheable_function(name: &str) -> bool {
BUILTIN_FUNCTIONS.contains(name)
|| AggregateFunctionFactory::instance().contains(name)
let name = Ascii::new(name);
BUILTIN_FUNCTIONS.contains(name.into_inner())
|| AggregateFunctionFactory::instance().contains(name.into_inner())
|| GENERAL_WINDOW_FUNCTIONS.contains(&name)
|| GENERAL_LAMBDA_FUNCTIONS.contains(&name)
}

#[ctor]
pub static BUILTIN_FUNCTIONS: FunctionRegistry = builtin_functions();

pub const ASYNC_FUNCTIONS: [&str; 2] = ["nextval", "dict_get"];
pub const ASYNC_FUNCTIONS: [Ascii<&str>; 2] = [Ascii::new("nextval"), Ascii::new("dict_get")];

pub const GENERAL_WINDOW_FUNCTIONS: [&str; 13] = [
"row_number",
"rank",
"dense_rank",
"percent_rank",
"lag",
"lead",
"first_value",
"first",
"last_value",
"last",
"nth_value",
"ntile",
"cume_dist",
pub const GENERAL_WINDOW_FUNCTIONS: [Ascii<&str>; 13] = [
Ascii::new("row_number"),
Ascii::new("rank"),
Ascii::new("dense_rank"),
Ascii::new("percent_rank"),
Ascii::new("lag"),
Ascii::new("lead"),
Ascii::new("first_value"),
Ascii::new("first"),
Ascii::new("last_value"),
Ascii::new("last"),
Ascii::new("nth_value"),
Ascii::new("ntile"),
Ascii::new("cume_dist"),
];

pub const GENERAL_LAMBDA_FUNCTIONS: [&str; 16] = [
"array_transform",
"array_apply",
"array_map",
"array_filter",
"array_reduce",
"json_array_transform",
"json_array_apply",
"json_array_map",
"json_array_filter",
"json_array_reduce",
"map_filter",
"map_transform_keys",
"map_transform_values",
"json_map_filter",
"json_map_transform_keys",
"json_map_transform_values",
pub const GENERAL_LAMBDA_FUNCTIONS: [Ascii<&str>; 16] = [
Ascii::new("array_transform"),
Ascii::new("array_apply"),
Ascii::new("array_map"),
Ascii::new("array_filter"),
Ascii::new("array_reduce"),
Ascii::new("json_array_transform"),
Ascii::new("json_array_apply"),
Ascii::new("json_array_map"),
Ascii::new("json_array_filter"),
Ascii::new("json_array_reduce"),
Ascii::new("map_filter"),
Ascii::new("map_transform_keys"),
Ascii::new("map_transform_values"),
Ascii::new("json_map_filter"),
Ascii::new("json_map_transform_keys"),
Ascii::new("json_map_transform_values"),
];

pub const GENERAL_SEARCH_FUNCTIONS: [&str; 3] = ["match", "query", "score"];
pub const GENERAL_SEARCH_FUNCTIONS: [Ascii<&str>; 3] = [
Ascii::new("match"),
Ascii::new("query"),
Ascii::new("score"),
];

fn builtin_functions() -> FunctionRegistry {
let mut registry = FunctionRegistry::empty();
Expand Down
40 changes: 19 additions & 21 deletions src/query/sql/src/planner/semantic/type_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ use itertools::Itertools;
use jsonb::keypath::KeyPath;
use jsonb::keypath::KeyPaths;
use simsearch::SimSearch;
use unicase::UniCase;
use unicase::Ascii;

use super::name_resolution::NameResolutionContext;
use super::normalize_identifier;
Expand Down Expand Up @@ -185,7 +185,7 @@ pub struct TypeChecker<'a> {
// This is used to check if there is nested aggregate function.
in_aggregate_function: bool,

// true if current expr is inside an window function.
// true if current expr is inside a window function.
// This is used to allow aggregation function in window's aggregate function.
in_window_function: bool,
forbid_udf: bool,
Expand Down Expand Up @@ -722,7 +722,10 @@ impl<'a> TypeChecker<'a> {
} => {
let func_name = normalize_identifier(name, self.name_resolution_ctx).to_string();
let func_name = func_name.as_str();
if !is_builtin_function(func_name) && !Self::is_sugar_function(func_name) {
let uni_case_func_name = Ascii::new(func_name);
if !is_builtin_function(func_name)
&& !Self::all_sugar_functions().contains(&uni_case_func_name)
{
if let Some(udf) = self.resolve_udf(*span, func_name, args)? {
return Ok(udf);
} else {
Expand All @@ -731,10 +734,10 @@ impl<'a> TypeChecker<'a> {
.all_function_names()
.into_iter()
.chain(AggregateFunctionFactory::instance().registered_names())
.chain(GENERAL_WINDOW_FUNCTIONS.iter().cloned().map(str::to_string))
.chain(GENERAL_LAMBDA_FUNCTIONS.iter().cloned().map(str::to_string))
.chain(GENERAL_SEARCH_FUNCTIONS.iter().cloned().map(str::to_string))
.chain(ASYNC_FUNCTIONS.iter().cloned().map(str::to_string))
.chain(GENERAL_WINDOW_FUNCTIONS.iter().cloned().map(|ascii| ascii.into_inner()))
.chain(GENERAL_LAMBDA_FUNCTIONS.iter().cloned().map(|ascii| ascii.into_inner()))
.chain(GENERAL_SEARCH_FUNCTIONS.iter().cloned().map(|ascii| ascii.into_inner()))
.chain(ASYNC_FUNCTIONS.iter().cloned().map(|ascii| ascii.into_inner()))
.chain(
Self::all_sugar_functions()
.iter()
Expand Down Expand Up @@ -768,15 +771,15 @@ impl<'a> TypeChecker<'a> {
// check window function legal
if window.is_some()
&& !AggregateFunctionFactory::instance().contains(func_name)
&& !GENERAL_WINDOW_FUNCTIONS.contains(&func_name)
&& !GENERAL_WINDOW_FUNCTIONS.contains(&uni_case_func_name)
{
return Err(ErrorCode::SemanticError(
"only window and aggregate functions allowed in window syntax",
)
.set_span(*span));
}
// check lambda function legal
if lambda.is_some() && !GENERAL_LAMBDA_FUNCTIONS.contains(&func_name) {
if lambda.is_some() && !GENERAL_LAMBDA_FUNCTIONS.contains(&uni_case_func_name) {
return Err(ErrorCode::SemanticError(
"only lambda functions allowed in lambda syntax",
)
Expand All @@ -785,7 +788,7 @@ impl<'a> TypeChecker<'a> {

let args: Vec<&Expr> = args.iter().collect();

if GENERAL_WINDOW_FUNCTIONS.contains(&func_name) {
if GENERAL_WINDOW_FUNCTIONS.contains(&uni_case_func_name) {
// general window function
if window.is_none() {
return Err(ErrorCode::SemanticError(format!(
Expand Down Expand Up @@ -851,7 +854,7 @@ impl<'a> TypeChecker<'a> {
// aggregate function
Box::new((new_agg_func.into(), data_type))
}
} else if GENERAL_LAMBDA_FUNCTIONS.contains(&func_name) {
} else if GENERAL_LAMBDA_FUNCTIONS.contains(&uni_case_func_name) {
if lambda.is_none() {
return Err(ErrorCode::SemanticError(format!(
"function {func_name} must have a lambda expression",
Expand All @@ -860,8 +863,8 @@ impl<'a> TypeChecker<'a> {
}
let lambda = lambda.as_ref().unwrap();
self.resolve_lambda_function(*span, func_name, &args, lambda)?
} else if GENERAL_SEARCH_FUNCTIONS.contains(&func_name) {
match func_name {
} else if GENERAL_SEARCH_FUNCTIONS.contains(&uni_case_func_name) {
match func_name.to_lowercase().as_str() {
"score" => self.resolve_score_search_function(*span, func_name, &args)?,
"match" => self.resolve_match_search_function(*span, func_name, &args)?,
"query" => self.resolve_query_search_function(*span, func_name, &args)?,
Expand All @@ -873,7 +876,7 @@ impl<'a> TypeChecker<'a> {
.set_span(*span));
}
}
} else if ASYNC_FUNCTIONS.contains(&func_name) {
} else if ASYNC_FUNCTIONS.contains(&uni_case_func_name) {
self.resolve_async_function(*span, func_name, &args)?
} else if BUILTIN_FUNCTIONS
.get_property(func_name)
Expand Down Expand Up @@ -1435,7 +1438,7 @@ impl<'a> TypeChecker<'a> {
self.in_window_function = false;

// If { IGNORE | RESPECT } NULLS is not specified, the default is RESPECT NULLS
// (i.e. a NULL value will be returned if the expression contains a NULL value and it is the first value in the expression).
// (i.e. a NULL value will be returned if the expression contains a NULL value, and it is the first value in the expression).
let ignore_null = if let Some(ignore_null) = window_ignore_null {
*ignore_null
} else {
Expand Down Expand Up @@ -2080,7 +2083,7 @@ impl<'a> TypeChecker<'a> {
param_count: usize,
span: Span,
) -> Result<()> {
// json lambda functions are casted to array or map, ignored here.
// json lambda functions are cast to array or map, ignored here.
let expected_count = if func_name == "array_reduce" {
2
} else if func_name.starts_with("array") {
Expand Down Expand Up @@ -3153,11 +3156,6 @@ impl<'a> TypeChecker<'a> {
FUNCTIONS
}

pub fn is_sugar_function(name: &str) -> bool {
let name = Ascii::new(name);
all_sugar_functions().iter().any(|func| func.eq(&name))
}

fn try_rewrite_sugar_function(
&mut self,
span: Span,
Expand Down
10 changes: 0 additions & 10 deletions src/query/sql/tests/type_check.rs

This file was deleted.

0 comments on commit c08754d

Please sign in to comment.