diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_primitive.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_primitive.rs index 19b4bbeb2..f2dfcda5a 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_primitive.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_primitive.rs @@ -7,6 +7,7 @@ use crate::deserializer::{ deserialize_flags::{DeserializerConditions, Flag}, types::BamlValueWithFlags, }; +use regex::Regex; use super::{array_helper::coerce_array_to_singular, ParsingContext, ParsingError}; @@ -114,6 +115,10 @@ fn coerce_int( Ok(BamlValueWithFlags::Int( ((frac.round() as i64), Flag::FloatToInt(frac)).into(), )) + } else if let Some(frac) = float_from_comma_separated(s) { + Ok(BamlValueWithFlags::Int( + ((frac.round() as i64), Flag::FloatToInt(frac)).into(), + )) } else { Err(ctx.error_unexpected_type(target, value)) } @@ -144,6 +149,24 @@ fn float_from_maybe_fraction(value: &str) -> Option { } } +fn float_from_comma_separated(value: &str) -> Option { + let re = Regex::new(r"([-+]?)\$?(?:\d+(?:,\d+)*(?:\.\d+)?|\d+\.\d+|\d+|\.\d+)(?:e[-+]?\d+)?") + .unwrap(); + let matches: Vec<_> = re.find_iter(value).collect(); + + if matches.len() != 1 { + return None; + } + + let number_str = matches[0].as_str(); + let without_commas = number_str.replace(",", ""); + // Remove all Unicode currency symbols + let re_currency = Regex::new(r"\p{Sc}").unwrap(); + let without_currency = re_currency.replace_all(&without_commas, ""); + + without_currency.parse::().ok() +} + fn coerce_float( ctx: &ParsingContext, target: &FieldType, @@ -174,6 +197,8 @@ fn coerce_float( Ok(BamlValueWithFlags::Float((n as f64).into())) } else if let Some(frac) = float_from_maybe_fraction(s) { Ok(BamlValueWithFlags::Float(frac.into())) + } else if let Some(frac) = float_from_comma_separated(s) { + Ok(BamlValueWithFlags::Float(frac.into())) } else { Err(ctx.error_unexpected_type(target, value)) } @@ -226,3 +251,83 @@ fn coerce_bool( Err(ctx.error_unexpected_null(target)) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_float_from_comma_separated() { + // Note we don't handle european numbers correctly. + let test_cases = vec![ + // European Formats + // Valid German format (comma as decimal separator) + ("3,14", Some(314.0)), + ("1.234,56", None), + ("1.234.567,89", None), + ("€1.234,56", None), + ("-€1.234,56", None), + ("€1.234", Some(1.234)), // TODO - technically incorrect + ("1.234€", Some(1.234)), // TODO - technically incorrect + // Valid currencies with European formatting + ("€1.234,56", None), + ("€1,234.56", Some(1234.56)), // Incorrect format for Euro + // US Formats + // Valid US format (comma as thousands separator) + ("3,000", Some(3000.0)), + ("3,100.00", Some(3100.00)), + ("1,234.56", Some(1234.56)), + ("1,234,567.89", Some(1234567.89)), + ("$1,234.56", Some(1234.56)), + ("-$1,234.56", Some(-1234.56)), + ("$1,234", Some(1234.0)), + ("1,234$", Some(1234.0)), + ("$1,234.56", Some(1234.56)), + ("+$1,234.56", Some(1234.56)), + ("-$1,234.56", Some(-1234.56)), + ("$9,999,999,999", Some(9999999999.0)), + ("$1.23.456", None), + ("$1.234.567.890", None), + // Valid currencies with US formatting + ("$1,234", Some(1234.0)), + ("$314", Some(314.0)), + // Indian Formats + // Assuming Indian numbering system (not present in original tests, added for categorization) + ("$1,23,456", Some(123456.0)), + // Additional Indian format test cases can be added here + + // Percentages and Strings with Numbers + // Percentages + ("50%", Some(50.0)), + ("3.14%", Some(3.14)), + (".009%", Some(0.009)), + ("1.234,56%", None), + ("$1,234.56%", Some(1234.56)), + // Strings containing numbers + ("The answer is 10,000", Some(10000.0)), + ("The total is €1.234,56 today", None), + ("You owe $3,000 for the service", Some(3000.0)), + ("Save up to 20% on your purchase", Some(20.0)), + ("Revenue grew by 1,234.56 this quarter", Some(1234.56)), + ("Profit is -€1.234,56 in the last month", None), + // Sentences with Multiple Numbers + ("The answer is 10,000 and $3,000", None), + ("We earned €1.234,56 and $2,345.67 this year", None), + ("Increase of 5% and a profit of $1,000", None), + ("Loss of -€500 and a gain of 1,200.50", None), + ("Targets: 2,000 units and €3.000,75 revenue", None), + // trailing periods and commas + ("12,111,123.", Some(12111123.0)), + ("12,111,123,", Some(12111123.0)), + ]; + + for &(input, expected) in &test_cases { + let result = float_from_comma_separated(input); + assert_eq!( + result, expected, + "Failed to parse '{}'. Expected {:?}, got {:?}", + input, expected, result + ); + } + } +} diff --git a/engine/baml-lib/jsonish/src/tests/test_basics.rs b/engine/baml-lib/jsonish/src/tests/test_basics.rs index ea9afe677..9a53ee349 100644 --- a/engine/baml-lib/jsonish/src/tests/test_basics.rs +++ b/engine/baml-lib/jsonish/src/tests/test_basics.rs @@ -27,7 +27,7 @@ test_deserializer!( ); test_deserializer!(test_number, EMPTY_FILE, "12111", FieldType::int(), 12111); - +test_deserializer!(test_number_2, EMPTY_FILE, "12,111", FieldType::int(), 12111); test_deserializer!( test_string, EMPTY_FILE, @@ -49,6 +49,31 @@ test_deserializer!( 12111.123 ); +test_deserializer!( + test_float_comma_us, + EMPTY_FILE, + "12,111.123", + FieldType::float(), + 12111.123 +); + +// uncomment when we support european formatting. +// test_deserializer!( +// test_float_comma_german, +// EMPTY_FILE, +// "12.111,123", +// FieldType::float(), + +// ); + +test_deserializer!( + test_float_comma_german2, + EMPTY_FILE, + "12.11.", + FieldType::float(), + 12.11 +); + test_deserializer!(test_float_1, EMPTY_FILE, "1/5", FieldType::float(), 0.2); test_deserializer!( diff --git a/engine/baml-lib/jsonish/src/tests/test_lists.rs b/engine/baml-lib/jsonish/src/tests/test_lists.rs index 4b7a56a36..106d8c2ec 100644 --- a/engine/baml-lib/jsonish/src/tests/test_lists.rs +++ b/engine/baml-lib/jsonish/src/tests/test_lists.rs @@ -43,3 +43,82 @@ test_deserializer!( FieldType::List(FieldType::Class("Foo".to_string()).into()), [{"a": 1, "b": "hello"}, {"a": 2, "b": "world"}] ); + +test_deserializer!( + test_class_list, + r#" + class ListClass { + date string + description string + transaction_amount float + transaction_type string + } + "#, + r#" + [ + { + "date": "01/01", + "description": "Transaction 1", + "transaction_amount": -100.00, + "transaction_type": "Withdrawal" + }, + { + "date": "01/02", + "description": "Transaction 2", + "transaction_amount": -2,000.00, + "transaction_type": "Withdrawal" + }, + { + "date": "01/03", + "description": "Transaction 3", + "transaction_amount": -300.00, + "transaction_type": "Withdrawal" + }, + { + "date": "01/04", + "description": "Transaction 4", + "transaction_amount": -4,000.00, + "transaction_type": "Withdrawal" + }, + { + "date": "01/05", + "description": "Transaction 5", + "transaction_amount": -5,000.00, + "transaction_type": "Withdrawal" + } + ] + "#, + FieldType::List(FieldType::Class("ListClass".to_string()).into()), + [ + { + "date": "01/01", + "description": "Transaction 1", + "transaction_amount": -100.00, + "transaction_type": "Withdrawal" + }, + { + "date": "01/02", + "description": "Transaction 2", + "transaction_amount": -2000.00, + "transaction_type": "Withdrawal" + }, + { + "date": "01/03", + "description": "Transaction 3", + "transaction_amount": -300.00, + "transaction_type": "Withdrawal" + }, + { + "date": "01/04", + "description": "Transaction 4", + "transaction_amount": -4000.00, + "transaction_type": "Withdrawal" + }, + { + "date": "01/05", + "description": "Transaction 5", + "transaction_amount": -5000.00, + "transaction_type": "Withdrawal" + } + ] +); diff --git a/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs b/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs index 8e1225d57..9698c0870 100644 --- a/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs +++ b/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs @@ -342,29 +342,33 @@ fn relevant_data_models<'a>( #[cfg(test)] mod tests { - use std::collections::HashMap; - use crate::BamlRuntime; use super::*; + use crate::BamlRuntime; + use std::collections::HashMap; #[test] fn skipped_variants_are_not_rendered() { - let files = vec![("test-file.baml",r#" + let files = vec![( + "test-file.baml", + r#" enum Foo { Bar Baz @skip - }"# - )].into_iter().collect(); + }"#, + )] + .into_iter() + .collect(); let env_vars: HashMap<&str, &str> = HashMap::new(); let baml_runtime = BamlRuntime::from_file_content(".", &files, env_vars).unwrap(); let ctx_manager = baml_runtime.create_ctx_manager(BamlValue::Null, None); let ctx: RuntimeContext = ctx_manager.create_ctx(None, None).unwrap(); let field_type = FieldType::Enum("Foo".to_string()); - let render_output = render_output_format( baml_runtime.inner.ir.as_ref(), &ctx, &field_type ).unwrap(); + let render_output = + render_output_format(baml_runtime.inner.ir.as_ref(), &ctx, &field_type).unwrap(); let foo_enum = render_output.find_enum("Foo").unwrap(); - assert_eq!(foo_enum.values[0].0.real_name(), "Bar".to_string()); + assert_eq!(foo_enum.values[0].0.real_name(), "Bar".to_string()); assert_eq!(foo_enum.values.len(), 1); } - }