Skip to content

Commit

Permalink
Fix issue where literal inptus are not typechecked correctly (#1121)
Browse files Browse the repository at this point in the history
<!-- ELLIPSIS_HIDDEN -->


> [!IMPORTANT]
> Fixes type checking for literal inputs by refining subtype logic and
adding tests for literals and optional types.
> 
>   - **Behavior**:
> - Fixes type checking for literal inputs in `distribute_type()` in
`mod.rs` by refining subtype logic.
> - Updates `coerce_arg()` in `to_baml_arg.rs` to handle literals and
optional types more accurately.
>   - **Type System**:
> - Refines `is_subtype_of()` in `field_type/mod.rs` and `types.rs` to
improve handling of literals and optional types.
> - Adds `is_subtype_of()` logic to `Type` and `FieldType` to handle
nested and union types.
>   - **Tests**:
> - Adds tests for literal handling in `test_runtime.rs` and `mod.rs`.
> - Adds `prompt2.baml` for validation of function prompts with
literals.
>   - **Dependencies**:
>     - Adds `log` and `minijinja` to `Cargo.toml` dependencies.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for 7dae7e7. It will automatically
update as commits are pushed.</sup>


<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
hellovai authored Oct 31, 2024
1 parent 0ccf473 commit aa5dc85
Show file tree
Hide file tree
Showing 15 changed files with 1,213 additions and 111 deletions.
1 change: 1 addition & 0 deletions engine/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

64 changes: 30 additions & 34 deletions engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,35 +201,28 @@ impl IRHelper for IntermediateRepr {
field_type: FieldType,
) -> anyhow::Result<BamlValueWithMeta<FieldType>> {
match value {
BamlValue::String(s)
if FieldType::Primitive(TypeValue::String).is_subtype_of(&field_type) =>
{
Ok(BamlValueWithMeta::String(s, field_type))
}

BamlValue::String(s) => {
if let FieldType::Literal(LiteralValue::String(l)) = &field_type {
if s == *l {
return Ok(BamlValueWithMeta::String(s, field_type));
}
}
let literal_type = FieldType::Literal(LiteralValue::String(s.clone()));
let primitive_type = FieldType::Primitive(TypeValue::String);

if literal_type.is_subtype_of(&field_type)
|| primitive_type.is_subtype_of(&field_type)
{
return Ok(BamlValueWithMeta::String(s, field_type));
}
anyhow::bail!("Could not unify String with {:?}", field_type)
}

BamlValue::Int(i)
if FieldType::Literal(LiteralValue::Int(i.clone())).is_subtype_of(&field_type) =>
{
Ok(BamlValueWithMeta::Int(i, field_type))
}
BamlValue::Int(i)
if FieldType::Primitive(TypeValue::Int).is_subtype_of(&field_type) =>
{
Ok(BamlValueWithMeta::Int(i, field_type))
}

BamlValue::Int(i) => {
if let FieldType::Literal(LiteralValue::Int(l)) = &field_type {
if i == *l {
return Ok(BamlValueWithMeta::Int(i, field_type));
}
}

anyhow::bail!("Could not unify Int with {:?}", field_type)
}

Expand All @@ -240,20 +233,17 @@ impl IRHelper for IntermediateRepr {
}
BamlValue::Float(_) => anyhow::bail!("Could not unify Float with {:?}", field_type),

BamlValue::Bool(b)
if FieldType::Primitive(TypeValue::Bool).is_subtype_of(&field_type) =>
{
Ok(BamlValueWithMeta::Bool(b, field_type))
}

BamlValue::Bool(b) => {
if let FieldType::Literal(LiteralValue::Bool(l)) = &field_type {
if b == *l {
return Ok(BamlValueWithMeta::Bool(b, field_type));
}
}
let literal_type = FieldType::Literal(LiteralValue::Bool(b));
let primitive_type = FieldType::Primitive(TypeValue::Bool);

anyhow::bail!("Could not unify Bool with {:?}", field_type)
if literal_type.is_subtype_of(&field_type) {
Ok(BamlValueWithMeta::Bool(b, field_type))
} else if primitive_type.is_subtype_of(&field_type) {
Ok(BamlValueWithMeta::Bool(b, field_type))
} else {
anyhow::bail!("Could not unify Bool with {:?}", field_type)
}
}

BamlValue::Null if FieldType::Primitive(TypeValue::Null).is_subtype_of(&field_type) => {
Expand Down Expand Up @@ -426,7 +416,8 @@ pub fn infer_type<'a>(value: &'a BamlValue) -> Option<FieldType> {
mod tests {
use super::*;
use baml_types::{
BamlMedia, BamlMediaContent, BamlMediaType, BamlValue, Constraint, ConstraintLevel, FieldType, JinjaExpression, MediaBase64, TypeValue
BamlMedia, BamlMediaContent, BamlMediaType, BamlValue, Constraint, ConstraintLevel,
FieldType, JinjaExpression, MediaBase64, TypeValue,
};
use repr::make_test_ir;

Expand Down Expand Up @@ -685,8 +676,13 @@ mod tests {
)
.unwrap();
let function = ir.find_function("Foo").unwrap();
let params = vec![("a".to_string(), BamlValue::Int(1))].into_iter().collect();
let arg_coercer = ArgCoercer { span_path: None, allow_implicit_cast_to_string: true };
let params = vec![("a".to_string(), BamlValue::Int(1))]
.into_iter()
.collect();
let arg_coercer = ArgCoercer {
span_path: None,
allow_implicit_cast_to_string: true,
};
let res = ir.check_function_params(&function, &params, arg_coercer);
assert!(res.is_err());
}
Expand Down
7 changes: 5 additions & 2 deletions engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,11 @@ mod tests {
label: Some("foo".to_string()),
}],
};
let arg_coercer = ArgCoercer { span_path: None, allow_implicit_cast_to_string: true };
let res = arg_coercer.coerce_arg(&ir, &type_, &value, &mut ScopeStack::new());
let arg_coercer = ArgCoercer {
span_path: None,
allow_implicit_cast_to_string: true,
};
let res = arg_coercer.coerce_arg(&ir, &type_, &value, &mut ScopeStack::new());
assert!(res.is_err());
}
}
22 changes: 2 additions & 20 deletions engine/baml-lib/baml-types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,13 @@ derive_builder.workspace = true
serde.workspace = true
serde_json.workspace = true
strum.workspace = true
minijinja.workspace = true
log.workspace = true

