Skip to content

Commit 672b023

Browse files
committed
add test case for proto udf decode fallback
1 parent 185df60 commit 672b023

File tree

1 file changed

+72
-12
lines changed

1 file changed

+72
-12
lines changed

datafusion/proto/tests/cases/roundtrip_logical_plan.rs

Lines changed: 72 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,21 @@ fn roundtrip_json_test(proto: &protobuf::LogicalExprNode) {
113113
#[cfg(not(feature = "json"))]
114114
fn roundtrip_json_test(_proto: &protobuf::LogicalExprNode) {}
115115

116-
// Given a DataFusion logical Expr, convert it to protobuf and back, using debug formatting to test
117-
// equality.
118116
fn roundtrip_expr_test(initial_struct: Expr, ctx: SessionContext) {
119117
let extension_codec = DefaultLogicalExtensionCodec {};
120-
let proto: protobuf::LogicalExprNode =
121-
serialize_expr(&initial_struct, &extension_codec)
122-
.unwrap_or_else(|e| panic!("Error serializing expression: {:?}", e));
123-
let round_trip: Expr =
124-
from_proto::parse_expr(&proto, &ctx, &extension_codec).unwrap();
118+
roundtrip_expr_test_with_codec(initial_struct, ctx, &extension_codec);
119+
}
120+
121+
// Given a DataFusion logical Expr, convert it to protobuf and back, using debug formatting to test
122+
// equality.
123+
fn roundtrip_expr_test_with_codec(
124+
initial_struct: Expr,
125+
ctx: SessionContext,
126+
codec: &dyn LogicalExtensionCodec,
127+
) {
128+
let proto: protobuf::LogicalExprNode = serialize_expr(&initial_struct, codec)
129+
.unwrap_or_else(|e| panic!("Error serializing expression: {:?}", e));
130+
let round_trip: Expr = from_proto::parse_expr(&proto, &ctx, codec).unwrap();
125131

126132
assert_eq!(format!("{:?}", &initial_struct), format!("{round_trip:?}"));
127133

@@ -2185,22 +2191,26 @@ fn roundtrip_aggregate_udf() {
21852191
roundtrip_expr_test(test_expr, ctx);
21862192
}
21872193

2188-
#[test]
2189-
fn roundtrip_scalar_udf() {
2194+
fn dummy_udf() -> ScalarUDF {
21902195
let scalar_fn = Arc::new(|args: &[ColumnarValue]| {
21912196
let ColumnarValue::Array(array) = &args[0] else {
21922197
panic!("should be array")
21932198
};
21942199
Ok(ColumnarValue::from(Arc::new(array.clone()) as ArrayRef))
21952200
});
21962201

2197-
let udf = create_udf(
2202+
create_udf(
21982203
"dummy",
21992204
vec![DataType::Utf8],
22002205
DataType::Utf8,
22012206
Volatility::Immutable,
22022207
scalar_fn,
2203-
);
2208+
)
2209+
}
2210+
2211+
#[test]
2212+
fn roundtrip_scalar_udf() {
2213+
let udf = dummy_udf();
22042214

22052215
let test_expr = Expr::ScalarFunction(ScalarFunction::new_udf(
22062216
Arc::new(udf.clone()),
@@ -2210,7 +2220,57 @@ fn roundtrip_scalar_udf() {
22102220
let ctx = SessionContext::new();
22112221
ctx.register_udf(udf);
22122222

2213-
roundtrip_expr_test(test_expr, ctx);
2223+
roundtrip_expr_test(test_expr.clone(), ctx);
2224+
2225+
// Now test loading the UDF without registering it in the context, but rather creating it in the
2226+
// extension codec.
2227+
#[derive(Debug)]
2228+
struct DummyUDFExtensionCodec;
2229+
2230+
impl LogicalExtensionCodec for DummyUDFExtensionCodec {
2231+
fn try_decode(
2232+
&self,
2233+
_buf: &[u8],
2234+
_inputs: &[LogicalPlan],
2235+
_ctx: &SessionContext,
2236+
) -> Result<Extension> {
2237+
not_impl_err!("LogicalExtensionCodec is not provided")
2238+
}
2239+
2240+
fn try_encode(&self, _node: &Extension, _buf: &mut Vec<u8>) -> Result<()> {
2241+
not_impl_err!("LogicalExtensionCodec is not provided")
2242+
}
2243+
2244+
fn try_decode_table_provider(
2245+
&self,
2246+
_buf: &[u8],
2247+
_table_ref: &TableReference,
2248+
_schema: SchemaRef,
2249+
_ctx: &SessionContext,
2250+
) -> Result<Arc<dyn TableProvider>> {
2251+
not_impl_err!("LogicalExtensionCodec is not provided")
2252+
}
2253+
2254+
fn try_encode_table_provider(
2255+
&self,
2256+
_table_ref: &TableReference,
2257+
_node: Arc<dyn TableProvider>,
2258+
_buf: &mut Vec<u8>,
2259+
) -> Result<()> {
2260+
not_impl_err!("LogicalExtensionCodec is not provided")
2261+
}
2262+
2263+
fn try_decode_udf(&self, name: &str, _buf: &[u8]) -> Result<Arc<ScalarUDF>> {
2264+
if name == "dummy" {
2265+
Ok(Arc::new(dummy_udf()))
2266+
} else {
2267+
return Err(DataFusionError::Internal(format!("UDF {name} not found")));
2268+
}
2269+
}
2270+
}
2271+
2272+
let ctx = SessionContext::new();
2273+
roundtrip_expr_test_with_codec(test_expr, ctx, &DummyUDFExtensionCodec)
22142274
}
22152275

22162276
#[test]

0 commit comments

Comments
 (0)