From 34a2ba4065a6a314ee3b28f7001efd69f61b5c85 Mon Sep 17 00:00:00 2001 From: hellovai Date: Mon, 3 Jun 2024 16:51:24 -0400 Subject: [PATCH] Fix test cases (#640) * BAML test cases that are bools or ints need to allow for casting to string example: args { foo 1 } could work for foo: string or foo: int --- .../baml-core/src/ir/ir_helpers/mod.rs | 12 +++++-- .../src/ir/ir_helpers/to_baml_arg.rs | 34 ++++++++++++++++--- .../src/runtime/runtime_interface.rs | 10 +++--- 3 files changed, 45 insertions(+), 11 deletions(-) diff --git a/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs b/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs index 92694a97d..045572bac 100644 --- a/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs +++ b/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs @@ -41,6 +41,7 @@ pub trait IRHelper { &'a self, function: &'a FunctionWalker<'a>, params: &BamlMap, + allow_implicit_cast_to_string: bool, ) -> Result; } @@ -163,6 +164,7 @@ impl IRHelper for IntermediateRepr { &'a self, function: &'a FunctionWalker<'a>, params: &BamlMap, + allow_implicit_cast_to_string: bool, ) -> Result { let function_params = match function.inputs() { either::Either::Left(_) => { @@ -182,9 +184,13 @@ impl IRHelper for IntermediateRepr { for (param_name, param_type) in function_params { scope.push(param_name.to_string()); if let Some(param_value) = params.get(param_name.as_str()) { - if let Some(baml_arg) = - to_baml_arg::validate_arg(self, param_type, param_value, &mut scope) - { + if let Some(baml_arg) = to_baml_arg::validate_arg( + self, + param_type, + param_value, + &mut scope, + allow_implicit_cast_to_string, + ) { baml_arg_map.insert(param_name.to_string(), baml_arg); } } else { diff --git a/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs b/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs index 636bb5712..479a08722 100644 --- a/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs +++ b/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs @@ -29,10 +29,22 @@ pub fn validate_arg( field_type: &FieldType, value: &BamlValue, scope: &mut ScopeStack, + allow_implicit_cast_to_string: bool, ) -> Option { match field_type { FieldType::Primitive(t) => match t { TypeValue::String if matches!(value, BamlValue::String(_)) => Some(value.clone()), + TypeValue::String if allow_implicit_cast_to_string => match value { + BamlValue::Int(i) => Some(BamlValue::String(i.to_string())), + BamlValue::Float(f) => Some(BamlValue::String(f.to_string())), + BamlValue::Bool(true) => Some(BamlValue::String("true".to_string())), + BamlValue::Bool(false) => Some(BamlValue::String("false".to_string())), + BamlValue::Null => Some(BamlValue::String("null".to_string())), + _ => { + scope.push_error(format!("Expected type {:?}, got `{}`", t, value)); + None + } + }, TypeValue::Int if matches!(value, BamlValue::Int(_)) => Some(value.clone()), TypeValue::Float => match value { BamlValue::Int(val) => Some(BamlValue::Float(*val as f64)), @@ -111,7 +123,13 @@ pub fn validate_arg( let mut fields = BamlMap::new(); for f in c.walk_fields() { if let Some(v) = obj.get(f.name()) { - if let Some(v) = validate_arg(ir, f.r#type(), v, scope) { + if let Some(v) = validate_arg( + ir, + f.r#type(), + v, + scope, + allow_implicit_cast_to_string, + ) { fields.insert(f.name().to_string(), v); } } else if !f.r#type().is_optional() { @@ -138,7 +156,8 @@ pub fn validate_arg( BamlValue::List(arr) => { let mut items = Vec::new(); for v in arr { - if let Some(v) = validate_arg(ir, item, v, scope) { + if let Some(v) = validate_arg(ir, item, v, scope, allow_implicit_cast_to_string) + { items.push(v); } } @@ -154,7 +173,8 @@ pub fn validate_arg( FieldType::Union(options) => { for option in options { let mut scope = ScopeStack::new(); - let result = validate_arg(ir, option, value, &mut scope); + let result = + validate_arg(ir, option, value, &mut scope, allow_implicit_cast_to_string); if !scope.has_errors() { return result; } @@ -167,7 +187,13 @@ pub fn validate_arg( Some(value.clone()) } else { let mut inner_scope = ScopeStack::new(); - let baml_arg = validate_arg(ir, inner, value, &mut inner_scope); + let baml_arg = validate_arg( + ir, + inner, + value, + &mut inner_scope, + allow_implicit_cast_to_string, + ); if inner_scope.has_errors() { scope.push_error(format!("Expected optional {}, got `{}`", inner, value)); None diff --git a/engine/baml-runtime/src/runtime/runtime_interface.rs b/engine/baml-runtime/src/runtime/runtime_interface.rs index b1160aabf..013157f59 100644 --- a/engine/baml-runtime/src/runtime/runtime_interface.rs +++ b/engine/baml-runtime/src/runtime/runtime_interface.rs @@ -123,7 +123,7 @@ impl InternalRuntimeInterface for InternalBamlRuntime { node_index: Option, ) -> Result<(RenderedPrompt, OrchestrationScope)> { let func = self.get_function(function_name, ctx)?; - let baml_args = self.ir().check_function_params(&func, params)?; + let baml_args = self.ir().check_function_params(&func, params, false)?; let renderer = PromptRenderer::from_function(&func, &self.ir(), ctx)?; let client_name = renderer.client_name().to_string(); @@ -193,7 +193,9 @@ impl InternalRuntimeInterface for InternalBamlRuntime { errors )); } - Ok(params) + + let baml_args = self.ir().check_function_params(&func, ¶ms, true)?; + Ok(baml_args.as_map_owned().unwrap()) } Err(e) => return Err(anyhow::anyhow!("Unable to resolve test params: {:?}", e)), } @@ -296,7 +298,7 @@ impl RuntimeInterface for InternalBamlRuntime { ctx: RuntimeContext, ) -> Result { let func = self.get_function(&function_name, &ctx)?; - let baml_args = self.ir().check_function_params(&func, ¶ms)?; + let baml_args = self.ir().check_function_params(&func, ¶ms, false)?; let renderer = PromptRenderer::from_function(&func, self.ir(), &ctx)?; let client_name = renderer.client_name().to_string(); @@ -325,7 +327,7 @@ impl RuntimeInterface for InternalBamlRuntime { let orchestrator = self.orchestration_graph(&client_name, &ctx)?; let Some(baml_args) = self .ir - .check_function_params(&func, ¶ms)? + .check_function_params(&func, ¶ms, false)? .as_map_owned() else { anyhow::bail!("Expected parameters to be a map for: {}", function_name);