Skip to content

Commit 7976b94

Browse files
authored
fix : implement_try_eval_mode_arithmetic (#2073)
1 parent 3da912c commit 7976b94

File tree

6 files changed

+285
-40
lines changed

6 files changed

+285
-40
lines changed

native/core/src/execution/planner.rs

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -238,62 +238,61 @@ impl PhysicalPlanner {
238238
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
239239
match spark_expr.expr_struct.as_ref().unwrap() {
240240
ExprStruct::Add(expr) => {
241-
// TODO respect eval mode
242-
// https://github.com/apache/datafusion-comet/issues/2021
241+
// TODO respect ANSI eval mode
243242
// https://github.com/apache/datafusion-comet/issues/536
244-
let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
243+
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
245244
self.create_binary_expr(
246245
expr.left.as_ref().unwrap(),
247246
expr.right.as_ref().unwrap(),
248247
expr.return_type.as_ref(),
249248
DataFusionOperator::Plus,
250249
input_schema,
250+
eval_mode,
251251
)
252252
}
253253
ExprStruct::Subtract(expr) => {
254-
// TODO respect eval mode
255-
// https://github.com/apache/datafusion-comet/issues/2021
254+
// TODO respect ANSI eval mode
256255
// https://github.com/apache/datafusion-comet/issues/535
257-
let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
256+
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
258257
self.create_binary_expr(
259258
expr.left.as_ref().unwrap(),
260259
expr.right.as_ref().unwrap(),
261260
expr.return_type.as_ref(),
262261
DataFusionOperator::Minus,
263262
input_schema,
263+
eval_mode,
264264
)
265265
}
266266
ExprStruct::Multiply(expr) => {
267-
// TODO respect eval mode
268-
// https://github.com/apache/datafusion-comet/issues/2021
267+
// TODO respect ANSI eval mode
269268
// https://github.com/apache/datafusion-comet/issues/534
270-
let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
269+
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
271270
self.create_binary_expr(
272271
expr.left.as_ref().unwrap(),
273272
expr.right.as_ref().unwrap(),
274273
expr.return_type.as_ref(),
275274
DataFusionOperator::Multiply,
276275
input_schema,
276+
eval_mode,
277277
)
278278
}
279279
ExprStruct::Divide(expr) => {
280-
// TODO respect eval mode
281-
// https://github.com/apache/datafusion-comet/issues/2021
280+
// TODO respect ANSI eval mode
282281
// https://github.com/apache/datafusion-comet/issues/533
283-
let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
282+
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
284283
self.create_binary_expr(
285284
expr.left.as_ref().unwrap(),
286285
expr.right.as_ref().unwrap(),
287286
expr.return_type.as_ref(),
288287
DataFusionOperator::Divide,
289288
input_schema,
289+
eval_mode,
290290
)
291291
}
292292
ExprStruct::IntegralDivide(expr) => {
293293
// TODO respect eval mode
294-
// https://github.com/apache/datafusion-comet/issues/2021
295294
// https://github.com/apache/datafusion-comet/issues/533
296-
let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
295+
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
297296
self.create_binary_expr_with_options(
298297
expr.left.as_ref().unwrap(),
299298
expr.right.as_ref().unwrap(),
@@ -303,6 +302,7 @@ impl PhysicalPlanner {
303302
BinaryExprOptions {
304303
is_integral_div: true,
305304
},
305+
eval_mode,
306306
)
307307
}
308308
ExprStruct::Remainder(expr) => {
@@ -1004,6 +1004,7 @@ impl PhysicalPlanner {
10041004
return_type: Option<&spark_expression::DataType>,
10051005
op: DataFusionOperator,
10061006
input_schema: SchemaRef,
1007+
eval_mode: EvalMode,
10071008
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
10081009
self.create_binary_expr_with_options(
10091010
left,
@@ -1012,9 +1013,11 @@ impl PhysicalPlanner {
10121013
op,
10131014
input_schema,
10141015
BinaryExprOptions::default(),
1016+
eval_mode,
10151017
)
10161018
}
10171019

1020+
#[allow(clippy::too_many_arguments)]
10181021
fn create_binary_expr_with_options(
10191022
&self,
10201023
left: &Expr,
@@ -1023,6 +1026,7 @@ impl PhysicalPlanner {
10231026
op: DataFusionOperator,
10241027
input_schema: SchemaRef,
10251028
options: BinaryExprOptions,
1029+
eval_mode: EvalMode,
10261030
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
10271031
let left = self.create_expr(left, Arc::clone(&input_schema))?;
10281032
let right = self.create_expr(right, Arc::clone(&input_schema))?;
@@ -1087,7 +1091,34 @@ impl PhysicalPlanner {
10871091
Arc::new(Field::new(func_name, data_type, true)),
10881092
)))
10891093
}
1090-
_ => Ok(Arc::new(BinaryExpr::new(left, op, right))),
1094+
_ => {
1095+
let data_type = return_type.map(to_arrow_datatype).unwrap();
1096+
if eval_mode == EvalMode::Try && data_type.is_integer() {
1097+
let op_str = match op {
1098+
DataFusionOperator::Plus => "checked_add",
1099+
DataFusionOperator::Minus => "checked_sub",
1100+
DataFusionOperator::Multiply => "checked_mul",
1101+
DataFusionOperator::Divide => "checked_div",
1102+
_ => {
1103+
todo!("Operator yet to be implemented!");
1104+
}
1105+
};
1106+
let fun_expr = create_comet_physical_fun(
1107+
op_str,
1108+
data_type.clone(),
1109+
&self.session_ctx.state(),
1110+
None,
1111+
)?;
1112+
Ok(Arc::new(ScalarFunctionExpr::new(
1113+
op_str,
1114+
fun_expr,
1115+
vec![left, right],
1116+
Arc::new(Field::new(op_str, data_type, true)),
1117+
)))
1118+
} else {
1119+
Ok(Arc::new(BinaryExpr::new(left, op, right)))
1120+
}
1121+
}
10911122
}
10921123
}
10931124

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// under the License.
1717

1818
use crate::hash_funcs::*;
19+
use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub};
1920
use crate::math_funcs::modulo_expr::spark_modulo;
2021
use crate::{
2122
spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div,
@@ -115,6 +116,18 @@ pub fn create_comet_physical_fun(
115116
data_type
116117
)
117118
}
119+
"checked_add" => {
120+
make_comet_scalar_udf!("checked_add", checked_add, data_type)
121+
}
122+
"checked_sub" => {
123+
make_comet_scalar_udf!("checked_sub", checked_sub, data_type)
124+
}
125+
"checked_mul" => {
126+
make_comet_scalar_udf!("checked_mul", checked_mul, data_type)
127+
}
128+
"checked_div" => {
129+
make_comet_scalar_udf!("checked_div", checked_div, data_type)
130+
}
118131
"murmur3_hash" => {
119132
let func = Arc::new(spark_murmur3_hash);
120133
make_comet_scalar_udf!("murmur3_hash", func, without data_type)
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{Array, ArrowNativeTypeOp, PrimitiveArray, PrimitiveBuilder};
19+
use arrow::array::{ArrayRef, AsArray};
20+
21+
use arrow::datatypes::{ArrowPrimitiveType, DataType, Int32Type, Int64Type};
22+
use datafusion::common::DataFusionError;
23+
use datafusion::physical_plan::ColumnarValue;
24+
use std::sync::Arc;
25+
26+
pub fn try_arithmetic_kernel<T>(
27+
left: &PrimitiveArray<T>,
28+
right: &PrimitiveArray<T>,
29+
op: &str,
30+
) -> Result<ArrayRef, DataFusionError>
31+
where
32+
T: ArrowPrimitiveType,
33+
{
34+
let len = left.len();
35+
let mut builder = PrimitiveBuilder::<T>::with_capacity(len);
36+
match op {
37+
"checked_add" => {
38+
for i in 0..len {
39+
if left.is_null(i) || right.is_null(i) {
40+
builder.append_null();
41+
} else {
42+
builder.append_option(left.value(i).add_checked(right.value(i)).ok());
43+
}
44+
}
45+
}
46+
"checked_sub" => {
47+
for i in 0..len {
48+
if left.is_null(i) || right.is_null(i) {
49+
builder.append_null();
50+
} else {
51+
builder.append_option(left.value(i).sub_checked(right.value(i)).ok());
52+
}
53+
}
54+
}
55+
"checked_mul" => {
56+
for i in 0..len {
57+
if left.is_null(i) || right.is_null(i) {
58+
builder.append_null();
59+
} else {
60+
builder.append_option(left.value(i).mul_checked(right.value(i)).ok());
61+
}
62+
}
63+
}
64+
"checked_div" => {
65+
for i in 0..len {
66+
if left.is_null(i) || right.is_null(i) {
67+
builder.append_null();
68+
} else {
69+
builder.append_option(left.value(i).div_checked(right.value(i)).ok());
70+
}
71+
}
72+
}
73+
_ => {
74+
return Err(DataFusionError::Internal(format!(
75+
"Unsupported operation: {:?}",
76+
op
77+
)))
78+
}
79+
}
80+
81+
Ok(Arc::new(builder.finish()) as ArrayRef)
82+
}
83+
84+
pub fn checked_add(
85+
args: &[ColumnarValue],
86+
data_type: &DataType,
87+
) -> Result<ColumnarValue, DataFusionError> {
88+
checked_arithmetic_internal(args, data_type, "checked_add")
89+
}
90+
91+
pub fn checked_sub(
92+
args: &[ColumnarValue],
93+
data_type: &DataType,
94+
) -> Result<ColumnarValue, DataFusionError> {
95+
checked_arithmetic_internal(args, data_type, "checked_sub")
96+
}
97+
98+
pub fn checked_mul(
99+
args: &[ColumnarValue],
100+
data_type: &DataType,
101+
) -> Result<ColumnarValue, DataFusionError> {
102+
checked_arithmetic_internal(args, data_type, "checked_mul")
103+
}
104+
105+
pub fn checked_div(
106+
args: &[ColumnarValue],
107+
data_type: &DataType,
108+
) -> Result<ColumnarValue, DataFusionError> {
109+
checked_arithmetic_internal(args, data_type, "checked_div")
110+
}
111+
112+
fn checked_arithmetic_internal(
113+
args: &[ColumnarValue],
114+
data_type: &DataType,
115+
op: &str,
116+
) -> Result<ColumnarValue, DataFusionError> {
117+
let left = &args[0];
118+
let right = &args[1];
119+
120+
let (left_arr, right_arr): (ArrayRef, ArrayRef) = match (left, right) {
121+
(ColumnarValue::Array(l), ColumnarValue::Array(r)) => (Arc::clone(l), Arc::clone(r)),
122+
(ColumnarValue::Scalar(l), ColumnarValue::Array(r)) => {
123+
(l.to_array_of_size(r.len())?, Arc::clone(r))
124+
}
125+
(ColumnarValue::Array(l), ColumnarValue::Scalar(r)) => {
126+
(Arc::clone(l), r.to_array_of_size(l.len())?)
127+
}
128+
(ColumnarValue::Scalar(l), ColumnarValue::Scalar(r)) => (l.to_array()?, r.to_array()?),
129+
};
130+
131+
// Rust only supports checked_arithmetic on Int32 and Int64
132+
let result_array = match data_type {
133+
DataType::Int32 => try_arithmetic_kernel::<Int32Type>(
134+
left_arr.as_primitive::<Int32Type>(),
135+
right_arr.as_primitive::<Int32Type>(),
136+
op,
137+
),
138+
DataType::Int64 => try_arithmetic_kernel::<Int64Type>(
139+
left_arr.as_primitive::<Int64Type>(),
140+
right_arr.as_primitive::<Int64Type>(),
141+
op,
142+
),
143+
_ => Err(DataFusionError::Internal(format!(
144+
"Unsupported data type: {:?}",
145+
data_type
146+
))),
147+
};
148+
149+
Ok(ColumnarValue::Array(result_array?))
150+
}

native/spark-expr/src/math_funcs/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// under the License.
1717

1818
mod ceil;
19+
pub(crate) mod checked_arithmetic;
1920
mod div;
2021
mod floor;
2122
pub(crate) mod hex;

spark/src/main/scala/org/apache/comet/serde/arithmetic.scala

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,6 @@ object CometAdd extends CometExpressionSerde[Add] with MathBase {
9494
withInfo(expr, s"Unsupported datatype ${expr.left.dataType}")
9595
return None
9696
}
97-
if (expr.evalMode == EvalMode.TRY) {
98-
withInfo(expr, s"Eval mode ${expr.evalMode} is not supported")
99-
return None
100-
}
10197
createMathExpression(
10298
expr,
10399
expr.left,
@@ -119,10 +115,6 @@ object CometSubtract extends CometExpressionSerde[Subtract] with MathBase {
119115
withInfo(expr, s"Unsupported datatype ${expr.left.dataType}")
120116
return None
121117
}
122-
if (expr.evalMode == EvalMode.TRY) {
123-
withInfo(expr, s"Eval mode ${expr.evalMode} is not supported")
124-
return None
125-
}
126118
createMathExpression(
127119
expr,
128120
expr.left,
@@ -144,10 +136,6 @@ object CometMultiply extends CometExpressionSerde[Multiply] with MathBase {
144136
withInfo(expr, s"Unsupported datatype ${expr.left.dataType}")
145137
return None
146138
}
147-
if (expr.evalMode == EvalMode.TRY) {
148-
withInfo(expr, s"Eval mode ${expr.evalMode} is not supported")
149-
return None
150-
}
151139
createMathExpression(
152140
expr,
153141
expr.left,
@@ -169,15 +157,10 @@ object CometDivide extends CometExpressionSerde[Divide] with MathBase {
169157
// See https://github.com/apache/arrow-datafusion/pull/6792
170158
// For now, use NullIf to swap zeros with nulls.
171159
val rightExpr = nullIfWhenPrimitive(expr.right)
172-
173160
if (!supportedDataType(expr.left.dataType)) {
174161
withInfo(expr, s"Unsupported datatype ${expr.left.dataType}")
175162
return None
176163
}
177-
if (expr.evalMode == EvalMode.TRY) {
178-
withInfo(expr, s"Eval mode ${expr.evalMode} is not supported")
179-
return None
180-
}
181164
createMathExpression(
182165
expr,
183166
expr.left,
@@ -199,10 +182,6 @@ object CometIntegralDivide extends CometExpressionSerde[IntegralDivide] with Mat
199182
withInfo(expr, s"Unsupported datatype ${expr.left.dataType}")
200183
return None
201184
}
202-
if (expr.evalMode == EvalMode.TRY) {
203-
withInfo(expr, s"Eval mode ${expr.evalMode} is not supported")
204-
return None
205-
}
206185

207186
// Precision is set to 19 (max precision for a numerical data type except DecimalType)
208187

0 commit comments

Comments
 (0)