diff --git a/arrow/src/util/data_gen.rs b/arrow/src/util/data_gen.rs index 42a0798f5540..7ea05811d55b 100644 --- a/arrow/src/util/data_gen.rs +++ b/arrow/src/util/data_gen.rs @@ -55,6 +55,14 @@ pub fn create_random_batch( /// Create a random [ArrayRef] from a [DataType] with a length, /// null density and true density (for [BooleanArray]). +/// +/// # Arguments +/// +/// * `field` - The field containing the data type for which to create a random array +/// * `size` - The number of elements in the generated array +/// * `null_density` - The approximate fraction of null values in the resulting array (0.0 to 1.0) +/// * `true_density` - The approximate fraction of true values in boolean arrays (0.0 to 1.0) +/// pub fn create_random_array( field: &Field, size: usize, @@ -215,6 +223,8 @@ pub fn create_random_array( crate::compute::cast(&v, d)? } Map(_, _) => create_random_map_array(field, size, null_density, true_density)?, + Decimal128(_, _) => create_random_decimal_array(field, size, null_density)?, + Decimal256(_, _) => create_random_decimal_array(field, size, null_density)?, other => { return Err(ArrowError::NotYetImplemented(format!( "Generating random arrays not yet implemented for {other:?}" @@ -223,6 +233,45 @@ pub fn create_random_array( }) } +#[inline] +fn create_random_decimal_array(field: &Field, size: usize, null_density: f32) -> Result { + let mut rng = seedable_rng(); + + match field.data_type() { + DataType::Decimal128(precision, scale) => { + let values = (0..size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random::()) + } + }) + .collect::>(); + Ok(Arc::new( + Decimal128Array::from(values).with_precision_and_scale(*precision, *scale)?, + )) + } + DataType::Decimal256(precision, scale) => { + let values = (0..size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(i256::from_parts(rng.random::(), rng.random::())) + } + }) + .collect::>(); + Ok(Arc::new( + Decimal256Array::from(values).with_precision_and_scale(*precision, *scale)?, + )) + } + _ => Err(ArrowError::InvalidArgumentError(format!( + "Cannot create decimal array for field {field:?}" + ))), + } +} + #[inline] fn create_random_list_array( field: &Field, @@ -745,4 +794,22 @@ mod tests { assert_eq!(array.as_map().keys().data_type(), &DataType::Utf8); assert_eq!(array.as_map().values().data_type(), &DataType::Utf8); } + + #[test] + fn test_create_decimal_array() { + let size = 10; + let fields = vec![ + Field::new("a", DataType::Decimal128(10, -2), true), + Field::new("b", DataType::Decimal256(10, -2), true), + ]; + let schema = Schema::new(fields); + let schema_ref = Arc::new(schema); + let batch = create_random_batch(schema_ref.clone(), size, 0.35, 0.7).unwrap(); + + assert_eq!(batch.schema(), schema_ref); + assert_eq!(batch.num_columns(), schema_ref.fields().len()); + for array in batch.columns() { + assert_eq!(array.len(), size); + } + } }