diff --git a/Cargo.toml b/Cargo.toml index 3bec286..87e08fa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ name = "cocoindex_engine" crate-type = ["cdylib"] [dependencies] -pyo3 = { version = "0.24.1", features = ["chrono"] } +pyo3 = { version = "0.24.1", features = ["chrono", "auto-initialize"] } pythonize = "0.24.0" pyo3-async-runtimes = { version = "0.24.0", features = ["tokio-runtime"] } diff --git a/src/base/value.rs b/src/base/value.rs index c3266fb..a1c2e21 100644 --- a/src/base/value.rs +++ b/src/base/value.rs @@ -71,7 +71,7 @@ impl<'de> Deserialize<'de> for RangeValue { } /// Value of key. -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize)] pub enum KeyValue { Bytes(Bytes), Str(Arc), @@ -340,7 +340,7 @@ impl KeyValue { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Deserialize)] pub enum BasicValue { Bytes(Bytes), Str(Arc), @@ -511,7 +511,7 @@ impl BasicValue { } } -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, Default, PartialEq, Deserialize)] pub enum Value { #[default] Null, @@ -747,7 +747,7 @@ impl Value { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Deserialize)] pub struct FieldValues { pub fields: Vec>, } @@ -821,7 +821,7 @@ where } } -#[derive(Debug, Clone, Serialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ScopeValue(pub FieldValues); impl Deref for ScopeValue { diff --git a/src/py/convert.rs b/src/py/convert.rs index 327ba82..a234765 100644 --- a/src/py/convert.rs +++ b/src/py/convert.rs @@ -12,6 +12,7 @@ use std::sync::Arc; use super::IntoPyResult; use crate::base::{schema, value}; +#[derive(Debug)] pub struct Pythonized(pub T); impl<'py, T: DeserializeOwned> FromPyObject<'py> for Pythonized { @@ -168,6 +169,7 @@ fn field_values_from_py_object<'py>( list.len() ))); } + Ok(value::FieldValues { fields: schema .fields @@ -198,6 +200,7 @@ pub fn value_from_py_object<'py>( .into_iter() .map(|v| field_values_from_py_object(&schema.row, &v)) .collect::>>()?; + match schema.kind { schema::TableKind::UTable => { value::Value::UTable(values.into_iter().map(|v| v.into()).collect()) @@ -205,6 +208,7 @@ pub fn value_from_py_object<'py>( schema::TableKind::LTable => { value::Value::LTable(values.into_iter().map(|v| v.into()).collect()) } + schema::TableKind::KTable => value::Value::KTable( values .into_iter() @@ -226,3 +230,269 @@ pub fn value_from_py_object<'py>( }; Ok(result) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::base::schema; + use crate::base::value; + use crate::base::value::ScopeValue; + use pyo3::Python; + use std::collections::BTreeMap; + use std::sync::Arc; + + fn assert_roundtrip_conversion(original_value: &value::Value, value_type: &schema::ValueType) { + Python::with_gil(|py| { + // Convert Rust value to Python object + let pythonized_value = Pythonized(original_value.clone()); + let py_object = pythonized_value.into_pyobject(py).unwrap_or_else(|e| { + panic!("Failed to convert Rust value to Python object: {:?}", e) + }); + + println!("Python object: {:?}", py_object); + let roundtripped_value = + value_from_py_object(value_type, &py_object).unwrap_or_else(|e| { + panic!( + "Failed to convert Python object back to Rust value: {:?}", + e + ) + }); + + println!("Roundtripped value: {:?}", roundtripped_value); + // Compare values + match (&original_value, &roundtripped_value) { + (value::Value::Basic(orig), value::Value::Basic(round)) => { + assert_eq!(orig, round, "BasicValue mismatch"); + } + (value::Value::Struct(orig), value::Value::Struct(round)) => { + assert_eq!( + orig.fields.len(), + round.fields.len(), + "Struct field count mismatch" + ); + for (o, r) in orig.fields.iter().zip(round.fields.iter()) { + assert_eq!(o, r, "Struct field value mismatch"); + } + } + (value::Value::UTable(orig), value::Value::UTable(round)) => { + assert_eq!(orig.len(), round.len(), "UTable row count mismatch"); + for (o, r) in orig.iter().zip(round.iter()) { + assert_eq!( + o.fields.len(), + r.fields.len(), + "UTable field count mismatch" + ); + for (of, rf) in o.fields.iter().zip(r.fields.iter()) { + assert_eq!(of, rf, "UTable field value mismatch"); + } + } + } + (value::Value::LTable(orig), value::Value::LTable(round)) => { + assert_eq!(orig.len(), round.len(), "LTable row count mismatch"); + for (o, r) in orig.iter().zip(round.iter()) { + assert_eq!( + o.fields.len(), + r.fields.len(), + "LTable field count mismatch" + ); + for (of, rf) in o.fields.iter().zip(r.fields.iter()) { + assert_eq!(of, rf, "LTable field value mismatch"); + } + } + } + (value::Value::KTable(orig), value::Value::KTable(round)) => { + assert_eq!(orig.len(), round.len(), "KTable entry count mismatch"); + for (ok, ov) in orig.iter() { + let rv = round + .get(ok) + .unwrap_or_else(|| panic!("Missing key in KTable roundtrip: {:?}", ok)); + assert_eq!( + ov.fields.len(), + rv.fields.len(), + "KTable field count mismatch" + ); + for (of, rf) in ov.fields.iter().zip(rv.fields.iter()) { + assert_eq!(of, rf, "KTable field value mismatch"); + } + } + } + _ => panic!( + "Value type mismatch: expected {:?}, got {:?}", + original_value, roundtripped_value + ), + } + }); + } + + #[test] + fn test_roundtrip_basic_values() { + let values_and_types = vec![ + ( + value::Value::Basic(value::BasicValue::Int64(42)), + schema::ValueType::Basic(schema::BasicValueType::Int64), + ), + ( + value::Value::Basic(value::BasicValue::Float64(3.14)), + schema::ValueType::Basic(schema::BasicValueType::Float64), + ), + ( + value::Value::Basic(value::BasicValue::Str(Arc::from("hello"))), + schema::ValueType::Basic(schema::BasicValueType::Str), + ), + ( + value::Value::Basic(value::BasicValue::Bool(true)), + schema::ValueType::Basic(schema::BasicValueType::Bool), + ), + ]; + + for (val, typ) in values_and_types { + assert_roundtrip_conversion(&val, &typ); + } + } + + #[test] + fn test_roundtrip_struct() { + let struct_schema = schema::StructSchema { + description: Some(Arc::from("Test struct description")), + fields: Arc::new(vec![ + schema::FieldSchema { + name: "a".to_string(), + value_type: schema::EnrichedValueType { + typ: schema::ValueType::Basic(schema::BasicValueType::Int64), + nullable: false, + attrs: Default::default(), + }, + }, + schema::FieldSchema { + name: "b".to_string(), + value_type: schema::EnrichedValueType { + typ: schema::ValueType::Basic(schema::BasicValueType::Str), + nullable: false, + attrs: Default::default(), + }, + }, + ]), + }; + + let struct_val_data = value::FieldValues { + fields: vec![ + value::Value::Basic(value::BasicValue::Int64(10)), + value::Value::Basic(value::BasicValue::Str(Arc::from("world"))), + ], + }; + + let struct_val = value::Value::Struct(struct_val_data); + let struct_typ = schema::ValueType::Struct(struct_schema); // No clone needed + + assert_roundtrip_conversion(&struct_val, &struct_typ); + } + + #[test] + fn test_roundtrip_table_types() { + let row_schema_struct = Arc::new(schema::StructSchema { + description: Some(Arc::from("Test table row description")), + fields: Arc::new(vec![ + schema::FieldSchema { + name: "key_col".to_string(), // Will be used as key for KTable implicitly + value_type: schema::EnrichedValueType { + typ: schema::ValueType::Basic(schema::BasicValueType::Int64), + nullable: false, + attrs: Default::default(), + }, + }, + schema::FieldSchema { + name: "data_col_1".to_string(), + value_type: schema::EnrichedValueType { + typ: schema::ValueType::Basic(schema::BasicValueType::Str), + nullable: false, + attrs: Default::default(), + }, + }, + schema::FieldSchema { + name: "data_col_2".to_string(), + value_type: schema::EnrichedValueType { + typ: schema::ValueType::Basic(schema::BasicValueType::Bool), + nullable: false, + attrs: Default::default(), + }, + }, + ]), + }); + + let row1_fields = value::FieldValues { + fields: vec![ + value::Value::Basic(value::BasicValue::Int64(1)), + value::Value::Basic(value::BasicValue::Str(Arc::from("row1_data"))), + value::Value::Basic(value::BasicValue::Bool(true)), + ], + }; + let row1_scope_val: value::ScopeValue = row1_fields.into(); + + let row2_fields = value::FieldValues { + fields: vec![ + value::Value::Basic(value::BasicValue::Int64(2)), + value::Value::Basic(value::BasicValue::Str(Arc::from("row2_data"))), + value::Value::Basic(value::BasicValue::Bool(false)), + ], + }; + let row2_scope_val: value::ScopeValue = row2_fields.into(); + + // UTable + let utable_schema = schema::TableSchema { + kind: schema::TableKind::UTable, + row: (*row_schema_struct).clone(), + }; + let utable_val = value::Value::UTable(vec![row1_scope_val.clone(), row2_scope_val.clone()]); + let utable_typ = schema::ValueType::Table(utable_schema); + assert_roundtrip_conversion(&utable_val, &utable_typ); + + // LTable + let ltable_schema = schema::TableSchema { + kind: schema::TableKind::LTable, + row: (*row_schema_struct).clone(), + }; + let ltable_val = value::Value::LTable(vec![row1_scope_val.clone(), row2_scope_val.clone()]); + let ltable_typ = schema::ValueType::Table(ltable_schema); + assert_roundtrip_conversion(<able_val, <able_typ); + + // KTable + let ktable_schema = schema::TableSchema { + kind: schema::TableKind::KTable, + row: (*row_schema_struct).clone(), + }; + let mut ktable_data = BTreeMap::new(); + + // Create KTable entries where the ScopeValue doesn't include the key field + // This matches how the Python code will serialize/deserialize + let row1_fields = value::FieldValues { + fields: vec![ + value::Value::Basic(value::BasicValue::Str(Arc::from("row1_data"))), + value::Value::Basic(value::BasicValue::Bool(true)), + ], + }; + let row1_scope_val: value::ScopeValue = row1_fields.into(); + + let row2_fields = value::FieldValues { + fields: vec![ + value::Value::Basic(value::BasicValue::Str(Arc::from("row2_data"))), + value::Value::Basic(value::BasicValue::Bool(false)), + ], + }; + let row2_scope_val: value::ScopeValue = row2_fields.into(); + + // For KTable, the key is extracted from the first field of ScopeValue based on current serialization + let key1 = value::Value::::Basic(value::BasicValue::Int64(1)) + .into_key() + .unwrap(); + let key2 = value::Value::::Basic(value::BasicValue::Int64(2)) + .into_key() + .unwrap(); + + ktable_data.insert(key1, row1_scope_val.clone()); + ktable_data.insert(key2, row2_scope_val.clone()); + + let ktable_val = value::Value::KTable(ktable_data); + let ktable_typ = schema::ValueType::Table(ktable_schema); + assert_roundtrip_conversion(&ktable_val, &ktable_typ); + } +} diff --git a/src/server.rs b/src/server.rs index ef2c185..988c318 100644 --- a/src/server.rs +++ b/src/server.rs @@ -8,7 +8,7 @@ use tower_http::{ trace::TraceLayer, }; -#[derive(Deserialize, Debug)] +#[derive(Deserialize, Debug, Clone)] pub struct ServerSettings { pub address: String, #[serde(default)] diff --git a/src/settings.rs b/src/settings.rs index 2cbcf14..350ec7b 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -1,13 +1,13 @@ use serde::Deserialize; -#[derive(Deserialize, Debug)] +#[derive(Deserialize, Debug, Clone)] pub struct DatabaseConnectionSpec { pub url: String, pub user: Option, pub password: Option, } -#[derive(Deserialize, Debug)] +#[derive(Deserialize, Debug, Clone)] pub struct Settings { pub database: DatabaseConnectionSpec, }