Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
498 changes: 498 additions & 0 deletions datafusion/expr-common/src/signature.rs

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ pub use udaf::{
udaf_default_window_function_schema_name, AggregateUDF, AggregateUDFImpl,
ReversedUDAF, SetMonotonicity, StatisticsArgs,
};
pub use udf::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl};
pub use udf::{arguments, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl};
pub use udwf::{ReversedUDWF, WindowUDF, WindowUDFImpl};
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};

Expand Down
274 changes: 274 additions & 0 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -949,3 +949,277 @@ mod tests {
hasher.finish()
}
}

/// Argument resolution logic for named function parameters
pub mod arguments {
use datafusion_common::{plan_err, Result};
use crate::Expr;

/// Resolves function arguments, handling named and positional notation.
///
/// This function validates and reorders arguments to match the function's parameter names
/// when named arguments are used.
///
/// # Rules
/// - All positional arguments must come before named arguments
/// - Named arguments can be in any order after positional arguments
/// - All parameter names must match the provided parameter_names
/// - No duplicate parameter names allowed
///
/// # Arguments
/// * `param_names` - The function's parameter names in order
/// * `args` - The argument expressions
/// * `arg_names` - Optional parameter name for each argument
///
/// # Returns
/// A vector of expressions in the correct order matching the parameter names
///
/// # Examples
/// ```rust,ignore
/// // Given parameters ["a", "b", "c"]
/// // And call: func(10, c => 30, b => 20)
/// // Returns: [Expr(10), Expr(20), Expr(30)]
/// ```
pub fn resolve_function_arguments(
param_names: &[String],
args: Vec<Expr>,
arg_names: Vec<Option<String>>,
) -> Result<Vec<Expr>> {
// Validate that arg_names length matches args length
if args.len() != arg_names.len() {
return plan_err!(
"Internal error: args length ({}) != arg_names length ({})",
args.len(),
arg_names.len()
);
}

// Check if all arguments are positional (fast path)
if arg_names.iter().all(|name| name.is_none()) {
return Ok(args);
}

// Validate mixed positional and named arguments
validate_argument_order(&arg_names)?;

// Validate and reorder named arguments
reorder_named_arguments(param_names, args, arg_names)
}

/// Validates that positional arguments come before named arguments
fn validate_argument_order(arg_names: &[Option<String>]) -> Result<()> {
let mut seen_named = false;
for (i, arg_name) in arg_names.iter().enumerate() {
match arg_name {
Some(_) => seen_named = true,
None if seen_named => {
return plan_err!(
"Positional argument at position {} follows named argument. \
All positional arguments must come before named arguments.",
i
);
}
None => {}
}
}
Ok(())
}

/// Reorders arguments based on named parameters to match signature order
fn reorder_named_arguments(
param_names: &[String],
args: Vec<Expr>,
arg_names: Vec<Option<String>>,
) -> Result<Vec<Expr>> {
// Count positional vs named arguments
let positional_count = arg_names.iter().filter(|n| n.is_none()).count();

// Capture args length before consuming the vector
let args_len = args.len();

// Create a result vector with the expected size
let expected_arg_count = param_names.len();
let mut result: Vec<Option<Expr>> = vec![None; expected_arg_count];

// Track which parameters have been assigned
let mut assigned = vec![false; expected_arg_count];

// Process all arguments (both positional and named)
for (i, (arg, arg_name)) in args.into_iter().zip(arg_names).enumerate() {
if let Some(name) = arg_name {
// Named argument - find its position in param_names
let param_index = param_names
.iter()
.position(|p| p == &name)
.ok_or_else(|| {
datafusion_common::plan_datafusion_err!(
"Unknown parameter name '{}'. Valid parameters are: [{}]",
name,
param_names.join(", ")
)
})?;

// Check if this parameter was already assigned
if assigned[param_index] {
return plan_err!(
"Parameter '{}' specified multiple times",
name
);
}

result[param_index] = Some(arg);
assigned[param_index] = true;
} else {
// Positional argument - place at current position
if i >= expected_arg_count {
return plan_err!(
"Too many positional arguments: expected at most {}, got {}",
expected_arg_count,
positional_count
);
}
result[i] = Some(arg);
assigned[i] = true;
}
}

// Check if all required parameters were provided
// Only require parameters up to the number of arguments provided (supports optional parameters)
let required_count = args_len;
for i in 0..required_count {
if !assigned[i] {
return plan_err!(
"Missing required parameter '{}'",
param_names[i]
);
}
}

// Return only the assigned parameters (handles optional trailing parameters)
Ok(result.into_iter().take(required_count).map(|e| e.unwrap()).collect())
}

#[cfg(test)]
mod tests {
use super::*;
use crate::lit;

#[test]
fn test_all_positional() {
let param_names = vec!["a".to_string(), "b".to_string()];

let args = vec![lit(1), lit("hello")];
let arg_names = vec![None, None];

let result = resolve_function_arguments(&param_names, args.clone(), arg_names).unwrap();
assert_eq!(result.len(), 2);
}

#[test]
fn test_all_named() {
let param_names = vec!["a".to_string(), "b".to_string()];

let args = vec![lit(1), lit("hello")];
let arg_names = vec![Some("a".to_string()), Some("b".to_string())];

let result = resolve_function_arguments(&param_names, args, arg_names).unwrap();
assert_eq!(result.len(), 2);
}

#[test]
fn test_named_reordering() {
let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()];

// Call with: func(c => 3.0, a => 1, b => "hello")
let args = vec![lit(3.0), lit(1), lit("hello")];
let arg_names = vec![
Some("c".to_string()),
Some("a".to_string()),
Some("b".to_string()),
];

let result = resolve_function_arguments(&param_names, args, arg_names).unwrap();

// Should be reordered to [a, b, c] = [1, "hello", 3.0]
assert_eq!(result.len(), 3);
assert_eq!(result[0], lit(1));
assert_eq!(result[1], lit("hello"));
assert_eq!(result[2], lit(3.0));
}

#[test]
fn test_mixed_positional_and_named() {
let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()];

// Call with: func(1, c => 3.0, b => "hello")
let args = vec![lit(1), lit(3.0), lit("hello")];
let arg_names = vec![None, Some("c".to_string()), Some("b".to_string())];

let result = resolve_function_arguments(&param_names, args, arg_names).unwrap();

// Should be reordered to [a, b, c] = [1, "hello", 3.0]
assert_eq!(result.len(), 3);
assert_eq!(result[0], lit(1));
assert_eq!(result[1], lit("hello"));
assert_eq!(result[2], lit(3.0));
}

