diff --git a/src/ast/ddl.rs b/src/ast/ddl.rs index f998143c8..773b119fb 100644 --- a/src/ast/ddl.rs +++ b/src/ast/ddl.rs @@ -1458,12 +1458,19 @@ pub struct ProcedureParam { pub name: Ident, pub data_type: DataType, pub mode: Option, + pub default: Option, } impl fmt::Display for ProcedureParam { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { if let Some(mode) = &self.mode { - write!(f, "{mode} {} {}", self.name, self.data_type) + if let Some(default) = &self.default { + write!(f, "{mode} {} {} = {}", self.name, self.data_type, default) + } else { + write!(f, "{mode} {} {}", self.name, self.data_type) + } + } else if let Some(default) = &self.default { + write!(f, "{} {} = {}", self.name, self.data_type, default) } else { write!(f, "{} {}", self.name, self.data_type) } diff --git a/src/parser/mod.rs b/src/parser/mod.rs index d661efd4d..9682804c2 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -7897,10 +7897,17 @@ impl<'a> Parser<'a> { }; let name = self.parse_identifier()?; let data_type = self.parse_data_type()?; + let default = if self.consume_token(&Token::Eq) { + Some(self.parse_expr()?) + } else { + None + }; + Ok(ProcedureParam { name, data_type, mode, + default, }) } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 720c1e492..40a2664b0 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -16497,7 +16497,8 @@ fn parse_create_procedure_with_parameter_modes() { span: fake_span, }, data_type: DataType::Integer(None), - mode: Some(ArgMode::In) + mode: Some(ArgMode::In), + default: None, }, ProcedureParam { name: Ident { @@ -16506,7 +16507,8 @@ fn parse_create_procedure_with_parameter_modes() { span: fake_span, }, data_type: DataType::Text, - mode: Some(ArgMode::Out) + mode: Some(ArgMode::Out), + default: None, }, ProcedureParam { name: Ident { @@ -16515,7 +16517,8 @@ fn parse_create_procedure_with_parameter_modes() { span: fake_span, }, data_type: DataType::Timestamp(None, TimezoneInfo::None), - mode: Some(ArgMode::InOut) + mode: Some(ArgMode::InOut), + default: None, }, ProcedureParam { name: Ident { @@ -16524,7 +16527,8 @@ fn parse_create_procedure_with_parameter_modes() { span: fake_span, }, data_type: DataType::Bool, - mode: None + mode: None, + default: None, }, ]) ); @@ -16533,6 +16537,40 @@ fn parse_create_procedure_with_parameter_modes() { } } +#[test] +fn create_procedure_with_parameter_default_value() { + let sql = r#"CREATE PROCEDURE test_proc (a INT = 42) AS BEGIN SELECT 1; END"#; + match verified_stmt(sql) { + Statement::CreateProcedure { + or_alter, + name, + params, + .. + } => { + assert_eq!(or_alter, false); + assert_eq!(name.to_string(), "test_proc"); + let fake_span = Span { + start: Location { line: 0, column: 0 }, + end: Location { line: 0, column: 0 }, + }; + assert_eq!( + params, + Some(vec![ProcedureParam { + name: Ident { + value: "a".into(), + quote_style: None, + span: fake_span, + }, + data_type: DataType::Int(None), + mode: None, + default: Some(Expr::Value((number("42")).with_empty_span())), + },]) + ); + } + _ => unreachable!(), + } +} + #[test] fn parse_not_null() { let _ = all_dialects().expr_parses_to("x NOT NULL", "x IS NOT NULL"); diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index b1ad422ec..38181dc13 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -156,6 +156,7 @@ fn parse_create_procedure() { }, data_type: DataType::Int(None), mode: None, + default: None, }, ProcedureParam { name: Ident { @@ -168,6 +169,7 @@ fn parse_create_procedure() { unit: None })), mode: None, + default: None, } ]), name: ObjectName::from(vec![Ident {