diff --git a/crates/iceberg/src/arrow/reader.rs b/crates/iceberg/src/arrow/reader.rs index 488f41cf20..5cd97410ea 100644 --- a/crates/iceberg/src/arrow/reader.rs +++ b/crates/iceberg/src/arrow/reader.rs @@ -23,7 +23,10 @@ use std::str::FromStr; use std::sync::Arc; use arrow_arith::boolean::{and, and_kleene, is_not_null, is_null, not, or, or_kleene}; +use arrow_array::cast::AsArray; +use arrow_array::types::{Float32Type, Float64Type}; use arrow_array::{Array, ArrayRef, BooleanArray, Datum as ArrowDatum, RecordBatch, Scalar}; +use arrow_buffer::BooleanBuffer; use arrow_cast::cast::cast; use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq}; use arrow_schema::{ @@ -1509,6 +1512,35 @@ fn project_column( } } +fn compute_is_nan(array: &ArrayRef) -> std::result::Result { + // Compute NaN over the contiguous values slice, then fold the null bitmap + // in with a single bitwise AND so that null slots become false. + let (is_nan, nulls) = match array.data_type() { + DataType::Float32 => { + let arr = array.as_primitive::(); + ( + BooleanBuffer::from_iter(arr.values().iter().map(|v| v.is_nan())), + arr.nulls(), + ) + } + DataType::Float64 => { + let arr = array.as_primitive::(); + ( + BooleanBuffer::from_iter(arr.values().iter().map(|v| v.is_nan())), + arr.nulls(), + ) + } + _ => unreachable!("is_nan is only valid for float types"), + }; + + let values = match nulls { + Some(nulls) => &is_nan & nulls.inner(), + None => is_nan, + }; + + Ok(BooleanArray::new(values, None)) +} + type PredicateResult = dyn FnMut(RecordBatch) -> std::result::Result + Send + 'static; @@ -1591,8 +1623,11 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { reference: &BoundReference, _predicate: &BoundPredicate, ) -> Result> { - if self.bound_reference(reference)?.is_some() { - self.build_always_true() + if let Some(idx) = self.bound_reference(reference)? { + Ok(Box::new(move |batch| { + let column = project_column(&batch, idx)?; + compute_is_nan(&column) + })) } else { // A missing column, treating it as null. self.build_always_false() @@ -1604,8 +1639,12 @@ impl BoundPredicateVisitor for PredicateConverter<'_> { reference: &BoundReference, _predicate: &BoundPredicate, ) -> Result> { - if self.bound_reference(reference)?.is_some() { - self.build_always_false() + if let Some(idx) = self.bound_reference(reference)? { + Ok(Box::new(move |batch| { + let column = project_column(&batch, idx)?; + let is_nan = compute_is_nan(&column)?; + not(&is_nan) + })) } else { // A missing column, treating it as null. self.build_always_true() @@ -2002,7 +2041,7 @@ mod tests { use std::sync::Arc; use arrow_array::cast::AsArray; - use arrow_array::{ArrayRef, LargeStringArray, RecordBatch, StringArray}; + use arrow_array::{Array, ArrayRef, BooleanArray, LargeStringArray, RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema as ArrowSchema, TimeUnit}; use futures::TryStreamExt; use parquet::arrow::arrow_reader::{RowSelection, RowSelector}; @@ -5464,4 +5503,81 @@ message schema { ts_array.value(0) ); } + + fn apply_predicate_to_batch( + predicate: Predicate, + schema: SchemaRef, + batch: RecordBatch, + ) -> BooleanArray { + use super::PredicateConverter; + + let bound = predicate.bind(schema, true).unwrap(); + + // Build a trivial Parquet schema with one float column at field id 4 + let message_type = " + message schema { + optional float qux = 4; + } + "; + let parquet_type = parse_message_type(message_type).expect("parse schema"); + let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_type)); + + let column_map = HashMap::from([(4i32, 0usize)]); + let column_indices = vec![0usize]; + + let mut converter = PredicateConverter { + parquet_schema: &parquet_schema, + column_map: &column_map, + column_indices: &column_indices, + }; + + let mut predicate_fn = visit(&mut converter, &bound).unwrap(); + predicate_fn(batch).unwrap() + } + + #[test] + fn test_predicate_converter_nan() { + use arrow_array::Float32Array; + + let schema = table_schema_simple(); + let arrow_schema = Arc::new(ArrowSchema::new(vec![Field::new( + "qux", + DataType::Float32, + true, + )])); + let values = vec![Some(1.0f32), Some(f32::NAN), None, Some(0.0f32)]; + + // is_nan: non-null-propagating per Java's implementation - NULL → false + let batch = RecordBatch::try_new(arrow_schema.clone(), vec![Arc::new(Float32Array::from( + values.clone(), + ))]) + .unwrap(); + let result = + apply_predicate_to_batch(Reference::new("qux").is_nan(), schema.clone(), batch); + assert_eq!( + [ + result.value(0), + result.value(1), + result.value(2), + result.value(3) + ], + [false, true, false, false] + ); + assert!(!result.is_null(2)); + + // not_nan: non-null-propagating per Java's implementation - NULL → true + let batch = + RecordBatch::try_new(arrow_schema, vec![Arc::new(Float32Array::from(values))]).unwrap(); + let result = apply_predicate_to_batch(Reference::new("qux").is_not_nan(), schema, batch); + assert_eq!( + [ + result.value(0), + result.value(1), + result.value(2), + result.value(3) + ], + [true, false, true, true] + ); + assert!(!result.is_null(2)); + } }