[dependencies.indexmap]
workspace = true
optional = true

[dependencies.minijinja]
version = "1.0.16"
default-features = false
features = [
"macros",
"builtins",
"debug",
"preserve_order",
"adjacent_loop_items",
"unicode",
"json",
"unstable_machinery",
"unstable_machinery_serde",
"custom_syntax",
"internal_debug",
# We don't want to use these features:
# multi_template
# loader
#
]

[features]
default = ["stable_sort"]
Expand Down
24 changes: 17 additions & 7 deletions engine/baml-lib/baml-types/src/field_type/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,11 @@ impl FieldType {
}
match (self, other) {
(FieldType::Primitive(TypeValue::Null), FieldType::Optional(_)) => true,
(FieldType::Optional(self_item), FieldType::Optional(other_item)) => {
self_item.is_subtype_of(other_item)
}
(_, FieldType::Optional(t)) => self.is_subtype_of(t),
(FieldType::Optional(_), _) => false,

// Handle types that nest other types.
(FieldType::List(self_item), FieldType::List(other_item)) => {
Expand All @@ -234,12 +239,6 @@ impl FieldType {
}
(FieldType::Map(_, _), _) => false,

(FieldType::Optional(self_item), FieldType::Optional(other_item)) => {
self_item.is_subtype_of(other_item)
}
(_, FieldType::Optional(other_item)) => self.is_subtype_of(other_item),
(FieldType::Optional(_), _) => false,

(
FieldType::Constrained {
base: self_base,
Expand All @@ -252,13 +251,24 @@ impl FieldType {
) => self_base.is_subtype_of(other_base) && self_cs == other_cs,
(FieldType::Constrained { base, .. }, _) => base.is_subtype_of(other),
(_, FieldType::Constrained { base, .. }) => self.is_subtype_of(base),

(
FieldType::Literal(LiteralValue::Bool(_)),
FieldType::Primitive(TypeValue::Bool),
) => true,
(FieldType::Literal(LiteralValue::Bool(_)), _) => {
self.is_subtype_of(&FieldType::Primitive(TypeValue::Bool))
}
(
FieldType::Literal(LiteralValue::Int(_)),
FieldType::Primitive(TypeValue::Int),
) => true,
(FieldType::Literal(LiteralValue::Int(_)), _) => {
self.is_subtype_of(&FieldType::Primitive(TypeValue::Int))
}
(
FieldType::Literal(LiteralValue::String(_)),
FieldType::Primitive(TypeValue::String),
) => true,
(FieldType::Literal(LiteralValue::String(_)), _) => {
self.is_subtype_of(&FieldType::Primitive(TypeValue::String))
}
Expand Down
26 changes: 26 additions & 0 deletions engine/baml-lib/baml/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,19 @@ use std::{env, fs, io::Write as _, path};

const VALIDATIONS_ROOT_DIR: &str = "tests/validation_files";
const CARGO_MANIFEST_DIR: &str = env!("CARGO_MANIFEST_DIR");
const BAML_CLI_INIT_DIR: &str = concat!("/../../baml-runtime/src/cli/initial_project/baml_src");
const PROMPT_FIDDLE_EXAMPLE_DIR: &str =
concat!("/../../../typescript/fiddle-frontend/public/_examples/all-projects/baml_src");

fn main() {
build_folder_tests(
&BAML_CLI_INIT_DIR,
"tests/validation_files/baml_cli_init.baml",
);
build_folder_tests(
&PROMPT_FIDDLE_EXAMPLE_DIR,
"tests/validation_files/prompt_fiddle_example.baml",
);
build_validation_tests();
// build_reformat_tests();
}
Expand All @@ -26,6 +37,21 @@ fn build_validation_tests() {
}
}

fn build_folder_tests(dir: &'static str, out_file_name: &str) {
println!("cargo:rerun-if-changed={dir}");
let mut all_schemas = Vec::new();
find_all_schemas("", &mut all_schemas, dir);

// concatenate all the files in the directory into a single file
let mut out_file = fs::File::create(format!("{CARGO_MANIFEST_DIR}/{out_file_name}")).unwrap();
for schema_path in &all_schemas {
let file_path = format!("{CARGO_MANIFEST_DIR}/{dir}{schema_path}");
println!("Reading file: {}", file_path);
let file_content = fs::read_to_string(&file_path).unwrap();
writeln!(out_file, "{}", file_content).unwrap();
}
}

fn find_all_schemas(prefix: &str, all_schemas: &mut Vec<String>, root_dir: &'static str) {
for entry in fs::read_dir(format!("{CARGO_MANIFEST_DIR}/{root_dir}/{prefix}")).unwrap() {
let entry = entry.unwrap();
Expand Down
116 changes: 116 additions & 0 deletions engine/baml-lib/baml/tests/validation_files/baml_cli_init.baml
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// Defining a data model.
class Resume {
name string
email string
experience string[]
skills string[]
}

// Create a function to extract the resume from a string.
function ExtractResume(resume: string) -> Resume {
// Specify a client as provider/model-name
// you can use custom LLM params with a custom client name from clients.baml like "client CustomHaiku"
client "openai/gpt-4o" // Set OPENAI_API_KEY to use this client.
prompt #"
Extract from this content:
{{ resume }}

{{ ctx.output_format }}
"#
}

// Test the function with a sample resume. Open the VSCode playground to run this.
test vaibhav_resume {
functions [ExtractResume]
args {
resume #"
Vaibhav Gupta
[email protected]

Experience:
- Founder at BoundaryML
- CV Engineer at Google
- CV Engineer at Microsoft

Skills:
- Rust
- C++
"#
}
}

// Learn more about clients at https://docs.boundaryml.com/docs/snippets/clients/overview

client<llm> CustomGPT4o {
provider openai
options {
model "gpt-4o"
api_key env.OPENAI_API_KEY
}
}

client<llm> CustomGPT4oMini {
provider openai
retry_policy Exponential
options {
model "gpt-4o-mini"
api_key env.OPENAI_API_KEY
}
}

client<llm> CustomSonnet {
provider anthropic
options {
model "claude-3-5-sonnet-20241022"
api_key env.ANTHROPIC_API_KEY
}
}


client<llm> CustomHaiku {
provider anthropic
retry_policy Constant
options {
model "claude-3-haiku-20240307"
api_key env.ANTHROPIC_API_KEY
}
}

// https://docs.boundaryml.com/docs/snippets/clients/round-robin
client<llm> CustomFast {
provider round-robin
options {
// This will alternate between the two clients
strategy [CustomGPT4oMini, CustomHaiku]
}
}

// https://docs.boundaryml.com/docs/snippets/clients/fallback
client<llm> OpenaiFallback {
provider fallback
options {
// This will try the clients in order until one succeeds
strategy [CustomGPT4oMini, CustomGPT4oMini]
}
}

// https://docs.boundaryml.com/docs/snippets/clients/retry
retry_policy Constant {
max_retries 3
// Strategy is optional
strategy {
type constant_delay
delay_ms 200
}
}

retry_policy Exponential {
max_retries 2
// Strategy is optional
strategy {
type exponential_backoff
delay_ms 300
mutliplier 1.5
max_delay_ms 10000
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
class Message {
role "user" | "assistant" | string
message string
}

function Bot(convo: Message[]) -> string {
client "openai/gpt-4o"
prompt #"
You are a helpful assistant.
{{ ctx.output_format }}

{% for m in convo %}
{{ _.role(m.role) }}
{{ m.message }}
{% endfor %}
"#
}
Loading

0 comments on commit aa5dc85

Please sign in to comment.