diff --git a/crates/fluss/src/metadata/datatype.rs b/crates/fluss/src/metadata/datatype.rs index 4deed2bc..8ad4f7e5 100644 --- a/crates/fluss/src/metadata/datatype.rs +++ b/crates/fluss/src/metadata/datatype.rs @@ -96,25 +96,25 @@ impl DataType { impl Display for DataType { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - DataType::Boolean(v) => write!(f, "{}", v), - DataType::TinyInt(v) => write!(f, "{}", v), - DataType::SmallInt(v) => write!(f, "{}", v), - DataType::Int(v) => write!(f, "{}", v), - DataType::BigInt(v) => write!(f, "{}", v), - DataType::Float(v) => write!(f, "{}", v), - DataType::Double(v) => write!(f, "{}", v), - DataType::Char(v) => write!(f, "{}", v), - DataType::String(v) => write!(f, "{}", v), - DataType::Decimal(v) => write!(f, "{}", v), - DataType::Date(v) => write!(f, "{}", v), - DataType::Time(v) => write!(f, "{}", v), - DataType::Timestamp(v) => write!(f, "{}", v), - DataType::TimestampLTz(v) => write!(f, "{}", v), - DataType::Bytes(v) => write!(f, "{}", v), - DataType::Binary(v) => write!(f, "{}", v), - DataType::Array(v) => write!(f, "{}", v), - DataType::Map(v) => write!(f, "{}", v), - DataType::Row(v) => write!(f, "{}", v), + DataType::Boolean(v) => write!(f, "{v}"), + DataType::TinyInt(v) => write!(f, "{v}"), + DataType::SmallInt(v) => write!(f, "{v}"), + DataType::Int(v) => write!(f, "{v}"), + DataType::BigInt(v) => write!(f, "{v}"), + DataType::Float(v) => write!(f, "{v}"), + DataType::Double(v) => write!(f, "{v}"), + DataType::Char(v) => write!(f, "{v}"), + DataType::String(v) => write!(f, "{v}"), + DataType::Decimal(v) => write!(f, "{v}"), + DataType::Date(v) => write!(f, "{v}"), + DataType::Time(v) => write!(f, "{v}"), + DataType::Timestamp(v) => write!(f, "{v}"), + DataType::TimestampLTz(v) => write!(f, "{v}"), + DataType::Bytes(v) => write!(f, "{v}"), + DataType::Binary(v) => write!(f, "{v}"), + DataType::Array(v) => write!(f, "{v}"), + DataType::Map(v) => write!(f, "{v}"), + DataType::Row(v) => write!(f, "{v}"), } } } @@ -861,7 +861,7 @@ impl Display for RowType { if i > 0 { write!(f, ", ")?; } - write!(f, "{}", field)?; + write!(f, "{field}")?; } write!(f, ">")?; if !self.nullable { diff --git a/crates/fluss/src/record/arrow.rs b/crates/fluss/src/record/arrow.rs index 29bfe41a..e46093dd 100644 --- a/crates/fluss/src/record/arrow.rs +++ b/crates/fluss/src/record/arrow.rs @@ -589,16 +589,84 @@ pub fn to_arrow_type(fluss_type: &DataType) -> ArrowDataType { DataType::Double(_) => ArrowDataType::Float64, DataType::Char(_) => ArrowDataType::Utf8, DataType::String(_) => ArrowDataType::Utf8, - DataType::Decimal(_) => todo!(), + DataType::Decimal(decimal_type) => ArrowDataType::Decimal128( + decimal_type + .precision() + .try_into() + .expect("precision exceeds u8::MAX"), + decimal_type + .scale() + .try_into() + .expect("scale exceeds i8::MAX"), + ), DataType::Date(_) => ArrowDataType::Date32, - DataType::Time(_) => todo!(), - DataType::Timestamp(_) => todo!(), - DataType::TimestampLTz(_) => todo!(), - DataType::Bytes(_) => todo!(), - DataType::Binary(_) => todo!(), - DataType::Array(_data_type) => todo!(), - DataType::Map(_data_type) => todo!(), - DataType::Row(_data_fields) => todo!(), + DataType::Time(time_type) => match time_type.precision() { + 0 => ArrowDataType::Time32(arrow_schema::TimeUnit::Second), + 1..=3 => ArrowDataType::Time32(arrow_schema::TimeUnit::Millisecond), + 4..=6 => ArrowDataType::Time64(arrow_schema::TimeUnit::Microsecond), + 7..=9 => ArrowDataType::Time64(arrow_schema::TimeUnit::Nanosecond), + // This arm should never be reached due to validation in TimeType. + invalid => panic!("Invalid precision value for TimeType: {invalid}"), + }, + DataType::Timestamp(timestamp_type) => match timestamp_type.precision() { + 0 => ArrowDataType::Timestamp(arrow_schema::TimeUnit::Second, None), + 1..=3 => ArrowDataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None), + 4..=6 => ArrowDataType::Timestamp(arrow_schema::TimeUnit::Microsecond, None), + 7..=9 => ArrowDataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None), + // This arm should never be reached due to validation in Timestamp. + invalid => panic!("Invalid precision value for TimestampType: {invalid}"), + }, + DataType::TimestampLTz(timestamp_ltz_type) => match timestamp_ltz_type.precision() { + 0 => ArrowDataType::Timestamp(arrow_schema::TimeUnit::Second, None), + 1..=3 => ArrowDataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None), + 4..=6 => ArrowDataType::Timestamp(arrow_schema::TimeUnit::Microsecond, None), + 7..=9 => ArrowDataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None), + // This arm should never be reached due to validation in TimestampLTz. + invalid => panic!("Invalid precision value for TimestampLTzType: {invalid}"), + }, + DataType::Bytes(_) => ArrowDataType::Binary, + DataType::Binary(binary_type) => ArrowDataType::FixedSizeBinary( + binary_type + .length() + .try_into() + .expect("length exceeds i32::MAX"), + ), + DataType::Array(array_type) => ArrowDataType::List( + Field::new_list_field( + to_arrow_type(array_type.get_element_type()), + fluss_type.is_nullable(), + ) + .into(), + ), + DataType::Map(map_type) => { + let key_type = to_arrow_type(map_type.key_type()); + let value_type = to_arrow_type(map_type.value_type()); + let entry_fields = vec![ + Field::new("key", key_type, map_type.key_type().is_nullable()), + Field::new("value", value_type, map_type.value_type().is_nullable()), + ]; + ArrowDataType::Map( + Arc::new(Field::new( + "entries", + ArrowDataType::Struct(arrow_schema::Fields::from(entry_fields)), + fluss_type.is_nullable(), + )), + false, + ) + } + DataType::Row(row_type) => ArrowDataType::Struct(arrow_schema::Fields::from( + row_type + .fields() + .iter() + .map(|f| { + Field::new( + f.name(), + to_arrow_type(f.data_type()), + f.data_type().is_nullable(), + ) + }) + .collect::>(), + )), } } @@ -820,3 +888,129 @@ impl ArrowReader { } } pub struct MyVec(pub StreamReader); + +#[cfg(test)] +mod tests { + use super::*; + use crate::metadata::DataTypes; + + #[test] + fn test_to_array_type() { + assert_eq!(to_arrow_type(&DataTypes::boolean()), ArrowDataType::Boolean); + assert_eq!(to_arrow_type(&DataTypes::tinyint()), ArrowDataType::Int8); + assert_eq!(to_arrow_type(&DataTypes::smallint()), ArrowDataType::Int16); + assert_eq!(to_arrow_type(&DataTypes::bigint()), ArrowDataType::Int64); + assert_eq!(to_arrow_type(&DataTypes::int()), ArrowDataType::Int32); + assert_eq!(to_arrow_type(&DataTypes::float()), ArrowDataType::Float32); + assert_eq!(to_arrow_type(&DataTypes::double()), ArrowDataType::Float64); + assert_eq!(to_arrow_type(&DataTypes::char(16)), ArrowDataType::Utf8); + assert_eq!(to_arrow_type(&DataTypes::string()), ArrowDataType::Utf8); + assert_eq!( + to_arrow_type(&DataTypes::decimal(10, 2)), + ArrowDataType::Decimal128(10, 2) + ); + assert_eq!(to_arrow_type(&DataTypes::date()), ArrowDataType::Date32); + assert_eq!( + to_arrow_type(&DataTypes::time()), + ArrowDataType::Time32(arrow_schema::TimeUnit::Second) + ); + assert_eq!( + to_arrow_type(&DataTypes::time_with_precision(3)), + ArrowDataType::Time32(arrow_schema::TimeUnit::Millisecond) + ); + assert_eq!( + to_arrow_type(&DataTypes::time_with_precision(6)), + ArrowDataType::Time64(arrow_schema::TimeUnit::Microsecond) + ); + assert_eq!( + to_arrow_type(&DataTypes::time_with_precision(9)), + ArrowDataType::Time64(arrow_schema::TimeUnit::Nanosecond) + ); + assert_eq!( + to_arrow_type(&DataTypes::timestamp_with_precision(0)), + ArrowDataType::Timestamp(arrow_schema::TimeUnit::Second, None) + ); + assert_eq!( + to_arrow_type(&DataTypes::timestamp_with_precision(3)), + ArrowDataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None) + ); + assert_eq!( + to_arrow_type(&DataTypes::timestamp_with_precision(6)), + ArrowDataType::Timestamp(arrow_schema::TimeUnit::Microsecond, None) + ); + assert_eq!( + to_arrow_type(&DataTypes::timestamp_with_precision(9)), + ArrowDataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None) + ); + assert_eq!( + to_arrow_type(&DataTypes::timestamp_ltz_with_precision(0)), + ArrowDataType::Timestamp(arrow_schema::TimeUnit::Second, None) + ); + assert_eq!( + to_arrow_type(&DataTypes::timestamp_ltz_with_precision(3)), + ArrowDataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None) + ); + assert_eq!( + to_arrow_type(&DataTypes::timestamp_ltz_with_precision(6)), + ArrowDataType::Timestamp(arrow_schema::TimeUnit::Microsecond, None) + ); + assert_eq!( + to_arrow_type(&DataTypes::timestamp_ltz_with_precision(9)), + ArrowDataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None) + ); + assert_eq!(to_arrow_type(&DataTypes::bytes()), ArrowDataType::Binary); + assert_eq!( + to_arrow_type(&DataTypes::binary(16)), + ArrowDataType::FixedSizeBinary(16) + ); + + assert_eq!( + to_arrow_type(&DataTypes::array(DataTypes::int())), + ArrowDataType::List(Field::new_list_field(ArrowDataType::Int32, true).into()) + ); + + assert_eq!( + to_arrow_type(&DataTypes::map(DataTypes::string(), DataTypes::int())), + ArrowDataType::Map( + Arc::new(Field::new( + "entries", + ArrowDataType::Struct(arrow_schema::Fields::from(vec![ + Field::new("key", ArrowDataType::Utf8, true), + Field::new("value", ArrowDataType::Int32, true), + ])), + true, + )), + false, + ) + ); + + assert_eq!( + to_arrow_type(&DataTypes::row(vec![ + DataTypes::field("f1".to_string(), DataTypes::int()), + DataTypes::field("f2".to_string(), DataTypes::string()), + ])), + ArrowDataType::Struct(arrow_schema::Fields::from(vec![ + Field::new("f1", ArrowDataType::Int32, true), + Field::new("f2", ArrowDataType::Utf8, true), + ])) + ); + } + + #[test] + #[should_panic(expected = "Invalid precision value for TimeType: 10")] + fn test_time_invalid_precision() { + to_arrow_type(&DataTypes::time_with_precision(10)); + } + + #[test] + #[should_panic(expected = "Invalid precision value for TimestampType: 10")] + fn test_timestamp_invalid_precision() { + to_arrow_type(&DataTypes::timestamp_with_precision(10)); + } + + #[test] + #[should_panic(expected = "Invalid precision value for TimestampLTzType: 10")] + fn test_timestamp_ltz_invalid_precision() { + to_arrow_type(&DataTypes::timestamp_ltz_with_precision(10)); + } +}