@@ -113,15 +113,21 @@ fn roundtrip_json_test(proto: &protobuf::LogicalExprNode) {
113
113
#[ cfg( not( feature = "json" ) ) ]
114
114
fn roundtrip_json_test ( _proto : & protobuf:: LogicalExprNode ) { }
115
115
116
- // Given a DataFusion logical Expr, convert it to protobuf and back, using debug formatting to test
117
- // equality.
118
116
fn roundtrip_expr_test ( initial_struct : Expr , ctx : SessionContext ) {
119
117
let extension_codec = DefaultLogicalExtensionCodec { } ;
120
- let proto: protobuf:: LogicalExprNode =
121
- serialize_expr ( & initial_struct, & extension_codec)
122
- . unwrap_or_else ( |e| panic ! ( "Error serializing expression: {:?}" , e) ) ;
123
- let round_trip: Expr =
124
- from_proto:: parse_expr ( & proto, & ctx, & extension_codec) . unwrap ( ) ;
118
+ roundtrip_expr_test_with_codec ( initial_struct, ctx, & extension_codec) ;
119
+ }
120
+
121
+ // Given a DataFusion logical Expr, convert it to protobuf and back, using debug formatting to test
122
+ // equality.
123
+ fn roundtrip_expr_test_with_codec (
124
+ initial_struct : Expr ,
125
+ ctx : SessionContext ,
126
+ codec : & dyn LogicalExtensionCodec ,
127
+ ) {
128
+ let proto: protobuf:: LogicalExprNode = serialize_expr ( & initial_struct, codec)
129
+ . unwrap_or_else ( |e| panic ! ( "Error serializing expression: {:?}" , e) ) ;
130
+ let round_trip: Expr = from_proto:: parse_expr ( & proto, & ctx, codec) . unwrap ( ) ;
125
131
126
132
assert_eq ! ( format!( "{:?}" , & initial_struct) , format!( "{round_trip:?}" ) ) ;
127
133
@@ -2185,22 +2191,26 @@ fn roundtrip_aggregate_udf() {
2185
2191
roundtrip_expr_test ( test_expr, ctx) ;
2186
2192
}
2187
2193
2188
- #[ test]
2189
- fn roundtrip_scalar_udf ( ) {
2194
+ fn dummy_udf ( ) -> ScalarUDF {
2190
2195
let scalar_fn = Arc :: new ( |args : & [ ColumnarValue ] | {
2191
2196
let ColumnarValue :: Array ( array) = & args[ 0 ] else {
2192
2197
panic ! ( "should be array" )
2193
2198
} ;
2194
2199
Ok ( ColumnarValue :: from ( Arc :: new ( array. clone ( ) ) as ArrayRef ) )
2195
2200
} ) ;
2196
2201
2197
- let udf = create_udf (
2202
+ create_udf (
2198
2203
"dummy" ,
2199
2204
vec ! [ DataType :: Utf8 ] ,
2200
2205
DataType :: Utf8 ,
2201
2206
Volatility :: Immutable ,
2202
2207
scalar_fn,
2203
- ) ;
2208
+ )
2209
+ }
2210
+
2211
+ #[ test]
2212
+ fn roundtrip_scalar_udf ( ) {
2213
+ let udf = dummy_udf ( ) ;
2204
2214
2205
2215
let test_expr = Expr :: ScalarFunction ( ScalarFunction :: new_udf (
2206
2216
Arc :: new ( udf. clone ( ) ) ,
@@ -2210,7 +2220,57 @@ fn roundtrip_scalar_udf() {
2210
2220
let ctx = SessionContext :: new ( ) ;
2211
2221
ctx. register_udf ( udf) ;
2212
2222
2213
- roundtrip_expr_test ( test_expr, ctx) ;
2223
+ roundtrip_expr_test ( test_expr. clone ( ) , ctx) ;
2224
+
2225
+ // Now test loading the UDF without registering it in the context, but rather creating it in the
2226
+ // extension codec.
2227
+ #[ derive( Debug ) ]
2228
+ struct DummyUDFExtensionCodec ;
2229
+
2230
+ impl LogicalExtensionCodec for DummyUDFExtensionCodec {
2231
+ fn try_decode (
2232
+ & self ,
2233
+ _buf : & [ u8 ] ,
2234
+ _inputs : & [ LogicalPlan ] ,
2235
+ _ctx : & SessionContext ,
2236
+ ) -> Result < Extension > {
2237
+ not_impl_err ! ( "LogicalExtensionCodec is not provided" )
2238
+ }
2239
+
2240
+ fn try_encode ( & self , _node : & Extension , _buf : & mut Vec < u8 > ) -> Result < ( ) > {
2241
+ not_impl_err ! ( "LogicalExtensionCodec is not provided" )
2242
+ }
2243
+
2244
+ fn try_decode_table_provider (
2245
+ & self ,
2246
+ _buf : & [ u8 ] ,
2247
+ _table_ref : & TableReference ,
2248
+ _schema : SchemaRef ,
2249
+ _ctx : & SessionContext ,
2250
+ ) -> Result < Arc < dyn TableProvider > > {
2251
+ not_impl_err ! ( "LogicalExtensionCodec is not provided" )
2252
+ }
2253
+
2254
+ fn try_encode_table_provider (
2255
+ & self ,
2256
+ _table_ref : & TableReference ,
2257
+ _node : Arc < dyn TableProvider > ,
2258
+ _buf : & mut Vec < u8 > ,
2259
+ ) -> Result < ( ) > {
2260
+ not_impl_err ! ( "LogicalExtensionCodec is not provided" )
2261
+ }
2262
+
2263
+ fn try_decode_udf ( & self , name : & str , _buf : & [ u8 ] ) -> Result < Arc < ScalarUDF > > {
2264
+ if name == "dummy" {
2265
+ Ok ( Arc :: new ( dummy_udf ( ) ) )
2266
+ } else {
2267
+ return Err ( DataFusionError :: Internal ( format ! ( "UDF {name} not found" ) ) ) ;
2268
+ }
2269
+ }
2270
+ }
2271
+
2272
+ let ctx = SessionContext :: new ( ) ;
2273
+ roundtrip_expr_test_with_codec ( test_expr, ctx, & DummyUDFExtensionCodec )
2214
2274
}
2215
2275
2216
2276
#[ test]
0 commit comments