Skip to content

Commit

Permalink
Fix test cases (#640)
Browse files Browse the repository at this point in the history
* BAML test cases that are bools or ints need to allow for casting to
string

example:
args {
  foo 1
}

could work for foo: string or foo: int
  • Loading branch information
hellovai authored Jun 3, 2024
1 parent 171cbfe commit 34a2ba4
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 11 deletions.
12 changes: 9 additions & 3 deletions engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ pub trait IRHelper {
&'a self,
function: &'a FunctionWalker<'a>,
params: &BamlMap<String, BamlValue>,
allow_implicit_cast_to_string: bool,
) -> Result<BamlValue>;
}

Expand Down Expand Up @@ -163,6 +164,7 @@ impl IRHelper for IntermediateRepr {
&'a self,
function: &'a FunctionWalker<'a>,
params: &BamlMap<String, BamlValue>,
allow_implicit_cast_to_string: bool,
) -> Result<BamlValue> {
let function_params = match function.inputs() {
either::Either::Left(_) => {
Expand All @@ -182,9 +184,13 @@ impl IRHelper for IntermediateRepr {
for (param_name, param_type) in function_params {
scope.push(param_name.to_string());
if let Some(param_value) = params.get(param_name.as_str()) {
if let Some(baml_arg) =
to_baml_arg::validate_arg(self, param_type, param_value, &mut scope)
{
if let Some(baml_arg) = to_baml_arg::validate_arg(
self,
param_type,
param_value,
&mut scope,
allow_implicit_cast_to_string,
) {
baml_arg_map.insert(param_name.to_string(), baml_arg);
}
} else {
Expand Down
34 changes: 30 additions & 4 deletions engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,22 @@ pub fn validate_arg(
field_type: &FieldType,
value: &BamlValue,
scope: &mut ScopeStack,
allow_implicit_cast_to_string: bool,
) -> Option<BamlValue> {
match field_type {
FieldType::Primitive(t) => match t {
TypeValue::String if matches!(value, BamlValue::String(_)) => Some(value.clone()),
TypeValue::String if allow_implicit_cast_to_string => match value {
BamlValue::Int(i) => Some(BamlValue::String(i.to_string())),
BamlValue::Float(f) => Some(BamlValue::String(f.to_string())),
BamlValue::Bool(true) => Some(BamlValue::String("true".to_string())),
BamlValue::Bool(false) => Some(BamlValue::String("false".to_string())),
BamlValue::Null => Some(BamlValue::String("null".to_string())),
_ => {
scope.push_error(format!("Expected type {:?}, got `{}`", t, value));
None
}
},
TypeValue::Int if matches!(value, BamlValue::Int(_)) => Some(value.clone()),
TypeValue::Float => match value {
BamlValue::Int(val) => Some(BamlValue::Float(*val as f64)),
Expand Down Expand Up @@ -111,7 +123,13 @@ pub fn validate_arg(
let mut fields = BamlMap::new();
for f in c.walk_fields() {
if let Some(v) = obj.get(f.name()) {
if let Some(v) = validate_arg(ir, f.r#type(), v, scope) {
if let Some(v) = validate_arg(
ir,
f.r#type(),
v,
scope,
allow_implicit_cast_to_string,
) {
fields.insert(f.name().to_string(), v);
}
} else if !f.r#type().is_optional() {
Expand All @@ -138,7 +156,8 @@ pub fn validate_arg(
BamlValue::List(arr) => {
let mut items = Vec::new();
for v in arr {
if let Some(v) = validate_arg(ir, item, v, scope) {
if let Some(v) = validate_arg(ir, item, v, scope, allow_implicit_cast_to_string)
{
items.push(v);
}
}
Expand All @@ -154,7 +173,8 @@ pub fn validate_arg(
FieldType::Union(options) => {
for option in options {
let mut scope = ScopeStack::new();
let result = validate_arg(ir, option, value, &mut scope);
let result =
validate_arg(ir, option, value, &mut scope, allow_implicit_cast_to_string);
if !scope.has_errors() {
return result;
}
Expand All @@ -167,7 +187,13 @@ pub fn validate_arg(
Some(value.clone())
} else {
let mut inner_scope = ScopeStack::new();
let baml_arg = validate_arg(ir, inner, value, &mut inner_scope);
let baml_arg = validate_arg(
ir,
inner,
value,
&mut inner_scope,
allow_implicit_cast_to_string,
);
if inner_scope.has_errors() {
scope.push_error(format!("Expected optional {}, got `{}`", inner, value));
None
Expand Down
10 changes: 6 additions & 4 deletions engine/baml-runtime/src/runtime/runtime_interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ impl InternalRuntimeInterface for InternalBamlRuntime {
node_index: Option<usize>,
) -> Result<(RenderedPrompt, OrchestrationScope)> {
let func = self.get_function(function_name, ctx)?;
let baml_args = self.ir().check_function_params(&func, params)?;
let baml_args = self.ir().check_function_params(&func, params, false)?;

let renderer = PromptRenderer::from_function(&func, &self.ir(), ctx)?;
let client_name = renderer.client_name().to_string();
Expand Down Expand Up @@ -193,7 +193,9 @@ impl InternalRuntimeInterface for InternalBamlRuntime {
errors
));
}
Ok(params)

let baml_args = self.ir().check_function_params(&func, &params, true)?;
Ok(baml_args.as_map_owned().unwrap())
}
Err(e) => return Err(anyhow::anyhow!("Unable to resolve test params: {:?}", e)),
}
Expand Down Expand Up @@ -296,7 +298,7 @@ impl RuntimeInterface for InternalBamlRuntime {
ctx: RuntimeContext,
) -> Result<crate::FunctionResult> {
let func = self.get_function(&function_name, &ctx)?;
let baml_args = self.ir().check_function_params(&func, &params)?;
let baml_args = self.ir().check_function_params(&func, &params, false)?;

let renderer = PromptRenderer::from_function(&func, self.ir(), &ctx)?;
let client_name = renderer.client_name().to_string();
Expand Down Expand Up @@ -325,7 +327,7 @@ impl RuntimeInterface for InternalBamlRuntime {
let orchestrator = self.orchestration_graph(&client_name, &ctx)?;
let Some(baml_args) = self
.ir
.check_function_params(&func, &params)?
.check_function_params(&func, &params, false)?
.as_map_owned()
else {
anyhow::bail!("Expected parameters to be a map for: {}", function_name);
Expand Down

0 comments on commit 34a2ba4

Please sign in to comment.