Skip to content

Commit 7ced0ac

Browse files
discord9alamb
andauthored
fix: block timestamp precision narrowing unwrap (#22837)
## Which issue does this PR close? Part of GreptimeTeam/greptimedb#8214. Follow-up / smaller alternative to #21908. So this quick fix is for fix a very common path, the rest(and the allow list version still lays in #21908 and need more discussion) - closes #22142 ## Rationale for this change `unwrap_cast` can currently rewrite predicates like: ```sql CAST(ts_ns AS timestamp(3)) = timestamp(3) '2024-01-01 00:00:00.001' ``` into an equality against the original nanosecond column. That is not equivalent: the original predicate matches every nanosecond timestamp within the same millisecond, while the rewritten predicate only matches the exact millisecond boundary. ## What changes are included in this PR? This is intentionally a small blocklist-only fix: - add a shared `is_timestamp_precision_narrowing_cast` helper - block comparison cast unwrap when a timestamp cast narrows precision - apply the same guard to logical and physical unwrap-cast simplifiers - keep timestamp precision widening unwraps enabled ## Are these changes tested? Added targeted logical, physical, and helper tests. Ran: ```text cargo test -p datafusion-optimizer unwrap_cast cargo test -p datafusion-physical-expr unwrap_cast cargo test -p datafusion-expr-common test_timestamp_precision_narrowing_cast ``` ## Are there any user-facing changes? Plans will keep timestamp precision-narrowing casts in comparison predicates instead of unwrapping them incorrectly. --------- Signed-off-by: discord9 <discord9@163.com> Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent b846fb2 commit 7ced0ac

4 files changed

Lines changed: 164 additions & 10 deletions

File tree

datafusion/expr-common/src/casts.rs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,35 @@ fn is_lossy_temporal_cast(from_type: &DataType, to_type: &DataType) -> bool {
103103
|| (is_date_type(to_type) && from_type.is_temporal())
104104
}
105105

106+
/// Returns true when casting a timestamp from `from_type` to `to_type` loses
107+
/// timestamp precision.
108+
///
109+
/// This is used by comparison cast unwrapping to avoid rewrites such as
110+
/// `CAST(ts_ns AS timestamp(ms)) = lit_ms` -> `ts_ns = lit_ns`. The original
111+
/// predicate can match any nanosecond value in the same millisecond, while the
112+
/// rewritten predicate only matches the exact millisecond boundary.
113+
pub fn is_timestamp_precision_narrowing_cast(
114+
from_type: &DataType,
115+
to_type: &DataType,
116+
) -> bool {
117+
let (DataType::Timestamp(from_unit, _), DataType::Timestamp(to_unit, _)) =
118+
(from_type, to_type)
119+
else {
120+
return false;
121+
};
122+
123+
timestamp_unit_scale(from_unit) > timestamp_unit_scale(to_unit)
124+
}
125+
126+
fn timestamp_unit_scale(unit: &TimeUnit) -> i128 {
127+
match unit {
128+
TimeUnit::Second => 1,
129+
TimeUnit::Millisecond => MILLISECONDS as i128,
130+
TimeUnit::Microsecond => MICROSECONDS as i128,
131+
TimeUnit::Nanosecond => NANOSECONDS as i128,
132+
}
133+
}
134+
106135
/// Returns true if unwrap_cast_in_comparison supports this numeric type
107136
fn is_supported_numeric_type(data_type: &DataType) -> bool {
108137
matches!(
@@ -784,6 +813,23 @@ mod tests {
784813
);
785814
}
786815

816+
#[test]
817+
fn test_timestamp_precision_narrowing_cast() {
818+
let ts_ns = DataType::Timestamp(TimeUnit::Nanosecond, None);
819+
let ts_us = DataType::Timestamp(TimeUnit::Microsecond, None);
820+
let ts_ms = DataType::Timestamp(TimeUnit::Millisecond, None);
821+
let ts_s = DataType::Timestamp(TimeUnit::Second, None);
822+
823+
assert!(is_timestamp_precision_narrowing_cast(&ts_ns, &ts_ms));
824+
assert!(is_timestamp_precision_narrowing_cast(&ts_us, &ts_s));
825+
assert!(!is_timestamp_precision_narrowing_cast(&ts_ms, &ts_ns));
826+
assert!(!is_timestamp_precision_narrowing_cast(&ts_ms, &ts_ms));
827+
assert!(!is_timestamp_precision_narrowing_cast(
828+
&DataType::Int64,
829+
&ts_ms
830+
));
831+
}
832+
787833
#[test]
788834
fn test_try_cast_to_type_unsupported() {
789835
// int64 to list

datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ use datafusion_common::{Result, ScalarValue};
5959
use datafusion_common::{internal_err, tree_node::Transformed};
6060
use datafusion_expr::{BinaryExpr, lit};
6161
use datafusion_expr::{Cast, Expr, Operator, TryCast, simplify::SimplifyContext};
62-
use datafusion_expr_common::casts::{is_supported_type, try_cast_literal_to_type};
62+
use datafusion_expr_common::casts::{
63+
is_supported_type, is_timestamp_precision_narrowing_cast, try_cast_literal_to_type,
64+
};
6365

6466
pub(super) fn unwrap_cast_in_comparison_for_binary(
6567
info: &SimplifyContext,
@@ -113,10 +115,14 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
113115
match (expr, literal) {
114116
(
115117
Expr::TryCast(TryCast {
116-
expr: left_expr, ..
118+
expr: left_expr,
119+
field,
120+
..
117121
})
118122
| Expr::Cast(Cast {
119-
expr: left_expr, ..
123+
expr: left_expr,
124+
field,
125+
..
120126
}),
121127
Expr::Literal(lit_val, _),
122128
) => {
@@ -128,6 +134,10 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
128134
return false;
129135
};
130136

137+
if is_timestamp_precision_narrowing_cast(&expr_type, field.data_type()) {
138+
return false;
139+
}
140+
131141
if cast_literal_to_type_with_op(lit_val, &expr_type, op).is_some() {
132142
return true;
133143
}
@@ -146,10 +156,14 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist(
146156
list: &[Expr],
147157
) -> bool {
148158
let (Expr::TryCast(TryCast {
149-
expr: left_expr, ..
159+
expr: left_expr,
160+
field,
161+
..
150162
})
151163
| Expr::Cast(Cast {
152-
expr: left_expr, ..
164+
expr: left_expr,
165+
field,
166+
..
153167
})) = expr
154168
else {
155169
return false;
@@ -163,6 +177,10 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist(
163177
return false;
164178
}
165179

180+
if is_timestamp_precision_narrowing_cast(&expr_type, field.data_type()) {
181+
return false;
182+
}
183+
166184
for right in list {
167185
let Ok(right_type) = info.get_data_type(right) else {
168186
return false;
@@ -586,6 +604,25 @@ mod tests {
586604
assert_eq!(optimize_test(expr_lt, &schema), expected);
587605
}
588606

607+
#[test]
608+
fn test_not_unwrap_cast_timestamp_precision_narrowing() {
609+
let schema = expr_test_schema();
610+
let expr_input = cast(col("ts_nano_none"), timestamp_millis_none_type())
611+
.eq(lit_timestamp_millis_none(1));
612+
613+
assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
614+
}
615+
616+
#[test]
617+
fn test_unwrap_cast_timestamp_precision_widening() {
618+
let schema = expr_test_schema();
619+
let expr_input = cast(col("ts_millis_none"), timestamp_nano_none_type())
620+
.eq(lit_timestamp_nano_none(1_000_000));
621+
let expected = col("ts_millis_none").eq(lit_timestamp_millis_none(1));
622+
623+
assert_eq!(optimize_test(expr_input, &schema), expected);
624+
}
625+
589626
fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr {
590627
let simplifier = ExprSimplifier::new(
591628
SimplifyContext::builder()
@@ -607,6 +644,7 @@ mod tests {
607644
Field::new("c5", DataType::Float32, false),
608645
Field::new("c6", DataType::UInt32, false),
609646
Field::new("ts_nano_none", timestamp_nano_none_type(), false),
647+
Field::new("ts_millis_none", timestamp_millis_none_type(), false),
610648
Field::new("ts_nano_utf", timestamp_nano_utc_type(), false),
611649
Field::new("str1", DataType::Utf8, false),
612650
Field::new("largestr", DataType::LargeUtf8, false),
@@ -643,6 +681,10 @@ mod tests {
643681
lit(ScalarValue::TimestampNanosecond(Some(ts), None))
644682
}
645683

684+
fn lit_timestamp_millis_none(ts: i64) -> Expr {
685+
lit(ScalarValue::TimestampMillisecond(Some(ts), None))
686+
}
687+
646688
fn lit_timestamp_nano_utc(ts: i64) -> Expr {
647689
let utc = Some("+0:00".into());
648690
lit(ScalarValue::TimestampNanosecond(Some(ts), utc))
@@ -652,6 +694,10 @@ mod tests {
652694
DataType::Timestamp(TimeUnit::Nanosecond, None)
653695
}
654696

697+
fn timestamp_millis_none_type() -> DataType {
698+
DataType::Timestamp(TimeUnit::Millisecond, None)
699+
}
700+
655701
// this is the type that now() returns
656702
fn timestamp_nano_utc_type() -> DataType {
657703
let utc = Some("+0:00".into());

datafusion/optimizer/tests/optimizer_integration.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,7 @@ fn extension_node_does_not_block_projection_pruning() -> Result<()> {
795795
Projection: t.a, CAST(t.ts AS Timestamp(ms, "UTC")) AS ts
796796
Filter: __common_expr_3 > TimestampMillisecond(1000, Some("UTC")) AND __common_expr_3 < TimestampMillisecond(2000, Some("UTC"))
797797
Projection: CAST(t.ts AS Timestamp(ms, "UTC")) AS __common_expr_3, t.a, t.ts
798-
TableScan: t projection=[a, ts], partial_filters=[t.ts > TimestampNanosecond(1000000000, None), t.ts < TimestampNanosecond(2000000000, None), CAST(t.ts AS Timestamp(ms, "UTC")) > TimestampMillisecond(1000, Some("UTC")), CAST(t.ts AS Timestamp(ms, "UTC")) < TimestampMillisecond(2000, Some("UTC"))]
798+
TableScan: t projection=[a, ts], partial_filters=[CAST(t.ts AS Timestamp(ms, "UTC")) > TimestampMillisecond(1000, Some("UTC")), CAST(t.ts AS Timestamp(ms, "UTC")) < TimestampMillisecond(2000, Some("UTC"))]
799799
"#,
800800
);
801801

datafusion/physical-expr/src/simplifier/unwrap_cast.rs

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ use std::sync::Arc;
3636
use arrow::datatypes::{DataType, Schema};
3737
use datafusion_common::{Result, ScalarValue, tree_node::Transformed};
3838
use datafusion_expr::Operator;
39-
use datafusion_expr_common::casts::try_cast_literal_to_type;
39+
use datafusion_expr_common::casts::{
40+
is_timestamp_precision_narrowing_cast, try_cast_literal_to_type,
41+
};
4042

4143
use crate::PhysicalExpr;
4244
use crate::expressions::{BinaryExpr, CastExpr, Literal, TryCastExpr, lit};
@@ -60,13 +62,14 @@ fn try_unwrap_cast_binary(
6062
schema: &Schema,
6163
) -> Result<Option<Arc<dyn PhysicalExpr>>> {
6264
// Case 1: cast(left_expr) op literal
63-
if let (Some((inner_expr, _cast_type)), Some(literal)) = (
65+
if let (Some((inner_expr, cast_type)), Some(literal)) = (
6466
extract_cast_info(binary.left()),
6567
binary.right().downcast_ref::<Literal>(),
6668
) && binary.op().supports_propagation()
6769
&& let Some(unwrapped) = try_unwrap_cast_comparison(
6870
Arc::clone(inner_expr),
6971
literal.value(),
72+
cast_type,
7073
*binary.op(),
7174
schema,
7275
)?
@@ -75,7 +78,7 @@ fn try_unwrap_cast_binary(
7578
}
7679

7780
// Case 2: literal op cast(right_expr)
78-
if let (Some(literal), Some((inner_expr, _cast_type))) = (
81+
if let (Some(literal), Some((inner_expr, cast_type))) = (
7982
binary.left().downcast_ref::<Literal>(),
8083
extract_cast_info(binary.right()),
8184
) {
@@ -85,6 +88,7 @@ fn try_unwrap_cast_binary(
8588
&& let Some(unwrapped) = try_unwrap_cast_comparison(
8689
Arc::clone(inner_expr),
8790
literal.value(),
91+
cast_type,
8892
swapped_op,
8993
schema,
9094
)?
@@ -118,12 +122,17 @@ fn extract_cast_info(
118122
fn try_unwrap_cast_comparison(
119123
inner_expr: Arc<dyn PhysicalExpr>,
120124
literal_value: &ScalarValue,
125+
cast_type: &DataType,
121126
op: Operator,
122127
schema: &Schema,
123128
) -> Result<Option<Arc<dyn PhysicalExpr>>> {
124129
// Get the data type of the inner expression
125130
let inner_type = inner_expr.data_type(schema)?;
126131

132+
if is_timestamp_precision_narrowing_cast(&inner_type, cast_type) {
133+
return Ok(None);
134+
}
135+
127136
// Try to cast the literal to the inner expression's type
128137
if let Some(casted_literal) = try_cast_literal_to_type(literal_value, &inner_type) {
129138
let literal_expr = lit(casted_literal);
@@ -138,7 +147,7 @@ fn try_unwrap_cast_comparison(
138147
mod tests {
139148
use super::*;
140149
use crate::expressions::col;
141-
use arrow::datatypes::Field;
150+
use arrow::datatypes::{Field, TimeUnit};
142151
use datafusion_common::tree_node::TreeNode;
143152

144153
/// Check if an expression is a cast expression
@@ -548,6 +557,59 @@ mod tests {
548557
assert!(!result.transformed);
549558
}
550559

560+
#[test]
561+
fn test_not_unwrap_timestamp_precision_narrowing() {
562+
let schema = Schema::new(vec![Field::new(
563+
"ts",
564+
DataType::Timestamp(TimeUnit::Nanosecond, None),
565+
false,
566+
)]);
567+
568+
let column_expr = col("ts", &schema).unwrap();
569+
let cast_expr = Arc::new(CastExpr::new(
570+
column_expr,
571+
DataType::Timestamp(TimeUnit::Millisecond, None),
572+
None,
573+
));
574+
let literal_expr = lit(ScalarValue::TimestampMillisecond(Some(1), None));
575+
let binary_expr =
576+
Arc::new(BinaryExpr::new(cast_expr, Operator::Eq, literal_expr));
577+
578+
let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
579+
580+
assert!(!result.transformed);
581+
}
582+
583+
#[test]
584+
fn test_unwrap_timestamp_precision_widening() {
585+
let schema = Schema::new(vec![Field::new(
586+
"ts",
587+
DataType::Timestamp(TimeUnit::Millisecond, None),
588+
false,
589+
)]);
590+
591+
let column_expr = col("ts", &schema).unwrap();
592+
let cast_expr = Arc::new(CastExpr::new(
593+
column_expr,
594+
DataType::Timestamp(TimeUnit::Nanosecond, None),
595+
None,
596+
));
597+
let literal_expr = lit(ScalarValue::TimestampNanosecond(Some(1_000_000), None));
598+
let binary_expr =
599+
Arc::new(BinaryExpr::new(cast_expr, Operator::Eq, literal_expr));
600+
601+
let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap();
602+
603+
assert!(result.transformed);
604+
let optimized_binary = result.data.downcast_ref::<BinaryExpr>().unwrap();
605+
assert!(!is_cast_expr(optimized_binary.left()));
606+
let right_literal = optimized_binary.right().downcast_ref::<Literal>().unwrap();
607+
assert_eq!(
608+
right_literal.value(),
609+
&ScalarValue::TimestampMillisecond(Some(1), None)
610+
);
611+
}
612+
551613
#[test]
552614
fn test_complex_nested_expression() {
553615
let schema = test_schema();

0 commit comments

Comments
 (0)