@@ -26,7 +26,7 @@ use crate::error::DataFusionError;
26
26
use crate :: logical_plan:: dfschema:: DFSchemaRef ;
27
27
use crate :: sql:: parser:: FileType ;
28
28
use arrow:: datatypes:: { DataType , Field , Schema , SchemaRef } ;
29
- use datafusion_common:: DFSchema ;
29
+ use datafusion_common:: { DFField , DFSchema } ;
30
30
use std:: fmt:: Formatter ;
31
31
use std:: {
32
32
collections:: HashSet ,
@@ -268,21 +268,79 @@ pub struct Limit {
268
268
#[ derive( Clone ) ]
269
269
pub struct Subquery {
270
270
/// The list of sub queries
271
- pub subqueries : Vec < LogicalPlan > ,
271
+ pub subqueries : Vec < ( LogicalPlan , SubqueryType ) > ,
272
272
/// The incoming logical plan
273
273
pub input : Arc < LogicalPlan > ,
274
274
/// The schema description of the output
275
275
pub schema : DFSchemaRef ,
276
276
}
277
277
278
+ /// Subquery type
279
+ #[ derive( Debug , Clone , Copy , PartialEq ) ]
280
+ pub enum SubqueryType {
281
+ /// Scalar (SELECT, WHERE) evaluating to one value
282
+ Scalar ,
283
+ /// EXISTS(...) evaluating to true if at least one row was produced
284
+ Exists ,
285
+ /// ANY(...)/ALL(...)
286
+ AnyAll ,
287
+ }
288
+
278
289
impl Subquery {
279
290
/// Merge schema of main input and correlated subquery columns
280
- pub fn merged_schema ( input : & LogicalPlan , subqueries : & [ LogicalPlan ] ) -> DFSchema {
281
- subqueries. iter ( ) . fold ( ( * * input. schema ( ) ) . clone ( ) , |a, b| {
282
- let mut res = a;
283
- res. merge ( b. schema ( ) ) ;
284
- res
285
- } )
291
+ pub fn merged_schema (
292
+ input : & LogicalPlan ,
293
+ subqueries : & [ ( LogicalPlan , SubqueryType ) ] ,
294
+ ) -> DFSchema {
295
+ subqueries
296
+ . iter ( )
297
+ . fold ( ( * * input. schema ( ) ) . clone ( ) , |input_schema, ( plan, typ) | {
298
+ let mut res = input_schema;
299
+ let subquery_schema = Self :: transform_dfschema ( plan. schema ( ) , * typ) ;
300
+ res. merge ( & subquery_schema) ;
301
+ res
302
+ } )
303
+ }
304
+
305
+ /// Transform DataFusion schema according to subquery type
306
+ pub fn transform_dfschema ( schema : & DFSchema , typ : SubqueryType ) -> DFSchema {
307
+ match typ {
308
+ SubqueryType :: Scalar => schema. clone ( ) ,
309
+ SubqueryType :: Exists | SubqueryType :: AnyAll => {
310
+ let new_fields = schema
311
+ . fields ( )
312
+ . iter ( )
313
+ . map ( |field| {
314
+ let new_field = Subquery :: transform_field ( field. field ( ) , typ) ;
315
+ if let Some ( qualifier) = field. qualifier ( ) {
316
+ DFField :: from_qualified ( qualifier, new_field)
317
+ } else {
318
+ DFField :: from ( new_field)
319
+ }
320
+ } )
321
+ . collect ( ) ;
322
+ DFSchema :: new_with_metadata ( new_fields, schema. metadata ( ) . clone ( ) )
323
+ . unwrap ( )
324
+ }
325
+ }
326
+ }
327
+
328
+ /// Transform Arrow field according to subquery type
329
+ pub fn transform_field ( field : & Field , typ : SubqueryType ) -> Field {
330
+ match typ {
331
+ SubqueryType :: Scalar => field. clone ( ) ,
332
+ SubqueryType :: Exists => Field :: new ( field. name ( ) , DataType :: Boolean , false ) ,
333
+ SubqueryType :: AnyAll => {
334
+ let item = Field :: new_dict (
335
+ "item" ,
336
+ field. data_type ( ) . clone ( ) ,
337
+ true ,
338
+ field. dict_id ( ) . unwrap_or ( 0 ) ,
339
+ field. dict_is_ordered ( ) . unwrap_or ( false ) ,
340
+ ) ;
341
+ Field :: new ( field. name ( ) , DataType :: List ( Box :: new ( item) ) , false )
342
+ }
343
+ }
286
344
}
287
345
}
288
346
@@ -585,7 +643,7 @@ impl LogicalPlan {
585
643
input, subqueries, ..
586
644
} ) => vec ! [ input. as_ref( ) ]
587
645
. into_iter ( )
588
- . chain ( subqueries. iter ( ) )
646
+ . chain ( subqueries. iter ( ) . map ( | ( q , _ ) | q ) )
589
647
. collect ( ) ,
590
648
LogicalPlan :: Filter ( Filter { input, .. } ) => vec ! [ input] ,
591
649
LogicalPlan :: Repartition ( Repartition { input, .. } ) => vec ! [ input] ,
@@ -728,7 +786,7 @@ impl LogicalPlan {
728
786
input, subqueries, ..
729
787
} ) => {
730
788
input. accept ( visitor) ?;
731
- for input in subqueries {
789
+ for ( input, _ ) in subqueries {
732
790
if !input. accept ( visitor) ? {
733
791
return Ok ( false ) ;
734
792
}
0 commit comments