diff --git a/src/ast/query.rs b/src/ast/query.rs index 7ffb64d9b..541047ca8 100644 --- a/src/ast/query.rs +++ b/src/ast/query.rs @@ -1336,7 +1336,7 @@ pub enum TableFactor { Pivot { table: Box, aggregate_functions: Vec, // Function expression - value_column: Vec, + value_column: Vec, // Expr is a identifier or a compound identifier value_source: PivotValueSource, default_on_null: Option, alias: Option, @@ -2010,10 +2010,15 @@ impl fmt::Display for TableFactor { } => { write!( f, - "{table} PIVOT({} FOR {} IN ({value_source})", + "{table} PIVOT({} FOR ", display_comma_separated(aggregate_functions), - Expr::CompoundIdentifier(value_column.to_vec()), )?; + if value_column.len() == 1 { + write!(f, "{}", value_column[0])?; + } else { + write!(f, "({})", display_comma_separated(value_column))?; + } + write!(f, " IN ({value_source})")?; if let Some(expr) = default_on_null { write!(f, " DEFAULT ON NULL ({expr})")?; } diff --git a/src/ast/spans.rs b/src/ast/spans.rs index 3e82905e1..f38d9a64f 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -1971,7 +1971,7 @@ impl Spanned for TableFactor { } => union_spans( core::iter::once(table.span()) .chain(aggregate_functions.iter().map(|i| i.span())) - .chain(value_column.iter().map(|i| i.span)) + .chain(value_column.iter().map(|i| i.span())) .chain(core::iter::once(value_source.span())) .chain(default_on_null.as_ref().map(|i| i.span())) .chain(alias.as_ref().map(|i| i.span())), diff --git a/src/ast/visitor.rs b/src/ast/visitor.rs index 8e0a3139a..7840f0e14 100644 --- a/src/ast/visitor.rs +++ b/src/ast/visitor.rs @@ -884,6 +884,8 @@ mod tests { "PRE: EXPR: a.amount", "POST: EXPR: a.amount", "POST: EXPR: SUM(a.amount)", + "PRE: EXPR: a.MONTH", + "POST: EXPR: a.MONTH", "PRE: EXPR: 'JAN'", "POST: EXPR: 'JAN'", "PRE: EXPR: 'FEB'", diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 8d5a55da0..99f5fe497 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -10820,6 +10820,18 @@ impl<'a> Parser<'a> { self.parse_parenthesized_column_list_inner(optional, allow_empty, |p| p.parse_identifier()) } + pub fn parse_parenthesized_compound_identifier_list( + &mut self, + optional: IsOptional, + allow_empty: bool, + ) -> Result, ParserError> { + self.parse_parenthesized_column_list_inner(optional, allow_empty, |p| { + Ok(Expr::CompoundIdentifier( + p.parse_period_separated(|p| p.parse_identifier())?, + )) + }) + } + /// Parses a parenthesized comma-separated list of index columns, which can be arbitrary /// expressions with ordering information (and an opclass in some dialects). fn parse_parenthesized_index_column_list(&mut self) -> Result, ParserError> { @@ -13828,7 +13840,13 @@ impl<'a> Parser<'a> { self.expect_token(&Token::LParen)?; let aggregate_functions = self.parse_comma_separated(Self::parse_aliased_function_call)?; self.expect_keyword_is(Keyword::FOR)?; - let value_column = self.parse_period_separated(|p| p.parse_identifier())?; + let value_column = if self.peek_token_ref().token == Token::LParen { + self.parse_parenthesized_compound_identifier_list(Mandatory, false)? + } else { + vec![Expr::CompoundIdentifier( + self.parse_period_separated(|p| p.parse_identifier())?, + )] + }; self.expect_keyword_is(Keyword::IN)?; self.expect_token(&Token::LParen)?; diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 5d8284a46..19ce2493a 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -10875,7 +10875,10 @@ fn parse_pivot_table() { expected_function("b", Some("t")), expected_function("c", Some("u")), ], - value_column: vec![Ident::new("a"), Ident::new("MONTH")], + value_column: vec![Expr::CompoundIdentifier(vec![ + Ident::new("a"), + Ident::new("MONTH") + ])], value_source: PivotValueSource::List(vec![ ExprWithAlias { expr: Expr::value(number("1")), @@ -10922,6 +10925,15 @@ fn parse_pivot_table() { verified_stmt(sql_without_table_alias).to_string(), sql_without_table_alias ); + + let sql_with_multiple_value_column = concat!( + "SELECT * FROM person ", + "PIVOT(SUM(age) AS a, AVG(class) AS c FOR (name, age) IN (('John', 30) AS c1, ('Mike', 40) AS c2))" + ); + assert_eq!( + verified_stmt(sql_with_multiple_value_column).to_string(), + sql_with_multiple_value_column + ); } #[test] @@ -11143,7 +11155,7 @@ fn parse_pivot_unpivot_table() { expr: call("sum", [Expr::Identifier(Ident::new("population"))]), alias: None }], - value_column: vec![Ident::new("year")], + value_column: vec![Expr::CompoundIdentifier(vec![Ident::new("year")])], value_source: PivotValueSource::List(vec![ ExprWithAlias { expr: Expr::Value(