Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 121 additions & 5 deletions crates/iceberg/src/arrow/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ use std::str::FromStr;
use std::sync::Arc;

use arrow_arith::boolean::{and, and_kleene, is_not_null, is_null, not, or, or_kleene};
use arrow_array::cast::AsArray;
use arrow_array::types::{Float32Type, Float64Type};
use arrow_array::{Array, ArrayRef, BooleanArray, Datum as ArrowDatum, RecordBatch, Scalar};
use arrow_buffer::BooleanBuffer;
use arrow_cast::cast::cast;
use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq};
use arrow_schema::{
Expand Down Expand Up @@ -1509,6 +1512,35 @@ fn project_column(
}
}

fn compute_is_nan(array: &ArrayRef) -> std::result::Result<BooleanArray, ArrowError> {
// Compute NaN over the contiguous values slice, then fold the null bitmap
// in with a single bitwise AND so that null slots become false.
let (is_nan, nulls) = match array.data_type() {
DataType::Float32 => {
let arr = array.as_primitive::<Float32Type>();
(
BooleanBuffer::from_iter(arr.values().iter().map(|v| v.is_nan())),
arr.nulls(),
)
}
DataType::Float64 => {
let arr = array.as_primitive::<Float64Type>();
(
BooleanBuffer::from_iter(arr.values().iter().map(|v| v.is_nan())),
arr.nulls(),
)
}
_ => unreachable!("is_nan is only valid for float types"),
};

let values = match nulls {
Some(nulls) => &is_nan & nulls.inner(),
None => is_nan,
};

Ok(BooleanArray::new(values, None))
}

type PredicateResult =
dyn FnMut(RecordBatch) -> std::result::Result<BooleanArray, ArrowError> + Send + 'static;

Expand Down Expand Up @@ -1591,8 +1623,11 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
reference: &BoundReference,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if self.bound_reference(reference)?.is_some() {
self.build_always_true()
if let Some(idx) = self.bound_reference(reference)? {
Ok(Box::new(move |batch| {
let column = project_column(&batch, idx)?;
compute_is_nan(&column)
}))
} else {
// A missing column, treating it as null.
self.build_always_false()
Expand All @@ -1604,8 +1639,12 @@ impl BoundPredicateVisitor for PredicateConverter<'_> {
reference: &BoundReference,
_predicate: &BoundPredicate,
) -> Result<Box<PredicateResult>> {
if self.bound_reference(reference)?.is_some() {
self.build_always_false()
if let Some(idx) = self.bound_reference(reference)? {
Ok(Box::new(move |batch| {
let column = project_column(&batch, idx)?;
let is_nan = compute_is_nan(&column)?;
not(&is_nan)
}))
} else {
// A missing column, treating it as null.
self.build_always_true()
Expand Down Expand Up @@ -2002,7 +2041,7 @@ mod tests {
use std::sync::Arc;

use arrow_array::cast::AsArray;
use arrow_array::{ArrayRef, LargeStringArray, RecordBatch, StringArray};
use arrow_array::{Array, ArrayRef, BooleanArray, LargeStringArray, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema as ArrowSchema, TimeUnit};
use futures::TryStreamExt;
use parquet::arrow::arrow_reader::{RowSelection, RowSelector};
Expand Down Expand Up @@ -5464,4 +5503,81 @@ message schema {
ts_array.value(0)
);
}

fn apply_predicate_to_batch(
predicate: Predicate,
schema: SchemaRef,
batch: RecordBatch,
) -> BooleanArray {
use super::PredicateConverter;

let bound = predicate.bind(schema, true).unwrap();

// Build a trivial Parquet schema with one float column at field id 4
let message_type = "
message schema {
optional float qux = 4;
}
";
let parquet_type = parse_message_type(message_type).expect("parse schema");
let parquet_schema = SchemaDescriptor::new(Arc::new(parquet_type));

let column_map = HashMap::from([(4i32, 0usize)]);
let column_indices = vec![0usize];

let mut converter = PredicateConverter {
parquet_schema: &parquet_schema,
column_map: &column_map,
column_indices: &column_indices,
};

let mut predicate_fn = visit(&mut converter, &bound).unwrap();
predicate_fn(batch).unwrap()
}

#[test]
fn test_predicate_converter_nan() {
use arrow_array::Float32Array;

let schema = table_schema_simple();
let arrow_schema = Arc::new(ArrowSchema::new(vec![Field::new(
"qux",
DataType::Float32,
true,
)]));
let values = vec![Some(1.0f32), Some(f32::NAN), None, Some(0.0f32)];

// is_nan: non-null-propagating per Java's implementation - NULL → false
let batch = RecordBatch::try_new(arrow_schema.clone(), vec![Arc::new(Float32Array::from(
values.clone(),
))])
.unwrap();
let result =
apply_predicate_to_batch(Reference::new("qux").is_nan(), schema.clone(), batch);
assert_eq!(
[
result.value(0),
result.value(1),
result.value(2),
result.value(3)
],
[false, true, false, false]
);
assert!(!result.is_null(2));

// not_nan: non-null-propagating per Java's implementation - NULL → true
let batch =
RecordBatch::try_new(arrow_schema, vec![Arc::new(Float32Array::from(values))]).unwrap();
let result = apply_predicate_to_batch(Reference::new("qux").is_not_nan(), schema, batch);
assert_eq!(
[
result.value(0),
result.value(1),
result.value(2),
result.value(3)
],
[true, false, true, true]
);
assert!(!result.is_null(2));
}
}
Loading