diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 3d6b832aa6b27..d4e0fd953ccba 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -58,13 +58,11 @@ use datafusion_common::{ }; use datafusion_expr::select_expr::SelectExpr; use datafusion_expr::{ - ExplainOption, SortExpr, TableProviderFilterPushDown, UNNAMED_TABLE, case, - dml::InsertOp, - expr::{Alias, ScalarFunction}, - is_null, lit, - utils::COUNT_STAR_EXPANSION, + ExplainOption, ScalarUDF, SortExpr, TableProviderFilterPushDown, UNNAMED_TABLE, case, + dml::InsertOp, is_null, lit, utils::COUNT_STAR_EXPANSION, }; use datafusion_functions::core::coalesce; +use datafusion_functions::math::nanvl; use datafusion_functions_aggregate::expr_fn::{ avg, count, max, median, min, stddev, sum, }; @@ -2471,6 +2469,65 @@ impl DataFrame { &self, value: ScalarValue, columns: Vec, + ) -> Result { + self.fill_columns(value, &columns, coalesce(), |_| true) + } + + // Helper to find columns from names + fn find_columns(&self, names: &[impl AsRef]) -> Result> { + let schema = self.logical_plan().schema(); + names + .iter() + .map(|name| { + let name = name.as_ref(); + schema + .field_with_name(None, name) + .cloned() + .map_err(|_| plan_datafusion_err!("Column '{}' not found", name)) + }) + .collect() + } + + /// Fill NaN values in specified floating-point columns with a given value + /// If no columns are specified (empty slice), applies to all columns + /// Only floating-point columns are affected; other columns are left unchanged + /// Only fills if the value can be cast to the column's type + /// + /// # Arguments + /// * `value` - Value to fill NaNs with + /// * `columns` - List of column names to fill. If empty, fills all columns. + /// + /// # Example + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # use datafusion_common::ScalarValue; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new(); + /// let df = ctx + /// .read_csv("tests/data/example.csv", CsvReadOptions::new()) + /// .await?; + /// // Fill NaN in only columns "a" and "c": + /// let df = df.fill_nan(ScalarValue::from(0.0), &["a", "c"])?; + /// // Fill NaN across all columns: + /// let df = df.fill_nan(ScalarValue::from(0.0), &[])?; + /// # Ok(()) + /// # } + /// ``` + pub fn fill_nan(&self, value: ScalarValue, columns: &[&str]) -> Result { + self.fill_columns(value, columns, nanvl(), |field| { + field.data_type().is_floating() + }) + } + + #[expect(clippy::needless_pass_by_value)] + fn fill_columns( + &self, + value: ScalarValue, + columns: &[impl AsRef], + func: Arc, + applies: impl Fn(&FieldRef) -> bool, ) -> Result { let cols = if columns.is_empty() { self.logical_plan() @@ -2480,28 +2537,21 @@ impl DataFrame { .map(Arc::clone) .collect() } else { - self.find_columns(&columns)? + self.find_columns(columns)? }; - // Create projections for each column let projections = self .logical_plan() .schema() .fields() .iter() .map(|field| { - if cols.contains(field) { + if cols.contains(field) && applies(field) { // Try to cast fill value to column type. If the cast fails, fallback to the original column. match value.clone().cast_to(field.data_type()) { - Ok(fill_value) => Expr::Alias(Alias { - expr: Box::new(Expr::ScalarFunction(ScalarFunction { - func: coalesce(), - args: vec![col(field.name()), lit(fill_value)], - })), - relation: None, - name: field.name().to_string(), - metadata: None, - }), + Ok(fill_value) => func + .call(vec![col(field.name()), lit(fill_value)]) + .alias(field.name()), Err(_) => col(field.name()), } } else { @@ -2513,20 +2563,6 @@ impl DataFrame { self.clone().select(projections) } - // Helper to find columns from names - fn find_columns(&self, names: &[String]) -> Result> { - let schema = self.logical_plan().schema(); - names - .iter() - .map(|name| { - schema - .field_with_name(None, name) - .cloned() - .map_err(|_| plan_datafusion_err!("Column '{}' not found", name)) - }) - .collect() - } - /// Find qualified columns for this dataframe from names /// /// # Arguments diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index bc1ad4c4c6bb1..0155a13607418 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -6539,6 +6539,173 @@ async fn test_fill_null_all_columns() -> Result<()> { Ok(()) } +async fn create_nan_table() -> Result { + // create a DataFrame with a NaN value in a float column "a" and a + // non-float column "b" that must stay untouched by fill_nan. + // "+-----+---+", + // "| a | b |", + // "+-----+---+", + // "| 1.0 | 1 |", + // "| NaN | 2 |", + // "| 3.0 | 3 |", + // "+-----+---+", + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float64, true), + Field::new("b", DataType::Int32, true), + ])); + let a_values = Float64Array::from(vec![Some(1.0), Some(f64::NAN), Some(3.0)]); + let b_values = Int32Array::from(vec![Some(1), Some(2), Some(3)]); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(a_values), Arc::new(b_values)], + )?; + + let ctx = SessionContext::new(); + let table = MemTable::try_new(schema.clone(), vec![vec![batch]])?; + ctx.register_table("t_nan", Arc::new(table))?; + let df = ctx.table("t_nan").await?; + Ok(df) +} + +#[tokio::test] +async fn test_fill_nan() -> Result<()> { + let df = create_nan_table().await?; + + // Fill NaNs in the float column "a" with 0.0. + let df_filled = df.fill_nan(ScalarValue::Float64(Some(0.0)), &["a"])?; + + let results = df_filled.collect().await?; + assert_snapshot!( + batches_to_sort_string(&results), + @r" + +-----+---+ + | a | b | + +-----+---+ + | 0.0 | 2 | + | 1.0 | 1 | + | 3.0 | 3 | + +-----+---+ + " + ); + + Ok(()) +} + +#[tokio::test] +async fn test_fill_nan_all_columns() -> Result<()> { + let df = create_nan_table().await?; + + // Fill NaNs across all columns. Only the float column "a" is affected; + // the non-float column "b" is left unchanged since NaN only exists for + // floating-point types. + let df_filled = df.fill_nan(ScalarValue::Float64(Some(0.0)), &[])?; + + let results = df_filled.collect().await?; + assert_snapshot!( + batches_to_sort_string(&results), + @r" + +-----+---+ + | a | b | + +-----+---+ + | 0.0 | 2 | + | 1.0 | 1 | + | 3.0 | 3 | + +-----+---+ + " + ); + Ok(()) +} + +#[tokio::test] +async fn test_fill_nan_non_float_column() -> Result<()> { + let df = create_nan_table().await?; + + // Explicitly naming a non-float column is a no-op, not an error: NaN does + // not exist for Int32, so column "b" (and the un-targeted "a") are unchanged. + let df_filled = df.fill_nan(ScalarValue::Float64(Some(0.0)), &["b"])?; + + let results = df_filled.collect().await?; + assert_snapshot!( + batches_to_sort_string(&results), + @r" + +-----+---+ + | a | b | + +-----+---+ + | 1.0 | 1 | + | 3.0 | 3 | + | NaN | 2 | + +-----+---+ + " + ); + + Ok(()) +} + +#[tokio::test] +async fn test_fill_nan_unknown_column() -> Result<()> { + let df = create_nan_table().await?; + + // A column name that is not in the schema is propagated as an error. + let err = df + .fill_nan(ScalarValue::Float64(Some(0.0)), &["does_not_exist"]) + .unwrap_err(); + + assert_snapshot!(err.strip_backtrace(), @"Error during planning: Column 'does_not_exist' not found"); + + Ok(()) +} + +#[tokio::test] +async fn test_fill_nan_casts_fill_value() -> Result<()> { + let df = create_nan_table().await?; + + // Int32(0) is not the column's type (Float64) but can be cast to it, so the + // NaN is replaced with 0.0. Exercises the cross-type cast path — the other + // positive tests pass a Float64 value, which skips the actual cast. + let df_filled = df.fill_nan(ScalarValue::Int32(Some(0)), &["a"])?; + + let results = df_filled.collect().await?; + assert_snapshot!( + batches_to_sort_string(&results), + @r" + +-----+---+ + | a | b | + +-----+---+ + | 0.0 | 2 | + | 1.0 | 1 | + | 3.0 | 3 | + +-----+---+ + " + ); + + Ok(()) +} + +#[tokio::test] +async fn test_fill_nan_uncastable_value() -> Result<()> { + let df = create_nan_table().await?; + + // The float column "a" is targeted, but "abc" cannot be cast to Float64, so + // the fill is skipped and column "a" keeps its original NaN value. + let df_filled = df.fill_nan(ScalarValue::Utf8(Some("abc".to_string())), &["a"])?; + + let results = df_filled.collect().await?; + assert_snapshot!( + batches_to_sort_string(&results), + @r" + +-----+---+ + | a | b | + +-----+---+ + | 1.0 | 1 | + | 3.0 | 3 | + | NaN | 2 | + +-----+---+ + " + ); + + Ok(()) +} + #[tokio::test] async fn test_insert_into_casting_support() -> Result<()> { // Testing case1: diff --git a/parquet-testing b/parquet-testing index 107b36603e051..ffdcbb5e22828 160000 --- a/parquet-testing +++ b/parquet-testing @@ -1 +1 @@ -Subproject commit 107b36603e051aee26bd93e04b871034f6c756c0 +Subproject commit ffdcbb5e22828186c7461e56dbd26a0fe3caee56 diff --git a/testing b/testing index 7df2b70baf4f0..9cfebfef8982f 160000 --- a/testing +++ b/testing @@ -1 +1 @@ -Subproject commit 7df2b70baf4f081ebf8e0c6bd22745cf3cbfd824 +Subproject commit 9cfebfef8982fb8612e0a2c59059752bd32321a3