#[test]
fn test_positional_after_named_error() {
let param_names = vec!["a".to_string(), "b".to_string()];

// Call with: func(a => 1, "hello") - ERROR
let args = vec![lit(1), lit("hello")];
let arg_names = vec![Some("a".to_string()), None];

let result = resolve_function_arguments(&param_names, args, arg_names);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Positional argument"));
}

#[test]
fn test_unknown_parameter_name() {
let param_names = vec!["a".to_string(), "b".to_string()];

// Call with: func(x => 1, b => "hello") - ERROR
let args = vec![lit(1), lit("hello")];
let arg_names = vec![Some("x".to_string()), Some("b".to_string())];

let result = resolve_function_arguments(&param_names, args, arg_names);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Unknown parameter"));
}

#[test]
fn test_duplicate_parameter_name() {
let param_names = vec!["a".to_string(), "b".to_string()];

// Call with: func(a => 1, a => 2) - ERROR
let args = vec![lit(1), lit(2)];
let arg_names = vec![Some("a".to_string()), Some("a".to_string())];

let result = resolve_function_arguments(&param_names, args, arg_names);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("specified multiple times"));
}

#[test]
fn test_missing_required_parameter() {
let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()];

// Call with: func(a => 1, c => 3.0) - missing 'b'
let args = vec![lit(1), lit(3.0)];
let arg_names = vec![Some("a".to_string()), Some("c".to_string())];

let result = resolve_function_arguments(&param_names, args, arg_names);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Missing required parameter"));
}
}
}
62 changes: 61 additions & 1 deletion datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@ pub fn generate_signature_error_msg(
) -> String {
let candidate_signatures = func_signature
.type_signature
.to_string_repr()
.to_string_repr_with_names(func_signature.parameter_names.as_deref())
.iter()
.map(|args_str| format!("\t{func_name}({args_str})"))
.collect::<Vec<String>>()
Expand Down Expand Up @@ -1714,4 +1714,64 @@ mod tests {
DataType::List(Arc::new(Field::new("my_union", union_type, true)));
assert!(!can_hash(&list_union_type));
}

#[test]
fn test_generate_signature_error_msg_with_parameter_names() {
use datafusion_expr_common::signature::{TypeSignature, Volatility};

// Create a signature like substr with parameter names
let sig = Signature::one_of(
vec![TypeSignature::Any(2), TypeSignature::Any(3)],
Volatility::Immutable,
)
.with_parameter_names(vec![
"str".to_string(),
"start_pos".to_string(),
"length".to_string(),
])
.expect("valid parameter names");

// Generate error message with only 1 argument provided
let error_msg = generate_signature_error_msg("substr", sig, &[DataType::Utf8]);

// Error message should contain parameter names
assert!(
error_msg.contains("str, start_pos"),
"Expected 'str, start_pos' in error message, got: {}",
error_msg
);
assert!(
error_msg.contains("str, start_pos, length"),
"Expected 'str, start_pos, length' in error message, got: {}",
error_msg
);

// Should NOT contain generic "Any" types
assert!(
!error_msg.contains("Any, Any"),
"Should not contain 'Any, Any', got: {}",
error_msg
);
}

#[test]
fn test_generate_signature_error_msg_without_parameter_names() {
use datafusion_expr_common::signature::{TypeSignature, Volatility};

// Create a signature without parameter names
let sig = Signature::one_of(
vec![TypeSignature::Any(2), TypeSignature::Any(3)],
Volatility::Immutable,
);

// Generate error message
let error_msg = generate_signature_error_msg("my_func", sig, &[DataType::Int32]);

// Should contain generic "Any" types when no parameter names
assert!(
error_msg.contains("Any, Any"),
"Expected 'Any, Any' without parameter names, got: {}",
error_msg
);
}
}
3 changes: 3 additions & 0 deletions datafusion/functions-nested/src/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ impl ArrayReplace {
},
),
volatility: Volatility::Immutable,
parameter_names: None,
},
aliases: vec![String::from("list_replace")],
}
Expand Down Expand Up @@ -186,6 +187,7 @@ impl ArrayReplaceN {
},
),
volatility: Volatility::Immutable,
parameter_names: None,
},
aliases: vec![String::from("list_replace_n")],
}
Expand Down Expand Up @@ -265,6 +267,7 @@ impl ArrayReplaceAll {
},
),
volatility: Volatility::Immutable,
parameter_names: None,
},
aliases: vec![String::from("list_replace_all")],
}
Expand Down
Loading
Loading