From 4b77aa27b35c2e8a2e71336ead06b89ffd795610 Mon Sep 17 00:00:00 2001 From: Greg Hale Date: Mon, 25 Nov 2024 22:19:34 -0800 Subject: [PATCH 01/13] Add semantic streaming --- engine/Cargo.lock | 163 ++- engine/Cargo.toml | 2 +- .../baml-core/src/ir/ir_helpers/mod.rs | 695 +++++++++- .../src/ir/ir_helpers/to_baml_arg.rs | 7 +- .../baml-lib/baml-core/src/ir/json_schema.rs | 2 +- engine/baml-lib/baml-core/src/ir/repr.rs | 201 ++- engine/baml-lib/baml-core/src/ir/walker.rs | 20 + engine/baml-lib/baml-types/src/baml_value.rs | 422 ++++-- .../baml-lib/baml-types/src/field_type/mod.rs | 55 +- engine/baml-lib/baml-types/src/lib.rs | 8 +- .../streaming/streaming_ok.baml | 11 + .../jinja-runtime/src/output_format/types.rs | 239 +++- engine/baml-lib/jsonish/Cargo.toml | 6 + engine/baml-lib/jsonish/benches/classes.rs | 36 + .../jsonish/benches/jsonish_benchmark.rs | 23 + engine/baml-lib/jsonish/benches/lists.rs | 28 + engine/baml-lib/jsonish/benches/literals.rs | 28 + engine/baml-lib/jsonish/benches/partials.rs | 139 ++ engine/baml-lib/jsonish/benches/unions.rs | 72 + .../src/deserializer/coercer/coerce_array.rs | 11 +- .../deserializer/coercer/coerce_literal.rs | 13 +- .../src/deserializer/coercer/coerce_map.rs | 12 +- .../deserializer/coercer/coerce_primitive.rs | 74 +- .../src/deserializer/coercer/field_type.rs | 16 +- .../coercer/ir_ref/coerce_class.rs | 39 +- .../coercer/ir_ref/coerce_enum.rs | 3 +- .../src/deserializer/coercer/match_string.rs | 2 +- .../jsonish/src/deserializer/coercer/mod.rs | 2 +- .../src/deserializer/deserialize_flags.rs | 25 +- .../baml-lib/jsonish/src/deserializer/mod.rs | 3 +- .../jsonish/src/deserializer/score.rs | 9 + .../src/deserializer/semantic_streaming.rs | 636 +++++++++ .../jsonish/src/deserializer/types.rs | 28 +- engine/baml-lib/jsonish/src/helpers/common.rs | 70 + engine/baml-lib/jsonish/src/helpers/mod.rs | 289 ++++ .../jsonish/src/jsonish/iterative_parser.rs | 226 +-- .../jsonish/src/jsonish/parser/entry.rs | 171 ++- .../src/jsonish/parser/fixing_parser.rs | 29 +- .../parser/fixing_parser/json_collection.rs | 94 +- .../parser/fixing_parser/json_parse_state.rs | 170 ++- .../src/jsonish/parser/markdown_parser.rs | 21 +- .../src/jsonish/parser/multi_json_parser.rs | 39 +- engine/baml-lib/jsonish/src/jsonish/value.rs | 136 +- engine/baml-lib/jsonish/src/lib.rs | 314 +++-- .../baml-lib/jsonish/src/tests/animation.rs | 56 + engine/baml-lib/jsonish/src/tests/macros.rs | 85 +- engine/baml-lib/jsonish/src/tests/mod.rs | 315 +---- .../baml-lib/jsonish/src/tests/test_class.rs | 119 +- .../baml-lib/jsonish/src/tests/test_lists.rs | 10 +- .../baml-lib/jsonish/src/tests/test_maps.rs | 46 +- .../jsonish/src/tests/test_streaming.rs | 128 ++ .../parser-database/src/attributes/mod.rs | 9 + .../src/attributes/to_string_attribute.rs | 24 +- engine/baml-lib/parser-database/src/lib.rs | 3 +- .../src/names/validate_reserved_names.rs | 14 +- engine/baml-lib/schema-ast/src/ast/field.rs | 1 + .../schema-ast/src/parser/datamodel.pest | 3 +- .../schema-ast/src/parser/parse_field.rs | 38 +- .../schema-ast/src/parser/parse_identifier.rs | 27 + engine/baml-runtime/Cargo.toml | 7 + engine/baml-runtime/benches/bench.rs | 107 ++ engine/baml-runtime/benches/lib.rs | 1 + .../benches/sap_parser_benchmark.rs | 0 engine/baml-runtime/src/cli/serve/mod.rs | 9 +- .../src/internal/llm_client/mod.rs | 233 +++- .../internal/llm_client/orchestrator/call.rs | 16 +- .../llm_client/orchestrator/stream.rs | 29 +- .../src/internal/prompt_renderer/mod.rs | 17 +- .../prompt_renderer/render_output_format.rs | 23 +- engine/baml-runtime/src/lib.rs | 9 +- .../src/runtime/runtime_interface.rs | 5 +- .../{constraints.rs => test_constraints.rs} | 0 engine/baml-runtime/src/tracing/mod.rs | 15 +- engine/baml-runtime/src/types/response.rs | 57 +- engine/baml-runtime/src/types/stream.rs | 4 +- .../baml-schema-wasm/src/runtime_wasm/mod.rs | 103 +- engine/language_client_codegen/src/lib.rs | 3 +- engine/language_client_codegen/src/openapi.rs | 2 +- .../src/python/generate_types.rs | 168 ++- .../language_client_codegen/src/python/mod.rs | 61 +- .../src/python/templates/async_client.py.j2 | 6 +- .../src/python/templates/partial_types.py.j2 | 7 +- .../src/python/templates/sync_client.py.j2 | 6 +- .../src/python/templates/types.py.j2 | 1 - .../src/ruby/field_type.rs | 2 +- .../src/ruby/generate_types.rs | 96 +- .../language_client_codegen/src/ruby/mod.rs | 13 +- .../src/ruby/templates/client.rb.j2 | 2 +- .../src/typescript/generate_types.rs | 75 +- .../src/typescript/mod.rs | 297 +++- .../typescript/templates/async_client.ts.j2 | 15 +- .../typescript/templates/partial_types.ts.j2 | 35 + .../typescript/templates/sync_client.ts.j2 | 8 +- engine/language_client_python/Cargo.toml | 1 + .../src/types/function_results.rs | 143 +- engine/language_client_ruby/Gemfile.lock | 2 +- .../ext/ruby_ffi/Cargo.toml | 1 + .../ext/ruby_ffi/src/function_result.rs | 19 +- .../ext/ruby_ffi/src/lib.rs | 2 +- .../ext/ruby_ffi/src/ruby_to_json.rs | 154 ++- engine/language_client_ruby/lib/baml.rb | 3 + engine/language_client_ruby/lib/stream.rb | 23 +- engine/language_client_typescript/native.d.ts | 2 +- .../src/types/function_results.rs | 10 +- engine/language_client_typescript/stream.js | 4 +- .../typescript_src/stream.ts | 4 +- fern/01-guide/04-baml-basics/streaming.mdx | 257 +++- flake.lock | 117 ++ .../test-files/functions/output/class.baml | 2 +- .../semantic_streaming.baml | 33 + .../python/baml_client/async_client.py | 904 ++++++------ integ-tests/python/baml_client/inlinedbaml.py | 3 +- .../python/baml_client/partial_types.py | 29 +- integ-tests/python/baml_client/sync_client.py | 904 ++++++------ .../python/baml_client/type_builder.py | 2 +- integ-tests/python/baml_client/types.py | 23 +- integ-tests/python/tests/test_functions.py | 56 +- integ-tests/ruby/baml_client/client.rb | 445 +++--- integ-tests/ruby/baml_client/inlined.rb | 3 +- integ-tests/ruby/baml_client/partial-types.rb | 160 ++- integ-tests/ruby/baml_client/type-registry.rb | 2 +- integ-tests/ruby/baml_client/types.rb | 72 + integ-tests/ruby/test_functions.rb | 66 +- .../typescript/baml_client/async_client.ts | 1211 +++++++++-------- .../typescript/baml_client/inlinedbaml.ts | 3 +- .../typescript/baml_client/partial_types.ts | 490 +++++++ .../typescript/baml_client/sync_client.ts | 321 +++-- .../typescript/baml_client/type_builder.ts | 2 +- integ-tests/typescript/baml_client/types.ts | 30 + .../typescript/tests/input-output.test.ts | 41 + .../typescript/tests/integ-tests.test.ts.old | 17 +- 131 files changed, 9000 insertions(+), 3460 deletions(-) create mode 100644 engine/baml-lib/baml/tests/validation_files/streaming/streaming_ok.baml create mode 100644 engine/baml-lib/jsonish/benches/classes.rs create mode 100644 engine/baml-lib/jsonish/benches/jsonish_benchmark.rs create mode 100644 engine/baml-lib/jsonish/benches/lists.rs create mode 100644 engine/baml-lib/jsonish/benches/literals.rs create mode 100644 engine/baml-lib/jsonish/benches/partials.rs create mode 100644 engine/baml-lib/jsonish/benches/unions.rs create mode 100644 engine/baml-lib/jsonish/src/deserializer/semantic_streaming.rs create mode 100644 engine/baml-lib/jsonish/src/helpers/common.rs create mode 100644 engine/baml-lib/jsonish/src/helpers/mod.rs create mode 100644 engine/baml-lib/jsonish/src/tests/animation.rs create mode 100644 engine/baml-lib/jsonish/src/tests/test_streaming.rs create mode 100644 engine/baml-runtime/benches/bench.rs create mode 100644 engine/baml-runtime/benches/lib.rs create mode 100644 engine/baml-runtime/benches/sap_parser_benchmark.rs rename engine/baml-runtime/src/{constraints.rs => test_constraints.rs} (100%) create mode 100644 engine/language_client_codegen/src/typescript/templates/partial_types.ts.j2 create mode 100644 flake.lock create mode 100644 integ-tests/baml_src/test-files/semantic_streaming/semantic_streaming.baml create mode 100644 integ-tests/typescript/baml_client/partial_types.ts diff --git a/engine/Cargo.lock b/engine/Cargo.lock index d09af12348..d7ed3267a4 100644 --- a/engine/Cargo.lock +++ b/engine/Cargo.lock @@ -77,6 +77,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "anstream" version = "0.6.14" @@ -918,6 +924,7 @@ dependencies = [ "futures", "indexmap 2.2.6", "internal-baml-codegen", + "jsonish", "libc", "log", "pyo3", @@ -959,6 +966,7 @@ dependencies = [ "clap", "colored", "console_log", + "criterion", "dashmap", "derive_more", "dissimilar", @@ -1012,7 +1020,7 @@ dependencies = [ "strum", "strum_macros", "test-log", - "thiserror 2.0.3", + "thiserror 2.0.9", "tokio", "tokio-stream", "tracing", @@ -1276,6 +1284,12 @@ dependencies = [ "either", ] +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "cc" version = "1.1.2" @@ -1329,6 +1343,33 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + [[package]] name = "clang-sys" version = "1.8.1" @@ -1498,6 +1539,42 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + [[package]] name = "crossbeam-channel" version = "0.5.13" @@ -1532,6 +1609,12 @@ version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + [[package]] name = "crypto-common" version = "0.1.6" @@ -2197,6 +2280,16 @@ dependencies = [ "tracing", ] +[[package]] +name = "half" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +dependencies = [ + "cfg-if", + "crunchy", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -2856,6 +2949,17 @@ dependencies = [ "once_cell", ] +[[package]] +name = "is-terminal" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "261f68e344040fbd0edea105bef17c66edf46f984ddb1115b775ce31be948f4b" +dependencies = [ + "hermit-abi 0.4.0", + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "is-wsl" version = "0.4.0" @@ -2920,6 +3024,7 @@ dependencies = [ "baml-types", "bstd", "colored", + "criterion", "either", "indexmap 2.2.6", "indoc", @@ -2932,6 +3037,7 @@ dependencies = [ "serde_json", "strsim 0.10.0", "test-log", + "thiserror 2.0.9", ] [[package]] @@ -3432,6 +3538,12 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +[[package]] +name = "oorandom" +version = "11.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" + [[package]] name = "open" version = "5.3.0" @@ -3684,6 +3796,34 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + [[package]] name = "polling" version = "2.8.0" @@ -4227,6 +4367,7 @@ dependencies = [ "env_logger", "futures", "indexmap 2.2.6", + "jsonish", "log", "magnus", "rb-sys", @@ -4921,11 +5062,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.3" +version = "2.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c006c85c7651b3cf2ada4584faa36773bd07bac24acfb39f3c431b36d7e667aa" +checksum = "f072643fd0190df67a8bab670c20ef5d8737177d6ac6b2e9a236cb096206b2cc" dependencies = [ - "thiserror-impl 2.0.3", + "thiserror-impl 2.0.9", ] [[package]] @@ -4941,9 +5082,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.3" +version = "2.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f077553d607adc1caf65430528a576c757a71ed73944b66ebb58ef2bbd243568" +checksum = "7b50fa271071aae2e6ee85f842e2e28ba8cd2c5fb67f11fcb1fd70b276f9e7d4" dependencies = [ "proc-macro2", "quote", @@ -4991,6 +5132,16 @@ dependencies = [ "time-core", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.8.0" diff --git a/engine/Cargo.toml b/engine/Cargo.toml index f85dc4b3b5..f88840ac5b 100644 --- a/engine/Cargo.toml +++ b/engine/Cargo.toml @@ -112,4 +112,4 @@ lto = false inherits = "dev" [profile.release] -lto = true +lto = true \ No newline at end of file diff --git a/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs b/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs index b20a51beef..97a949eda8 100644 --- a/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs +++ b/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs @@ -2,6 +2,7 @@ mod error_utils; pub mod scope_diagnostics; mod to_baml_arg; +use indexmap::IndexMap; use itertools::Itertools; use self::scope_diagnostics::ScopeStack; @@ -16,7 +17,7 @@ use crate::{ use anyhow::Result; use baml_types::{ BamlMap, BamlValue, BamlValueWithMeta, Constraint, ConstraintLevel, FieldType, LiteralValue, - TypeValue, + StreamingBehavior, TypeValue, }; pub use to_baml_arg::ArgCoercer; @@ -55,18 +56,29 @@ pub trait IRHelper { params: &BamlMap, coerce_settings: ArgCoercer, ) -> Result; + fn is_subtype(&self, base: &FieldType, other: &FieldType) -> bool; fn distribute_type( &self, value: BamlValue, field_type: FieldType, - ) -> Result>; - fn is_subtype(&self, base: &FieldType, other: &FieldType) -> bool; + ) -> anyhow::Result>; + fn distribute_type_with_meta( + &self, + value: BamlValueWithMeta, + field_type: FieldType, + // default_meta: Option<&T>, + ) -> Result>; + fn distribute_metadata<'a>( + &'a self, + field_type: &'a FieldType, + ) -> (&'a FieldType, (Vec, StreamingBehavior)); fn distribute_constraints<'a>( &'a self, field_type: &'a FieldType, ) -> (&'a FieldType, Vec); fn type_has_constraints(&self, field_type: &FieldType) -> bool; fn type_has_checks(&self, field_type: &FieldType) -> bool; + fn recursive_alias_definition(&self, alias_name: &str) -> Option<&FieldType>; } impl IRHelper for IntermediateRepr { @@ -276,32 +288,18 @@ impl IRHelper for IntermediateRepr { (FieldType::Map(_, _), _) => false, ( - FieldType::Constrained { + FieldType::WithMetadata { base: constrained_base, - constraints: base_constraints, - }, - FieldType::Constrained { - base: other_base, - constraints: other_constraints, - }, - ) => { - self.is_subtype(constrained_base, other_base) - && base_constraints == other_constraints - } - ( - FieldType::Constrained { - base: contrained_base, .. }, _, - ) => self.is_subtype(contrained_base, other), + ) => self.is_subtype(constrained_base, other), ( _, - FieldType::Constrained { - base: constrained_base, - .. + FieldType::WithMetadata { + base: other_base, .. }, - ) => self.is_subtype(base, constrained_base), + ) => self.is_subtype(base, other_base), (FieldType::Literal(LiteralValue::Bool(_)), FieldType::Primitive(TypeValue::Bool)) => { true @@ -410,6 +408,13 @@ impl IRHelper for IntermediateRepr { _ => Some(FieldType::Union(item_types)), }; + // let key_type = match self.distribute_metadata(&field_type) { + // (FieldType::Map(annotation_key_type, _), _) => annotation_key_type.as_ref(), + // // TODO: Make the following a baml compiler error, too. + // (FieldType::RecursiveTypeAlias(_), _) => anyhow::bail!("Type aliases are not allowed as map keys."), + // _ => anyhow::bail!("Value was not a map."), + // }; + match maybe_item_type { Some(item_type) => { let map_type = FieldType::Map( @@ -521,12 +526,209 @@ impl IRHelper for IntermediateRepr { } } + /// For some `BamlValueWithMeta` with type `FieldType`, walk the structure of both the value + /// and the type simultaneously, associating each node in the `BamlValue` with its + /// `FieldType`. + /// TODO (Greg): Make this function DynamicTypes-aware. Right now it assigns default metadata + /// to unknown classes, which may have been created with TypeBuilder. + fn distribute_type_with_meta( + &self, + value: BamlValueWithMeta, + field_type: FieldType, + ) -> anyhow::Result> { + let field_base_type = self.distribute_metadata(&field_type).0; + match value { + BamlValueWithMeta::String(s, meta) => { + let literal_type = FieldType::Literal(LiteralValue::String(s.clone())); + let primitive_type = FieldType::Primitive(TypeValue::String); + + if self.is_subtype(&literal_type, &field_base_type) + || self.is_subtype(&primitive_type, &field_base_type) + { + return Ok(BamlValueWithMeta::String(s, (meta, field_type))); + } + anyhow::bail!("Could not unify String with {:?}", field_base_type) + } + BamlValueWithMeta::Int(i, meta) + if self.is_subtype(&FieldType::Literal(LiteralValue::Int(i)), &field_base_type) => + { + Ok(BamlValueWithMeta::Int(i, (meta, field_type))) + } + BamlValueWithMeta::Int(i, meta) + if self.is_subtype(&FieldType::Primitive(TypeValue::Int), &field_type) => + { + Ok(BamlValueWithMeta::Int(i, (meta, field_type))) + } + BamlValueWithMeta::Int(_i, _meta) => { + anyhow::bail!("Could not unify Int with {:?}", field_base_type) + } + + BamlValueWithMeta::Float(f, meta) + if self.is_subtype(&FieldType::Primitive(TypeValue::Float), &field_base_type) => + { + Ok(BamlValueWithMeta::Float(f, (meta, field_type))) + } + BamlValueWithMeta::Float(_, _) => { + anyhow::bail!("Could not unify Float with {:?}", field_base_type) + } + + BamlValueWithMeta::Bool(b, meta) => { + let literal_type = FieldType::Literal(LiteralValue::Bool(b)); + let primitive_type = FieldType::Primitive(TypeValue::Bool); + + if self.is_subtype(&literal_type, &field_base_type) + || self.is_subtype(&primitive_type, &field_base_type) + { + Ok(BamlValueWithMeta::Bool(b, (meta, field_type))) + } else { + anyhow::bail!("Could not unify Bool with {:?}", field_base_type) + } + } + + BamlValueWithMeta::Null(meta) => Ok(BamlValueWithMeta::Null((meta, field_type))), + + // TODO: Handle enums and literal keys. + BamlValueWithMeta::Map(pairs, meta) => { + let (annotation_key_type, annotation_value_type) = map_types(self, &field_type) + .ok_or(anyhow::anyhow!("Could not unify map with {field_type:?}"))?; + + let mapped_fields: BamlMap> = pairs + .into_iter() + .map(|(key, val)| { + let sub_value = item_type(self, &field_type, &val) + .ok_or(anyhow::anyhow!( + "Could not determine item_type of item in map" + )) + .and_then(|item_type| self.distribute_type_with_meta(val, item_type))?; + Ok((key, sub_value)) + }) + .collect::>>>( + )?; + + // let item_types: Vec<&FieldType> = mapped_fields + // .values() + // .map(|i| &i.meta().1) + // .dedup() + // .collect(); + // let items_type = match item_types.len() { + // 0 => None, + // 1 => Some(item_types[0].clone()), + // _ => Some(FieldType::Union( + // item_types.into_iter().map(|t| t.clone()).collect(), + // )), + // }; + // if let Some((key_ty, value_ty)) = map_types(self, &field_type) { + // let expected_type = FieldType::Map(Box::new(key_ty.clone()), Box::new(value_ty.clone())); + // if !self.is_subtype(&expected_type, &field_base_type) { + // anyhow::bail!("Could not unify {:?} with {:?}", expected_type, field_base_type); + // } + // } + + Ok(BamlValueWithMeta::Map(mapped_fields, (meta, field_type))) + } + + BamlValueWithMeta::List(items, meta) => { + let new_items = items + .into_iter() + .map(|i| { + // dbg!(&field_type); + // dbg!(&i); + item_type(self, &field_type, &i) + .ok_or({ + eprintln!("ty: {field_type:?}, i: {i:?}"); + anyhow::anyhow!("Could not infer child type") + }) + .and_then(|item_type| self.distribute_type_with_meta(i, item_type)) + }) + .collect::>>()?; + // dbg!(&new_items); + // let item_types: Vec<&FieldType> = + // new_items.iter().map(|i| &i.meta().1).dedup().collect(); + // let items_type = match item_types.len() { + // 0 => None, + // 1 => Some(item_types[0].clone()), + // _ => Some(FieldType::Union( + // item_types.into_iter().map(|t| t.clone()).collect(), + // )), + // }; + // if let Some(ty) = items_type { + // let expected_type = FieldType::List(Box::new(ty)); + // if !self.is_subtype(&expected_type, &field_base_type) { + // anyhow::bail!("Could not unify {:?} with {:?}", expected_type, field_base_type); + // } + // } + Ok(BamlValueWithMeta::List(new_items, (meta, field_type))) + } + + BamlValueWithMeta::Media(m, meta) + if self.is_subtype( + &FieldType::Primitive(TypeValue::Media(m.media_type)), + &field_base_type, + ) => + { + Ok(BamlValueWithMeta::Media(m, (meta, field_type))) + } + BamlValueWithMeta::Media(_, _) => { + anyhow::bail!("Could not unify Media with {:?}", field_base_type) + } + + BamlValueWithMeta::Enum(name, val, meta) => { + if self.is_subtype(&FieldType::Enum(name.clone()), &field_base_type) { + Ok(BamlValueWithMeta::Enum(name, val, (meta, field_type))) + } else { + anyhow::bail!("Could not unify Enum {} with {:?}", name, field_base_type) + } + } + + BamlValueWithMeta::Class(name, fields, meta) => { + if self.find_class(&name).is_err() { + // // Classes not present in the IR may be dynamically generated. + // // In this case, all types will be inferred, rather than distributed + // // from the `field_type` parameter. + return distribute_infer_class(self, &name, fields, meta); + } + if !self.is_subtype(&FieldType::Class(name.clone()), &field_base_type) { + anyhow::bail!("Could not unify Class {} with {:?}", name, field_base_type); + } else { + let class_type = &self.find_class(&name)?.item.elem; + let class_fields: BamlMap = class_type + .static_fields + .iter() + .map(|field_node| { + ( + field_node.elem.name.clone(), + field_node.elem.r#type.elem.clone(), + ) + }) + .collect(); + let mapped_fields = fields + .into_iter() + .map(|(k, v)| { + let field_type = match class_fields.get(k.as_str()) { + Some(ft) => ft.clone(), + None => infer_type_with_meta(&v).unwrap_or(UNIT_TYPE), + }; + let mapped_field = self.distribute_type_with_meta(v, field_type)?; + Ok((k, mapped_field)) + }) + .collect::>>>( + )?; + Ok(BamlValueWithMeta::Class( + name, + mapped_fields, + (meta, field_type), + )) + } + } + } + } + /// Constraints may live in several places. A constrained base type stors its - /// constraints by wrapping itself in the `FieldType::Constrained` constructor. + /// constraints by wrapping itself in the `FieldType::WithMetadata` constructor. /// Additionally, `FieldType::Class` may have constraints stored in its class node, /// and `FieldType::Enum` can store constraints in its `Enum` node. - /// And the `FieldType::Constrained` constructor might wrap another - /// `FieldType::Constrained` constructor. + /// And the `FieldType::WithMetadata` constructor might wrap another + /// `FieldType::WithMetadata` constructor. /// /// This function collects constraints for a given type from all these /// possible sources. Whenever querying a type for its constraints, you @@ -536,35 +738,67 @@ impl IRHelper for IntermediateRepr { &'a self, field_type: &'a FieldType, ) -> (&'a FieldType, Vec) { + let (field_type, metadata) = self.distribute_metadata(field_type); + (field_type, metadata.0) + } + + /// For any FieldType, check if the field type is FieldType::WithMetadata, + /// and if so, return the metadata alongside the base type. + /// All other field types will be returned as is, alongside default metadata. + fn distribute_metadata<'a>( + &'a self, + field_type: &'a FieldType, + ) -> (&'a FieldType, (Vec, StreamingBehavior)) { match field_type { FieldType::Class(class_name) => match self.find_class(class_name) { - Err(_) => (field_type, Vec::new()), - Ok(class_node) => (field_type, class_node.item.attributes.constraints.clone()), + Err(_) => (field_type, (Vec::new(), StreamingBehavior::default())), + Ok(class_node) => ( + field_type, + ( + class_node.item.attributes.constraints.clone(), + class_node.item.attributes.streaming_behavior(), + ), + ), }, FieldType::Enum(enum_name) => match self.find_enum(enum_name) { - Err(_) => (field_type, Vec::new()), - Ok(enum_node) => (field_type, enum_node.item.attributes.constraints.clone()), + Err(_) => (field_type, (Vec::new(), StreamingBehavior::default())), + Ok(enum_node) => ( + field_type, + ( + enum_node.item.attributes.constraints.clone(), + StreamingBehavior::default(), + ), + ), }, // Check the first level to see if it's constrained. - FieldType::Constrained { base, constraints } => { + FieldType::WithMetadata { + base, + constraints, + streaming_behavior, + } => { match base.as_ref() { // If so, we must check the second level to see if we need to combine // constraints across levels. - // The recursion here means that arbitrarily nested `FieldType::Constrained`s + // The recursion here means that arbitrarily nested `FieldType::WithMetadata`s // will be collapsed before the function returns. - FieldType::Constrained { .. } => { - let (sub_base, sub_constraints) = - self.distribute_constraints(base.as_ref()); + FieldType::WithMetadata { .. } => { + let (sub_base, (sub_constraints, sub_streaming_behavior)) = + self.distribute_metadata(base.as_ref()); let combined_constraints = vec![constraints.clone(), sub_constraints] .into_iter() .flatten() .collect(); - (sub_base, combined_constraints) + let combined_streaming_behavior = + streaming_behavior.combine(&sub_streaming_behavior); + ( + sub_base, + (combined_constraints, combined_streaming_behavior), + ) } - _ => (base, constraints.clone()), + _ => (base, (constraints.clone(), streaming_behavior.clone())), } } - _ => (field_type, Vec::new()), + _ => (field_type, (Vec::new(), StreamingBehavior::default())), } } @@ -579,12 +813,234 @@ impl IRHelper for IntermediateRepr { .iter() .any(|Constraint { level, .. }| *level == ConstraintLevel::Check) } + + fn recursive_alias_definition(&self, alias_name: &str) -> Option<&FieldType> { + if let Some(cycle) = self + .structural_recursive_alias_cycles() + .iter() + .find(|cycle| cycle.contains_key(alias_name)) + { + cycle.get(alias_name) + } else { + None + } + } +} + +/// For types of values that contain other values (e.g. lists, maps), compute +/// the type of the contained value. +/// TODO: Does this always terminate, especially in the case of recursive type +/// aliases? +/// +/// When the field_type is a union, different variants may have +/// children of different types. We take a baml_value itself as a +/// parameter, and typecheck it against every variant of the union. +/// The first typechecking union variant is used as the type of +/// the children. This feels unsound, but it's not clear what we +/// should declare as the `item_type` in the case of unions that +/// admit multiple different children. (Perhaps a union of all the +/// child-having variants?). +fn item_type( + ir: &IntermediateRepr, + field_type: &FieldType, + baml_child_values: &BamlValueWithMeta, +) -> Option { + // dbg!(&baml_child_value); + // dbg!(&field_type); + let res = match ir.distribute_metadata(field_type).0 { + FieldType::Class(_) => None, + FieldType::Enum(_) => None, + FieldType::List(inner) => Some(*inner.clone()), + FieldType::Literal(_) => None, + FieldType::Map(k, v) => Some(*v.clone()), + FieldType::Optional(inner) => item_type(ir, &*inner, baml_child_values), + FieldType::Primitive(_) => None, + FieldType::RecursiveTypeAlias(alias_name) => ir + .recursive_alias_definition(alias_name) + .and_then(|resolved_type| item_type(ir, resolved_type, baml_child_values)), + FieldType::Union(variants) => { + let variant_children = variants.iter().filter_map(|variant| item_type(ir, variant, baml_child_values)).collect::>(); + match variant_children.len() { + 0 => None, + 1 => Some(variant_children[0].clone()), + _ => Some(FieldType::Union(variant_children)), + } + }, + FieldType::Tuple(_) => None, + FieldType::WithMetadata { base, .. } => item_type(ir, base, baml_child_values), + }; + res +} + +fn typecheck_value_with_meta( + ir: &IntermediateRepr, + value: &BamlValueWithMeta, + field_type: &FieldType, +) -> bool { + let field_base_type = ir.distribute_metadata(&field_type).0; + match value { + BamlValueWithMeta::String(s, meta) => { + let literal_type = FieldType::Literal(LiteralValue::String(s.clone())); + let primitive_type = FieldType::Primitive(TypeValue::String); + + ir.is_subtype(&literal_type, &field_base_type) + || ir.is_subtype(&primitive_type, &field_base_type) + } + BamlValueWithMeta::Int(i, meta) => { + ir.is_subtype(&FieldType::Literal(LiteralValue::Int(*i)), &field_base_type) + } + BamlValueWithMeta::Float(f, meta) => { + ir.is_subtype(&FieldType::Primitive(TypeValue::Float), &field_base_type) + } + + BamlValueWithMeta::Bool(b, meta) => { + let literal_type = FieldType::Literal(LiteralValue::Bool(*b)); + let primitive_type = FieldType::Primitive(TypeValue::Bool); + + ir.is_subtype(&literal_type, &field_base_type) + || ir.is_subtype(&primitive_type, &field_base_type) + } + + BamlValueWithMeta::Null(meta) => true, + + // TODO: Handle enums and literal keys. + BamlValueWithMeta::Map(pairs, meta) => { + true + // TODO! + } + + BamlValueWithMeta::List(items, meta) => { + let items_ok = items + .iter() + .map(|i| { + item_type(ir, &field_type, i) + .map_or(true, |item_ty| typecheck_value_with_meta(ir, i, &item_ty)) + }) + .all(|x| x); + items_ok + } + + BamlValueWithMeta::Media(m, meta) => ir.is_subtype( + &FieldType::Primitive(TypeValue::Media(m.media_type)), + &field_base_type, + ), + + BamlValueWithMeta::Enum(name, val, meta) => { + ir.is_subtype(&FieldType::Enum(name.clone()), &field_base_type) + } + + BamlValueWithMeta::Class(name, fields, meta) => { + // // Classes not present in the IR may be dynamically generated. + // // In this case, all types will be inferred, rather than distributed + // // from the `field_type` parameter. + + // TODO + true + + // if ir.find_class(&name).is_err() { + // return distribute_infer_class(self, &name, fields, meta); + // } + // if !self.is_subtype(&FieldType::Class(name.clone()), &field_base_type) { + // anyhow::bail!("Could not unify Class {} with {:?}", name, field_base_type); + // } else { + // let class_type = &self.find_class(&name)?.item.elem; + // let class_fields: BamlMap = class_type + // .static_fields + // .iter() + // .map(|field_node| { + // ( + // field_node.elem.name.clone(), + // field_node.elem.r#type.elem.clone(), + // ) + // }) + // .collect(); + // let mapped_fields = fields + // .into_iter() + // .map(|(k, v)| { + // let field_type = match class_fields.get(k.as_str()) { + // Some(ft) => ft.clone(), + // None => infer_type_with_meta(&v).unwrap_or(UNIT_TYPE), + // }; + // let mapped_field = self.distribute_type_with_meta(v, field_type)?; + // Ok((k, mapped_field)) + // }) + // .collect::>>>( + // )?; + // Ok(BamlValueWithMeta::Class( + // name, + // mapped_fields, + // (meta, field_type), + // )) + // } + } + } +} + +/// Like item_type, but specialized for maps. +fn map_types<'ir, 'a>( + ir: &'ir IntermediateRepr, + field_type: &'a FieldType, +) -> Option<(&'a FieldType, &'a FieldType)> +where + 'ir: 'a, +{ + match ir.distribute_metadata(field_type).0 { + FieldType::Map(key, value) => Some((key.as_ref(), value.as_ref())), + FieldType::RecursiveTypeAlias(alias_name) => ir + .recursive_alias_definition(alias_name) + .and_then(|alias_definition| map_types(ir, &alias_definition)), + FieldType::Primitive(_) => None, + FieldType::Enum(_) => None, + FieldType::List(_) => None, + FieldType::Literal(_) => None, + FieldType::Optional(base) => map_types(ir, base.as_ref()), + FieldType::Tuple(_) => None, + FieldType::Union(variants) => { + // When encountering a union, we return the key/value types of the + // first map we find inside the union. + // TODO: Give more thought to what `map_types` should return for + // unions, because the current logic is faulty for unions containing + // multiple maps. + let mut variant_map_types = variants + .into_iter() + .filter_map(|variant| map_types(ir, variant)); + variant_map_types.next() + } + FieldType::Class(_) => None, + FieldType::WithMetadata { .. } => { + unreachable!("distribute_metadata never returns this variant") + } + } } const UNIT_TYPE: FieldType = FieldType::Tuple(vec![]); -/// Derive the simplest type that can categorize a given value. This is meant to be used -/// by `distribute_type`, for dynamic fields of classes, whose types are not known statically. +/// A helper function for `distribute_type_with_meta`, for cases where a class +/// is not present in the IR. In this case, when we don't have a class +/// definition in the IR (e.g. because the class was introduced through +/// TypeBuilder), we enhance the `BamlValueWithMeta` using types inferred from +/// each field of the class instance. +fn distribute_infer_class( + ir: &IntermediateRepr, + class_name: &str, + class_fields: IndexMap>, + meta: T, +) -> Result> { + let fields = class_fields + .into_iter() + .map(|(k, v)| { + let field_type = infer_type_with_meta(&v).unwrap_or(UNIT_TYPE); + let field = ir.distribute_type_with_meta(v, field_type)?; + Ok((k.to_string(), field)) + }) + .collect::>>()?; + Ok(BamlValueWithMeta::Class( + class_name.to_string(), + fields, + (meta, FieldType::class(class_name)), + )) +} + pub fn infer_type(value: &BamlValue) -> Option { let ret = match value { BamlValue::Int(_) => Some(FieldType::Primitive(TypeValue::Int)), @@ -626,12 +1082,58 @@ pub fn infer_type(value: &BamlValue) -> Option { ret } +/// Derive the simplest type that can categorize a given value. This is meant to be used +/// by `distribute_type`, for dynamic fields of classes, whose types are not known statically. +/// TODO: Tests. +pub fn infer_type_with_meta(value: &BamlValueWithMeta) -> Option { + let ret = match value { + BamlValueWithMeta::Int(_, _) => Some(FieldType::Primitive(TypeValue::Int)), + BamlValueWithMeta::Bool(_, _) => Some(FieldType::Primitive(TypeValue::Bool)), + BamlValueWithMeta::Float(_, _) => Some(FieldType::Primitive(TypeValue::Float)), + BamlValueWithMeta::String(_, _) => Some(FieldType::Primitive(TypeValue::String)), + BamlValueWithMeta::Null(_) => Some(FieldType::Primitive(TypeValue::Null)), + BamlValueWithMeta::Map(pairs, _) => { + let v_tys = pairs + .iter() + .filter_map(|(_, v)| infer_type_with_meta(v)) + .dedup() + .collect::>(); + let k_ty = FieldType::Primitive(TypeValue::String); + let v_ty = match v_tys.len() { + 0 => None, + 1 => Some(v_tys[0].clone()), + _ => Some(FieldType::Union(v_tys)), + }?; + Some(FieldType::Map(Box::new(k_ty), Box::new(v_ty))) + } + BamlValueWithMeta::List(items, _) => { + let item_tys = items + .iter() + .filter_map(infer_type_with_meta) + .dedup() + .collect::>(); + let item_ty = match item_tys.len() { + 0 => None, + 1 => Some(item_tys[0].clone()), + _ => Some(FieldType::Union(item_tys)), + }?; + Some(FieldType::List(Box::new(item_ty))) + } + BamlValueWithMeta::Media(m, _) => { + Some(FieldType::Primitive(TypeValue::Media(m.media_type))) + } + BamlValueWithMeta::Enum(enum_name, _, _) => Some(FieldType::Enum(enum_name.clone())), + BamlValueWithMeta::Class(class_name, _, _) => Some(FieldType::Class(class_name.clone())), + }; + ret +} + #[cfg(test)] mod tests { use super::*; use baml_types::{ BamlMedia, BamlMediaContent, BamlMediaType, BamlValue, Constraint, ConstraintLevel, - FieldType, JinjaExpression, MediaBase64, TypeValue, + FieldType, JinjaExpression, MediaBase64, StreamingBehavior, TypeValue, }; use repr::make_test_ir; @@ -912,12 +1414,15 @@ mod tests { } } - let input = FieldType::Constrained { + let input = FieldType::WithMetadata { constraints: vec![mk_constraint("a")], - base: Box::new(FieldType::Constrained { + streaming_behavior: StreamingBehavior::default(), + base: Box::new(FieldType::WithMetadata { constraints: vec![mk_constraint("b")], - base: Box::new(FieldType::Constrained { + streaming_behavior: StreamingBehavior::default(), + base: Box::new(FieldType::WithMetadata { constraints: vec![mk_constraint("c")], + streaming_behavior: StreamingBehavior::default(), base: Box::new(FieldType::Primitive(TypeValue::Int)), }), }), @@ -1008,6 +1513,20 @@ mod subtype_tests { assert!(!ir().is_subtype(&l_o, &l_i)); } + fn subtype_list_with_metadata() { + let l_i = FieldType::WithMetadata { + base: Box::new(mk_list(mk_int())), + constraints: vec![], + streaming_behavior: StreamingBehavior { + done: true, + state: false, + }, + }; + let l_o = mk_list(mk_int()); + assert!(ir().is_subtype(&l_i, &l_o)); + assert!(ir().is_subtype(&l_o, &l_i)); + } + #[test] fn subtype_tuple() { let x = mk_tuple(vec![mk_int(), mk_optional(mk_int())]); @@ -1032,4 +1551,94 @@ mod subtype_tests { let x = FieldType::Primitive(TypeValue::Media(BamlMediaType::Audio)); assert!(ir().is_subtype(&x, &x)); } + + // Given: + // BamlValue::List ["a", {}] + // field_type: RTA("JsonValue") + // + // List [ + // "a" (Meta: Type: JsonValue), + // {} (Meta: Type: JsonValue), + // ] (Meta: Type: JsonValue) + + #[test] + fn test_item_type() { + let ir = make_test_ir( + r##" + type A = A[] + type B = B[][] + type C = map + + type JsonValue = float | JsonValue[] | map + + type JsonValue2 = float | JsonValue2List | JsonValue2Object + type JsonValue2List = JsonValue2[] + type JsonValue2Object = map + + type Foo = float | JsonValue | JsonValue2 + type U = string | Foo + "##, + ) + .unwrap(); + + let example_a = BamlValueWithMeta::List(vec![], ()); + let example_b = BamlValueWithMeta::List(vec![BamlValueWithMeta::List(vec![], ())], ()); + let example_c = BamlValueWithMeta::Map(vec![].into_iter().collect(), ()); + let example_json = BamlValueWithMeta::Map( + vec![ + ("foo".to_string(), BamlValueWithMeta::Bool(true, ())), + ( + "bar".to_string(), + BamlValueWithMeta::List( + vec![ + BamlValueWithMeta::Int(1, ()), + BamlValueWithMeta::Int(2, ()), + BamlValueWithMeta::Int(3, ()), + ], + (), + ), + ), + ] + .into_iter() + .collect(), + (), + ); + assert_eq!( + item_type( + &ir, + &FieldType::RecursiveTypeAlias("A".to_string()), + &example_a + ), + Some(FieldType::RecursiveTypeAlias("A".to_string())) + ); + assert_eq!( + item_type( + &ir, + &FieldType::RecursiveTypeAlias("B".to_string()), + &example_b + ), + Some(FieldType::List(Box::new(FieldType::RecursiveTypeAlias( + "B".to_string() + )))) + ); + assert_eq!( + item_type( + &ir, + &FieldType::RecursiveTypeAlias("C".to_string()), + &example_c + ), + Some(FieldType::RecursiveTypeAlias("C".to_string())) + ); + assert_eq!( + item_type( + &ir, + &FieldType::RecursiveTypeAlias("JsonValue".to_string()), + &example_json + ), + Some(FieldType::Map( + Box::new(FieldType::Primitive(TypeValue::String)), + Box::new(FieldType::RecursiveTypeAlias("JsonValue".to_string())) + )) + ); + } } diff --git a/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs b/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs index a13c5cc498..f357144e91 100644 --- a/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs +++ b/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs @@ -354,7 +354,7 @@ impl ArgCoercer { } } } - (FieldType::Constrained { .. }, _) => { + (FieldType::WithMetadata { .. }, _) => { unreachable!("The return value of distribute_constraints can never be FieldType::Constrainted"); } }?; @@ -417,7 +417,7 @@ fn first_failing_assert_nested<'a>( #[cfg(test)] mod tests { - use baml_types::JinjaExpression; + use baml_types::{JinjaExpression, StreamingBehavior}; use crate::ir::repr::make_test_ir; @@ -442,13 +442,14 @@ mod tests { ) .unwrap(); let value = BamlValue::Int(1); - let type_ = FieldType::Constrained { + let type_ = FieldType::WithMetadata { base: Box::new(FieldType::Primitive(TypeValue::Int)), constraints: vec![Constraint { level: ConstraintLevel::Assert, expression: JinjaExpression("this.length() > 0".to_string()), label: Some("foo".to_string()), }], + streaming_behavior: StreamingBehavior::default(), }; let arg_coercer = ArgCoercer { span_path: None, diff --git a/engine/baml-lib/baml-core/src/ir/json_schema.rs b/engine/baml-lib/baml-core/src/ir/json_schema.rs index 79c463ea10..a7eb3a7e52 100644 --- a/engine/baml-lib/baml-core/src/ir/json_schema.rs +++ b/engine/baml-lib/baml-core/src/ir/json_schema.rs @@ -262,7 +262,7 @@ impl WithJsonSchema for FieldType { } } } - FieldType::Constrained { base, .. } => base.json_schema(), + FieldType::WithMetadata { base, .. } => base.json_schema(), } } } diff --git a/engine/baml-lib/baml-core/src/ir/repr.rs b/engine/baml-lib/baml-core/src/ir/repr.rs index 19c58d63e0..baa4b7eb87 100644 --- a/engine/baml-lib/baml-core/src/ir/repr.rs +++ b/engine/baml-lib/baml-core/src/ir/repr.rs @@ -2,8 +2,10 @@ use std::collections::HashSet; use anyhow::{anyhow, Result}; use baml_types::{ - Constraint, ConstraintLevel, FieldType, JinjaExpression, StringOr, UnresolvedValue, + Constraint, ConstraintLevel, FieldType, JinjaExpression, Resolvable, StreamingBehavior, StringOr, + UnresolvedValue, }; +use either::Either; use indexmap::{IndexMap, IndexSet}; use internal_baml_parser_database::{ walkers::{ @@ -13,7 +15,7 @@ use internal_baml_parser_database::{ Attributes, ParserDatabase, PromptAst, RetryPolicyStrategy, TypeWalker, }; -use internal_baml_schema_ast::ast::{self, FieldArity, SubType, ValExpId, WithName, WithSpan}; +use internal_baml_schema_ast::ast::{self, Attribute, FieldArity, SubType, ValExpId, WithName, WithSpan}; use internal_llm_client::{ClientProvider, ClientSpec, UnresolvedClientProperty}; use serde::Serialize; @@ -279,6 +281,19 @@ impl NodeAttributes { pub fn get(&self, key: &str) -> Option<&UnresolvedValue<()>> { self.meta.get(key) } + + pub fn streaming_behavior(&self) -> StreamingBehavior { + fn is_some_true(maybe_value: Option<&UnresolvedValue<()>>) -> bool { + match maybe_value { + Some(Resolvable::Bool(true, _)) => true, + _ => false, + } + } + StreamingBehavior { + done: is_some_true(self.get("stream.done")), + state: is_some_true(self.get("stream.with_state")), + } + } } impl Default for NodeAttributes { @@ -303,6 +318,9 @@ fn to_ir_attributes( dynamic_type, skip, constraints, + streaming_done, + streaming_needed, + streaming_state, } = attributes; let description = description @@ -327,11 +345,49 @@ fn to_ir_attributes( None } }); + let streaming_done = streaming_done.as_ref().and_then(|v| { + if *v { + Some(( + "stream.done".to_string(), + UnresolvedValue::Bool(true, ()), + )) + } else { + None + } + }); + let streaming_needed = streaming_needed.as_ref().and_then(|v| { + if *v { + Some(( + "stream.not_null".to_string(), + UnresolvedValue::Bool(true, ()), + )) + } else { + None + } + }); + let streaming_state = streaming_state.as_ref().and_then(|v| { + if *v { + Some(( + "stream.with_state".to_string(), + UnresolvedValue::Bool(true, ()), + )) + } else { + None + } + }); - let meta = vec![description, alias, dynamic_type, skip] - .into_iter() - .flatten() - .collect(); + let meta = vec![ + description, + alias, + dynamic_type, + skip, + streaming_done, + streaming_needed, + streaming_state, + ] + .into_iter() + .filter_map(|s| s) + .collect(); (meta, constraints.clone()) }) } @@ -374,7 +430,7 @@ fn type_with_arity(t: FieldType, arity: &FieldArity) -> FieldType { impl WithRepr for ast::FieldType { // TODO: (Greg) This code only extracts constraints, and ignores any // other types of attributes attached to the type directly. - fn attributes(&self, _db: &ParserDatabase) -> NodeAttributes { + fn attributes(&self, db: &ParserDatabase) -> NodeAttributes { let constraints = self .attributes() .iter() @@ -407,8 +463,27 @@ impl WithRepr for ast::FieldType { }) }) .collect::>(); + let mut meta = IndexMap::new(); + if self + .attributes() + .iter() + .find(|Attribute { name, .. }| name.name() == "stream.done") + .is_some() + { + let val: UnresolvedValue<()> = Resolvable::Bool(true, ()); + meta.insert("stream.done".to_string(), val); + } + if self + .attributes() + .iter() + .find(|Attribute { name, .. }| name.name() == "stream.with_state") + .is_some() + { + let val: UnresolvedValue<()> = Resolvable::Bool(true, ()); + meta.insert("stream.with_state".to_string(), val); + } let attributes = NodeAttributes { - meta: IndexMap::new(), + meta, constraints, span: Some(self.span().clone()), }; @@ -417,8 +492,10 @@ impl WithRepr for ast::FieldType { } fn repr(&self, db: &ParserDatabase) -> Result { - let constraints = WithRepr::attributes(self, db).constraints; - let has_constraints = !constraints.is_empty(); + let attributes = WithRepr::attributes(self, db); + let has_constraints = !attributes.constraints.is_empty(); + let streaming_behavior = attributes.streaming_behavior(); + let has_special_streaming_behavior = streaming_behavior != StreamingBehavior::default(); let base = match self { ast::FieldType::Primitive(arity, typeval, ..) => { let repr = FieldType::Primitive(*typeval); @@ -442,9 +519,10 @@ impl WithRepr for ast::FieldType { let base_class = FieldType::Class(class_walker.name().to_string()); match class_walker.get_constraints(SubType::Class) { Some(constraints) if !constraints.is_empty() => { - FieldType::Constrained { + FieldType::WithMetadata { base: Box::new(base_class), constraints, + streaming_behavior: streaming_behavior.clone(), } } _ => base_class, @@ -454,9 +532,10 @@ impl WithRepr for ast::FieldType { let base_type = FieldType::Enum(enum_walker.name().to_string()); match enum_walker.get_constraints(SubType::Enum) { Some(constraints) if !constraints.is_empty() => { - FieldType::Constrained { + FieldType::WithMetadata { base: Box::new(base_type), constraints, + streaming_behavior: streaming_behavior.clone(), } } _ => base_type, @@ -515,10 +594,13 @@ impl WithRepr for ast::FieldType { ), }; - let with_constraints = if has_constraints { - FieldType::Constrained { + + let use_metadata = has_constraints || has_special_streaming_behavior; + let with_constraints = if use_metadata { + FieldType::WithMetadata { base: Box::new(base.clone()), - constraints, + constraints: attributes.constraints, + streaming_behavior, } } else { base @@ -671,19 +753,17 @@ impl WithRepr for FieldWalker<'_> { } fn repr(&self, db: &ParserDatabase) -> Result { + let ast_field_type = self.ast_field().expr.as_ref().ok_or(anyhow!( + "Internal error occurred while resolving repr of field {:?}", + self.name(), + ))?; + let field_type_attributes = WithRepr::attributes(ast_field_type, db); + let field_type = ast_field_type.repr(db)?; Ok(Field { name: self.name().to_string(), r#type: Node { - elem: self - .ast_field() - .expr - .clone() - .ok_or(anyhow!( - "Internal error occurred while resolving repr of field {:?}", - self.name(), - ))? - .repr(db)?, - attributes: self.attributes(db), + elem: field_type, + attributes: field_type_attributes, }, docstring: self.get_documentation().map(Docstring), }) @@ -1180,6 +1260,20 @@ pub fn make_test_ir(source_code: &str) -> anyhow::Result { Ok(ir) } +/// Pull out `StreamingBehavior` from `NodeAttributes`. +fn streaming_behavior_from_attributes(attributes: &NodeAttributes) -> StreamingBehavior { + fn is_some_true(maybe_value: Option<&UnresolvedValue<()>>) -> bool { + match maybe_value { + Some(Resolvable::Bool(true, _)) => true, + _ => false, + } + } + StreamingBehavior { + done: is_some_true(attributes.get("stream.done")), + state: is_some_true(attributes.get("stream.with_state")), + } +} + #[cfg(test)] mod tests { use super::*; @@ -1278,11 +1372,61 @@ mod tests { } #[test] + fn test_streaming_attributes() { + let ir = make_test_ir( + r##" + class Foo { + foo_int int @stream.not_null + foo_bool bool @stream.with_state + foo_list int[] @stream.done + } + + class Bar { + name string @stream.done + message string + @@stream.done + } + "##, + ) + .unwrap(); + let foo = ir.find_class("Foo").unwrap(); + assert!(!foo.streaming_done()); + match foo.walk_fields().collect::>().as_slice() { + [field1, field2, field3] => { + let type1 = &field1.item.elem.r#type; + assert!(field1.streaming_needed()); + assert!(type1.attributes.get("stream.not_null").is_none()); + let type2 = &field2.item.elem.r#type; + assert!(!field2.streaming_state()); + assert!(type2.attributes.get("stream.with_state").is_some()); + let type3 = &field3.item.elem.r#type; + assert!(!field3.streaming_done()); + assert!(type3.attributes.get("stream.done").is_some()); + } + _ => panic!("Expected exactly 3 fields"), + } + let bar = ir.find_class("Bar").unwrap(); + assert!(bar.streaming_done()); + match bar.walk_fields().collect::>().as_slice() { + [field1, field2] => { + assert!(!field1.streaming_done()); + assert!(field1 + .item + .elem + .r#type + .attributes + .get("stream.done") + .is_some()); + } + _ => panic!("Expected exactly 2 fields"), + } + } + fn test_resolve_type_alias() { let ir = make_test_ir( r##" - type One = int - type Two = One + type One = int + type Two = One type Three = Two class Test { @@ -1296,6 +1440,7 @@ mod tests { let alias = class.find_field("field").unwrap(); assert_eq!(*alias.r#type(), FieldType::Primitive(TypeValue::Int)); + } #[test] @@ -1316,7 +1461,7 @@ mod tests { let class = ir.find_class("Test").unwrap(); let alias = class.find_field("field").unwrap(); - let FieldType::Constrained { base, constraints } = alias.r#type() else { + let FieldType::WithMetadata { base, constraints, .. } = alias.r#type() else { panic!( "expected resolved constrained type, found {:?}", alias.r#type() diff --git a/engine/baml-lib/baml-core/src/ir/walker.rs b/engine/baml-lib/baml-core/src/ir/walker.rs index 1a7d1a24ba..624b224ce3 100644 --- a/engine/baml-lib/baml-core/src/ir/walker.rs +++ b/engine/baml-lib/baml-core/src/ir/walker.rs @@ -245,6 +245,14 @@ impl<'a> Walker<'a, &'a Class> { .transpose() } + pub fn streaming_done(&self) -> bool { + self.item.attributes.get("stream.done").is_some() + } + + pub fn streaming_state(&self) -> bool { + self.item.attributes.get("stream.with_state").is_some() + } + pub fn walk_fields(&'a self) -> impl Iterator> { self.item.elem.static_fields.iter().map(|f| Walker { db: self.db, @@ -390,6 +398,18 @@ impl<'a> Walker<'a, &'a Field> { .transpose() } + pub fn streaming_done(&self) -> bool { + self.item.attributes.get("stream.done").is_some() + } + + pub fn streaming_needed(&self) -> bool { + self.item.attributes.get("stream.not_null").is_some() + } + + pub fn streaming_state(&self) -> bool { + self.item.attributes.get("stream.with_state").is_some() + } + pub fn span(&self) -> Option<&crate::Span> { self.item.attributes.span.as_ref() } diff --git a/engine/baml-lib/baml-types/src/baml_value.rs b/engine/baml-lib/baml-types/src/baml_value.rs index c85c4d93c5..61e5b3af91 100644 --- a/engine/baml-lib/baml-types/src/baml_value.rs +++ b/engine/baml-lib/baml-types/src/baml_value.rs @@ -1,14 +1,15 @@ -use std::collections::HashMap; use std::{ collections::{HashSet, VecDeque}, fmt, }; -use serde::ser::SerializeMap; -use serde::{de::Visitor, Deserialize, Deserializer, Serialize, Serializer}; +use anyhow::{Context,Result}; +use std::collections::HashMap; +use indexmap::IndexMap; +use serde::{de::Visitor, ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer}; -use crate::media::BamlMediaType; -use crate::{BamlMap, BamlMedia, ResponseCheck}; +use crate::{media::BamlMediaType, ResponseCheck}; +use crate::{BamlMap, BamlMedia}; #[derive(Clone, Debug, PartialEq)] pub enum BamlValue { @@ -35,26 +36,6 @@ impl serde::Serialize for BamlValue { BamlValue::List(l) => l.serialize(serializer), BamlValue::Media(m) => { m.serialize(serializer) - // let struct_name = match m.media_type() { - // BamlMediaType::Image => "BamlImage", - // BamlMediaType::Audio => "BamlAudio", - // }; - // let mut s = serializer.serialize_struct(struct_name, 2)?; - // match m { - // BamlMedia::File(_, f) => { - // s.serialize_field("path", &f.path)?; - // s.serialize_field("media_type", &f.media_type)?; - // } - // BamlMedia::Url(_, u) => { - // s.serialize_field("url", &u.url)?; - // s.serialize_field("media_type", &u.media_type)?; - // } - // BamlMedia::Base64(_, b) => { - // s.serialize_field("base64", &b.base64)?; - // s.serialize_field("media_type", &b.media_type)?; - // } - // } - // s.end() } BamlValue::Enum(_, v) => serializer.serialize_str(v), BamlValue::Class(_, m) => m.serialize(serializer), @@ -465,9 +446,9 @@ impl BamlValueWithMeta { } } - pub fn map_meta(&self, f: F) -> BamlValueWithMeta + pub fn map_meta<'a, F, U>(&'a self, f: F) -> BamlValueWithMeta where - F: Fn(&T) -> U + Copy, + F: Fn(&'a T) -> U + Copy, { match self { BamlValueWithMeta::String(v, m) => BamlValueWithMeta::String(v.clone(), f(m)), @@ -493,6 +474,98 @@ impl BamlValueWithMeta { BamlValueWithMeta::Null(m) => BamlValueWithMeta::Null(f(m)), } } + + pub fn map_meta_owned(self, f: F) -> BamlValueWithMeta + where + F: Fn(T) -> U + Copy, + { + match self { + BamlValueWithMeta::String(v, m) => BamlValueWithMeta::String(v, f(m)), + BamlValueWithMeta::Int(v, m) => BamlValueWithMeta::Int(v, f(m)), + BamlValueWithMeta::Float(v, m) => BamlValueWithMeta::Float(v, f(m)), + BamlValueWithMeta::Bool(v, m) => BamlValueWithMeta::Bool(v, f(m)), + BamlValueWithMeta::Map(v, m) => BamlValueWithMeta::Map( + v.into_iter().map(|(k, v)| (k, v.map_meta_owned(f))).collect(), + f(m), + ), + BamlValueWithMeta::List(v, m) => { + BamlValueWithMeta::List(v.into_iter().map(|v| v.map_meta_owned(f)).collect(), f(m)) + } + BamlValueWithMeta::Media(v, m) => BamlValueWithMeta::Media(v, f(m)), + BamlValueWithMeta::Enum(v, e, m) => BamlValueWithMeta::Enum(v, e, f(m)), + BamlValueWithMeta::Class(n, fs, m) => BamlValueWithMeta::Class( + n, + fs.into_iter() + .map(|(k, v)| (k, v.map_meta_owned(f))) + .collect(), + f(m), + ), + BamlValueWithMeta::Null(m) => BamlValueWithMeta::Null(f(m)), + } + } + + /// Combine two similar shaped baml values by tupling their metadata + /// on a node-by-node basis. + /// + /// The baml value calling `zip_meta` is the "primary" one, whose value + /// data will live on in the returned baml value. + pub fn zip_meta(self, other: &BamlValueWithMeta) -> Result> + where T: std::fmt::Debug + { + let other_meta: U = other.meta().clone(); + let error_msg = String::new(); // format!("Could not unify {:?} with {:?}.", self, other); + let ret = match (self, other) { + (BamlValueWithMeta::Null(meta1), _) => { + Result::<_,_>::Ok(BamlValueWithMeta::Null((meta1, other_meta))) + }, + (BamlValueWithMeta::String(s1, meta1), BamlValueWithMeta::String(_s2, _)) if true => Ok(BamlValueWithMeta::String(s1, (meta1, other_meta))), + (BamlValueWithMeta::String(_,_), _) => anyhow::bail!("Unification error"), + (BamlValueWithMeta::Int(s1, meta1), BamlValueWithMeta::Int(_s2, _)) if true => Ok(BamlValueWithMeta::Int(s1, (meta1, other_meta))), + (BamlValueWithMeta::Int(_,_), _) => anyhow::bail!("Unification error"), + (BamlValueWithMeta::Float(s1, meta1), BamlValueWithMeta::Float(_s2, _)) if true => Ok(BamlValueWithMeta::Float(s1, (meta1, other_meta))), + (BamlValueWithMeta::Float(_,_), _) => anyhow::bail!("Unification error"), + (BamlValueWithMeta::Bool(s1, meta1), BamlValueWithMeta::Bool(_s2, _)) if true => Ok(BamlValueWithMeta::Bool(s1, (meta1, other_meta))), + (BamlValueWithMeta::Bool(_,_), _) => anyhow::bail!("Unification error"), + (BamlValueWithMeta::Map(s1, meta1), BamlValueWithMeta::Map(s2, _)) => { + let map_result = s1.into_iter().zip(s2).map(|((k1,v1), (_k2,v2))| { + v1.zip_meta(v2).map(|res| (k1, res)) + }).collect::>>()?; + Ok(BamlValueWithMeta::Map(map_result, (meta1, other_meta))) + }, + (BamlValueWithMeta::Map(_,_), _) => anyhow::bail!("Unification error"), + (BamlValueWithMeta::List(l1, meta1), BamlValueWithMeta::List(l2, _)) => { + let list_result = l1.into_iter().zip(l2).map(|(item1, item2)| { + item1.zip_meta(item2) + }).collect::>>()?; + Ok( BamlValueWithMeta::List(list_result, (meta1, other_meta))) + + } + (BamlValueWithMeta::List(_,_), _) => anyhow::bail!("Unification error"), + (BamlValueWithMeta::Media(m1, meta1), BamlValueWithMeta::Media(_m2, _)) if true => { + Ok(BamlValueWithMeta::Media(m1, (meta1, other_meta))) + } + (BamlValueWithMeta::Media(_, _), _) => anyhow::bail!("Unification error"), + (BamlValueWithMeta::Enum(x1, y1, meta1), BamlValueWithMeta::Enum(_x2, _y2, _)) if true => { + Ok(BamlValueWithMeta::Enum(x1, y1, (meta1, other_meta))) + } + (BamlValueWithMeta::Enum(_, _, _), _) => anyhow::bail!("Unification error"), + (BamlValueWithMeta::Class(name1, fields1, meta1), BamlValueWithMeta::Class(_name2, fields2, _)) if true => { + // TODO: We can remove a `clone` by checking that the fields + // are ordered the same way between the two classes, then consuming + // both classs' fields in parallel. + // let map_result = fields1.into_iter().zip(fields2).map(|((k1,v1),(_k2,v2))| { + // v1.zip_meta(v2).map(|r| (k1, r)) + // }).collect::>>()?; + let map_result = fields1.into_iter().map(|(k1, v1)| { + let v2 = fields2.get(&k1).context("Missing expected key")?; + v1.zip_meta(v2).map(|r| (k1, r)) + }).collect::>>()?; + Ok(BamlValueWithMeta::Class(name1, map_result, (meta1, other_meta))) + } + (BamlValueWithMeta::Class(_, _, _), _) => anyhow::bail!("Unification error"), + }; + ret.map_err(|_: anyhow::Error| anyhow::anyhow!(error_msg)) + } } /// An iterator over a BamlValue and all of its sub-values. @@ -608,6 +681,7 @@ impl Serialize for BamlValueWithMeta> { where S: Serializer, { + eprintln!("ABOUT TO SERIALIZE"); match self { BamlValueWithMeta::String(v, cr) => serialize_with_checks(v, cr, serializer), BamlValueWithMeta::Int(v, cr) => serialize_with_checks(v, cr, serializer), @@ -665,128 +739,186 @@ where fn add_checks<'a, S: SerializeMap>( map: &'a mut S, checks: &'a [ResponseCheck], -) -> Result<(), S::Error> { +) -> Result<(), S::Error> +{ if !checks.is_empty() { - let checks_map: HashMap<_, _> = checks - .iter() - .map(|check| (check.name.clone(), check)) - .collect(); + let checks_map: HashMap<_,_> = checks.iter().map(|check| (check.name.clone(), check)).collect(); map.serialize_entry("checks", &checks_map)?; } Ok(()) } -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_baml_value_with_meta_serialization() { - let baml_value: BamlValueWithMeta> = - BamlValueWithMeta::String("hi".to_string(), vec![]); - let baml_value_2: BamlValueWithMeta> = BamlValueWithMeta::Class( - "ContactInfo".to_string(), - vec![( - "primary".to_string(), - BamlValueWithMeta::Class( - "PhoneNumber".to_string(), - vec![( - "value".to_string(), - BamlValueWithMeta::String( - "123-456-7890".to_string(), - vec![ResponseCheck { - name: "foo".to_string(), - expression: "foo".to_string(), - status: "succeeded".to_string(), - }], - ), - )] - .into_iter() - .collect(), - vec![], - ), - )] - .into_iter() - .collect(), - vec![], - ); - assert!(serde_json::to_value(baml_value).is_ok()); - assert!(serde_json::to_value(baml_value_2).is_ok()); - } - - #[test] - fn test_serialize_class_checks() { - let baml_value: BamlValueWithMeta> = BamlValueWithMeta::Class( - "Foo".to_string(), - vec![ - ("foo".to_string(), BamlValueWithMeta::Int(1, vec![])), - ( - "bar".to_string(), - BamlValueWithMeta::String("hi".to_string(), vec![]), - ), - ] - .into_iter() - .collect(), - vec![ResponseCheck { - name: "bar_len_lt_foo".to_string(), - expression: "this.bar|length < this.foo".to_string(), - status: "failed".to_string(), - }], - ); - let expected = serde_json::json!({ - "value": {"foo": 1, "bar": "hi"}, - "checks": { - "bar_len_lt_foo": { - "name": "bar_len_lt_foo", - "expression": "this.bar|length < this.foo", - "status": "failed" - } - } - }); - let json = serde_json::to_value(baml_value).unwrap(); - assert_eq!(json, expected); - } - - #[test] - fn test_serialize_nested_class_checks() { - // Prepare an object for wrapping. - let foo: BamlValueWithMeta> = BamlValueWithMeta::Class( - "Foo".to_string(), - vec![ - ("foo".to_string(), BamlValueWithMeta::Int(1, vec![])), - ( - "bar".to_string(), - BamlValueWithMeta::String("hi".to_string(), vec![]), - ), - ] - .into_iter() - .collect(), - vec![ResponseCheck { - name: "bar_len_lt_foo".to_string(), - expression: "this.bar|length < this.foo".to_string(), - status: "failed".to_string(), - }], - ); - - // Prepare the top-level value. - let baml_value = BamlValueWithMeta::Class( - "FooWrapper".to_string(), - vec![("foo".to_string(), foo)].into_iter().collect(), - vec![], - ); - let expected = serde_json::json!({ - "foo": { - "value": {"foo": 1, "bar": "hi"}, - "checks": { - "bar_len_lt_foo": { - "name": "bar_len_lt_foo", - "expression": "this.bar|length < this.foo", - "status": "failed" - } - } - } - }); - let json = serde_json::to_value(baml_value).unwrap(); - assert_eq!(json, expected); - } +// impl Serialize for BamlValueWithMeta +// where T: SerializeMetadata + std::fmt::Debug, +// { +// +// fn serialize(&self, serializer: S) -> Result +// where +// S: Serializer, +// { +// let bare_value = self.value(); +// let metadata_fields = &self.meta().metadata_fields(&bare_value)?; +// match self { +// BamlValueWithMeta::String(v, _metadata) => serialize_with_checks(v, metadata_fields, serializer), +// BamlValueWithMeta::Int(v, _metadata) => serialize_with_checks(v, metadata_fields, serializer), +// BamlValueWithMeta::Float(v, _metadata) => serialize_with_checks(v, metadata_fields, serializer), +// BamlValueWithMeta::Bool(v, _metadata) => serialize_with_checks(v, metadata_fields, serializer), +// BamlValueWithMeta::Map(v, _metadata) => { +// let mut map = serializer.serialize_map(None)?; +// for (key, value) in v { +// map.serialize_entry::>(key, value)?; +// } +// add_checks(&mut map, &self.meta().metadata_fields(&bare_value)?)?; +// map.end() +// } +// BamlValueWithMeta::List(v, _metadata) => serialize_with_checks(v, metadata_fields, serializer), +// BamlValueWithMeta::Media(v, _metadata) => serialize_with_checks(v, metadata_fields, serializer), +// BamlValueWithMeta::Enum(_enum_name, v, _metadata) => serialize_with_checks(v, metadata_fields, serializer), +// BamlValueWithMeta::Class(_class_name, v, _metadata) => { +// let metadata_fields = self.meta().metadata_fields(&bare_value); +// if metadata_fields.is_empty() { +// let mut map = serializer.serialize_map(None)?; +// v.into_iter().try_for_each(|(key, value)| { +// map.serialize_entry(key, value) +// })?; +// add_checks(&mut map, &metadata_fields)?; +// map.end() +// } else { +// let mut checked_value = serializer.serialize_map(Some(2))?; +// checked_value.serialize_entry("value", &v)?; +// add_checks(&mut checked_value, &metadata_fields)?; +// checked_value.end() +// } +// } +// BamlValueWithMeta::Null(_) => serialize_with_checks(&(), &self.meta().metadata_fields(), serializer), +// } +// } +// } +// +// fn serialize_with_checks( +// value: &T, +// metadata_fields: &Vec<(String, serde_json::Value)>, +// serializer: S, +// ) -> Result +// where +// S: Serializer, +// { +// if !metadata_fields.is_empty() { +// let mut map = serializer.serialize_map(Some(2))?; +// map.serialize_entry("value", value)?; +// add_checks(&mut map, metadata_fields)?; +// map.end() +// } else { +// value.serialize(serializer) +// } +// } +// +// fn add_checks<'a, S: SerializeMap>( +// map: &'a mut S, +// metadata_fields: &Vec<(String, serde_json::Value)>, +// ) -> Result<(), S::Error> +// { +// metadata_fields.iter().try_for_each(|(field_name, value)| { +// map.serialize_entry(&field_name, &value) +// })?; +// Ok(()) +// } +// +// pub trait SerializeMetadata { +// fn metadata_fields(&self, bare_value: &BamlValue) -> Result, serde_json::Error>; +// } +// +// // This instance is used in constraint tests. +// // Consider modifying that test and deleting this instance. +// impl SerializeMetadata for Vec { +// fn metadata_fields(&self, _bare_value: &serde_json::Value) -> Result, serde_json::Error> { +// if !self.is_empty() { +// let checks_map: HashMap<_,_> = self.iter().map(|check| (check.name.clone(), check)).collect(); +// let json_checks_map = serde_json::to_value(checks_map).expect("serialization of checks is safe"); +// Ok(vec![("checks", json_checks_map)]) +// } else { +// Ok(Vec::new()) +// } +// } +// } +// +// impl SerializeMetadata for (T, Vec, Option) { +// +// // If there are only checks: +// // [("checks", checks), ("value", value)] +// // If there is completion state: +// // [("state", state), ("value", value)] +// // If there are checks and completion state: +// // [("state", state), ("value": { "checks": checks, "value": value })] +// // If there are neither checks nor completion state: +// // [("value", value)] +// fn metadata_fields(&self, bare_value: &BamlValue) -> Result, serde_json::Error> { +// let checks: Vec<(&str, &ResponseCheck)> = self.1.iter().map(|check| (check.name.as_str(), check)).collect(); +// let completion_state: Option<&CompletionState> = self.2.as_ref(); +// +// let checks_json = serde_json::to_value(&checks)?; +// let bare_value_json = serde_json::to_value(bare_value)?; +// +// match (checks.len(), completion_state) { +// (0, None) => Ok(vec![("value", bare_value_json)]), +// (_, None) => Ok(vec![("value", bare_value_json), ("checks", checks_json)]), +// (0, Some(state)) => Ok(vec![("value", bare_value_json), ("state", serde_json::to_value(state)?)]), +// (_, Some(state)) => Ok(vec![ +// ("state", serde_json::to_value(state)?), +// ("value", serde_json::to_value(&vec![ +// ("value", bare_value_json), +// ("checks", checks_json) +// ].into_iter().collect::>())?) +// ]), +// +// } +// // if !checks.is_empty() { +// // let checks_json = serde_json::to_value(checks).expect("Serializing checks is safe."); +// // fields.push(("checks".to_string(), checks_json)); +// // } +// +// // let value_considering_checks = if checks.is_empty() { +// // serde_json::to_value(bare_value)? +// // } else { +// // let object = vec![ +// // ("value", serde_json::to_value(bare_value)?), +// // ("checks", serde_json::to_value(fields)?), +// // ].into_iter().collect::>(); +// // serde_json::to_value(object)? +// // }; +// +// // let value_considering_completion_state = if let Some(state) = completion_state { +// // vec![ ("state", serde_json::to_value(&state)?) ] +// // } else { +// // value_considering_checks +// // } +// +// // Ok(value_considering_completion_state) +// } +// +// } + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct Completion { + pub state: CompletionState, + pub display: bool, + pub required_done: bool, } + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub enum CompletionState { + Pending, + Incomplete, + Complete, +} + +impl Default for Completion { + fn default() -> Self { + panic!("I hope we don't use this default"); + Completion { + state: CompletionState::Complete, + display: false, + required_done: false, + } + } +} \ No newline at end of file diff --git a/engine/baml-lib/baml-types/src/field_type/mod.rs b/engine/baml-lib/baml-types/src/field_type/mod.rs index 52f59fae02..aa8b9e62d9 100644 --- a/engine/baml-lib/baml-types/src/field_type/mod.rs +++ b/engine/baml-lib/baml-types/src/field_type/mod.rs @@ -86,9 +86,10 @@ pub enum FieldType { Tuple(Vec), Optional(Box), RecursiveTypeAlias(String), - Constrained { + WithMetadata { base: Box, constraints: Vec, + streaming_behavior: StreamingBehavior, }, } @@ -126,7 +127,7 @@ impl std::fmt::Display for FieldType { FieldType::Map(k, v) => write!(f, "map<{k}, {v}>"), FieldType::List(t) => write!(f, "{t}[]"), FieldType::Optional(t) => write!(f, "{t}?"), - FieldType::Constrained { base, .. } => base.fmt(f), + FieldType::WithMetadata { base, .. } => base.fmt(f), } } } @@ -137,7 +138,7 @@ impl FieldType { FieldType::Primitive(_) => true, FieldType::Optional(t) => t.is_primitive(), FieldType::List(t) => t.is_primitive(), - FieldType::Constrained { base, .. } => base.is_primitive(), + FieldType::WithMetadata { base, .. } => base.is_primitive(), _ => false, } } @@ -147,7 +148,7 @@ impl FieldType { FieldType::Optional(_) => true, FieldType::Primitive(TypeValue::Null) => true, FieldType::Union(types) => types.iter().any(FieldType::is_optional), - FieldType::Constrained { base, .. } => base.is_optional(), + FieldType::WithMetadata { base, .. } => base.is_optional(), _ => false, } } @@ -156,8 +157,52 @@ impl FieldType { match self { FieldType::Primitive(TypeValue::Null) => true, FieldType::Optional(t) => t.is_null(), - FieldType::Constrained { base, .. } => base.is_null(), + FieldType::WithMetadata { base, .. } => base.is_null(), _ => false, } } + + pub fn streaming_behavior(&self) -> Option<&StreamingBehavior> { + match self { + FieldType::WithMetadata { + streaming_behavior, .. + } => Some(streaming_behavior), + _ => None, + } + } +} + +/// Metadata on a type that determines how it behaves under streaming conditions. +#[derive(Clone, Debug, PartialEq, serde::Serialize)] +pub struct StreamingBehavior { + /// A type with the `done` property will not be visible in a stream until + /// we are certain that it is completely available (i.e. the parser did + /// not finalize it through any early termination, enough tokens were available + /// from the LLM response to be certain that it is done). + pub done: bool, + + /// A type with the `state` property will be represented in client code as + /// a struct: `{value: T, streaming_state: "incomplete" | "complete"}`. + pub state: bool, +} + +impl StreamingBehavior { + pub fn combine(&self, other: &StreamingBehavior) -> StreamingBehavior { + StreamingBehavior { + done: self.done || other.done, + state: self.state || other.state, + } + } +} + +impl Default for StreamingBehavior { + fn default() -> Self { + StreamingBehavior { + done: false, + state: false, + } + } } + +#[cfg(test)] +mod tests {} diff --git a/engine/baml-lib/baml-types/src/lib.rs b/engine/baml-lib/baml-types/src/lib.rs index 2b56a79351..bc7c3f9064 100644 --- a/engine/baml-lib/baml-types/src/lib.rs +++ b/engine/baml-lib/baml-types/src/lib.rs @@ -8,11 +8,13 @@ mod field_type; mod generator; mod value_expr; -pub use baml_value::{BamlValue, BamlValueWithMeta}; +pub use baml_value::{BamlValue, BamlValueWithMeta, Completion, CompletionState}; pub use constraint::*; -pub use field_type::{FieldType, LiteralValue, TypeValue}; +pub use field_type::{FieldType, LiteralValue, StreamingBehavior, TypeValue}; pub use generator::{GeneratorDefaultClientMode, GeneratorOutputType}; pub use map::Map as BamlMap; pub use media::{BamlMedia, BamlMediaContent, BamlMediaType, MediaBase64, MediaUrl}; pub use minijinja::JinjaExpression; -pub use value_expr::{EvaluationContext, GetEnvVar, ResolvedValue, StringOr, UnresolvedValue}; +pub use value_expr::{ + EvaluationContext, GetEnvVar, Resolvable, ResolvedValue, StringOr, UnresolvedValue, +}; diff --git a/engine/baml-lib/baml/tests/validation_files/streaming/streaming_ok.baml b/engine/baml-lib/baml/tests/validation_files/streaming/streaming_ok.baml new file mode 100644 index 0000000000..651cf67e06 --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/streaming/streaming_ok.baml @@ -0,0 +1,11 @@ +class Foo { + bar int @stream.not_null + baz string | int @stream.done + quux bool @stream.with_state +} + +class Foo2 { + something int + tomething_else bool + @@stream.done +} diff --git a/engine/baml-lib/jinja-runtime/src/output_format/types.rs b/engine/baml-lib/jinja-runtime/src/output_format/types.rs index 188f37035d..9357d99c74 100644 --- a/engine/baml-lib/jinja-runtime/src/output_format/types.rs +++ b/engine/baml-lib/jinja-runtime/src/output_format/types.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use anyhow::Result; -use baml_types::{Constraint, FieldType, TypeValue}; +use baml_types::{Constraint, FieldType, StreamingBehavior, TypeValue}; use indexmap::{IndexMap, IndexSet}; #[derive(Debug)] @@ -48,9 +48,10 @@ pub struct Enum { #[derive(Debug)] pub struct Class { pub name: Name, - // fields have name, type and description. - pub fields: Vec<(Name, FieldType, Option)>, + // fields have name, type, description, and streaming_needed. + pub fields: Vec<(Name, FieldType, Option, bool)>, pub constraints: Vec, + pub streaming_behavior: StreamingBehavior, } #[derive(Debug, Clone)] @@ -365,7 +366,7 @@ impl OutputFormatContent { FieldType::Optional(_) => Some(String::from("Answer in JSON using this schema:\n")), FieldType::Map(_, _) => Some(String::from("Answer in JSON using this schema:\n")), FieldType::Tuple(_) => None, - FieldType::Constrained { base, .. } => { + FieldType::WithMetadata { base, .. } => { auto_prefix(base, options, output_format_content) } } @@ -446,7 +447,7 @@ impl OutputFormatContent { } }, FieldType::Literal(v) => v.to_string(), - FieldType::Constrained { base, .. } => self.render_possibly_recursive_type( + FieldType::WithMetadata { base, .. } => self.render_possibly_recursive_type( options, base, render_state, @@ -491,7 +492,7 @@ impl OutputFormatContent { values: class .fields .iter() - .map(|(name, field_type, description)| { + .map(|(name, field_type, description, _streaming_needed)| { Ok(ClassFieldRender { name: name.rendered_name().to_string(), description: description.clone(), @@ -788,14 +789,17 @@ mod tests { Name::new("name".to_string()), FieldType::string(), Some("The person's name".to_string()), + false, ), ( Name::new("age".to_string()), FieldType::int(), Some("The person's age".to_string()), + false, ), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }]; let content = OutputFormatContent::target(FieldType::class("Person")) @@ -819,15 +823,18 @@ mod tests { Name::new("school".to_string()), FieldType::optional(FieldType::string()), Some("111\n ".to_string()), + false, ), ( Name::new("degree".to_string()), FieldType::string(), Some("2222222".to_string()), + false, ), - (Name::new("year".to_string()), FieldType::int(), None), + (Name::new("year".to_string()), FieldType::int(), None, false), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }]; let content = OutputFormatContent::target(FieldType::class("Education")) @@ -852,30 +859,35 @@ mod tests { Name::new("description".to_string()), FieldType::string(), None, + false, ), - (Name::new("severity".to_string()), FieldType::string(), None), + (Name::new("severity".to_string()), FieldType::string(), None, false), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("Enhancement".to_string()), fields: vec![ - (Name::new("title".to_string()), FieldType::string(), None), + (Name::new("title".to_string()), FieldType::string(), None, false), ( Name::new("description".to_string()), FieldType::string(), None, + false, ), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("Documentation".to_string()), fields: vec![ - (Name::new("module".to_string()), FieldType::string(), None), - (Name::new("format".to_string()), FieldType::string(), None), + (Name::new("module".to_string()), FieldType::string(), None, false), + (Name::new("format".to_string()), FieldType::string(), None, false), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, ]; @@ -920,10 +932,12 @@ r#"Answer in JSON using any of these schemas: FieldType::class("Documentation"), ]), None, + false, ), - (Name::new("date".to_string()), FieldType::string(), None), + (Name::new("date".to_string()), FieldType::string(), None, false), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("Bug".to_string()), @@ -932,30 +946,35 @@ r#"Answer in JSON using any of these schemas: Name::new("description".to_string()), FieldType::string(), None, + false, ), - (Name::new("severity".to_string()), FieldType::string(), None), + (Name::new("severity".to_string()), FieldType::string(), None, false), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("Enhancement".to_string()), fields: vec![ - (Name::new("title".to_string()), FieldType::string(), None), + (Name::new("title".to_string()), FieldType::string(), None, false), ( Name::new("description".to_string()), FieldType::string(), None, + false, ), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("Documentation".to_string()), fields: vec![ - (Name::new("module".to_string()), FieldType::string(), None), - (Name::new("format".to_string()), FieldType::string(), None), + (Name::new("module".to_string()), FieldType::string(), None, false), + (Name::new("format".to_string()), FieldType::string(), None, false), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, ]; @@ -990,14 +1009,16 @@ r#"Answer in JSON using this schema: let classes = vec![Class { name: Name::new("Node".to_string()), fields: vec![ - (Name::new("data".to_string()), FieldType::int(), None), + (Name::new("data".to_string()), FieldType::int(), None, false), ( Name::new("next".to_string()), FieldType::optional(FieldType::class("Node")), None, + false, ), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }]; let content = OutputFormatContent::target(FieldType::class("Node")) @@ -1025,14 +1046,16 @@ Answer in JSON using this schema: Node"# Class { name: Name::new("Node".to_string()), fields: vec![ - (Name::new("data".to_string()), FieldType::int(), None), + (Name::new("data".to_string()), FieldType::int(), None, false), ( Name::new("next".to_string()), FieldType::optional(FieldType::class("Node")), None, + false, ), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("LinkedList".to_string()), @@ -1041,10 +1064,12 @@ Answer in JSON using this schema: Node"# Name::new("head".to_string()), FieldType::optional(FieldType::class("Node")), None, + false, ), - (Name::new("len".to_string()), FieldType::int(), None), + (Name::new("len".to_string()), FieldType::int(), None, false), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, ]; @@ -1080,8 +1105,10 @@ Answer in JSON using this schema: Name::new("pointer".to_string()), FieldType::class("B"), None, + false, )], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("B".to_string()), @@ -1089,8 +1116,10 @@ Answer in JSON using this schema: Name::new("pointer".to_string()), FieldType::class("C"), None, + false, )], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("C".to_string()), @@ -1098,8 +1127,10 @@ Answer in JSON using this schema: Name::new("pointer".to_string()), FieldType::optional(FieldType::class("A")), None, + false, )], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, ]; @@ -1140,8 +1171,10 @@ Answer in JSON using this schema: A"# Name::new("pointer".to_string()), FieldType::class("B"), None, + false, )], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("B".to_string()), @@ -1149,8 +1182,10 @@ Answer in JSON using this schema: A"# Name::new("pointer".to_string()), FieldType::class("C"), None, + false, )], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("C".to_string()), @@ -1158,8 +1193,10 @@ Answer in JSON using this schema: A"# Name::new("pointer".to_string()), FieldType::optional(FieldType::class("A")), None, + false, )], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("NonRecursive".to_string()), @@ -1168,11 +1205,13 @@ Answer in JSON using this schema: A"# Name::new("pointer".to_string()), FieldType::class("A"), None, + false, ), - (Name::new("data".to_string()), FieldType::int(), None), - (Name::new("field".to_string()), FieldType::bool(), None), + (Name::new("data".to_string()), FieldType::int(), None, false), + (Name::new("field".to_string()), FieldType::bool(), None, false), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, ]; @@ -1219,14 +1258,17 @@ Answer in JSON using this schema: Name::new("pointer".to_string()), FieldType::class("B"), None, + false, ), ( Name::new("nested".to_string()), FieldType::class("Nested"), None, + false, ), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("B".to_string()), @@ -1234,8 +1276,10 @@ Answer in JSON using this schema: Name::new("pointer".to_string()), FieldType::class("C"), None, + false, )], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("C".to_string()), @@ -1243,8 +1287,10 @@ Answer in JSON using this schema: Name::new("pointer".to_string()), FieldType::optional(FieldType::class("A")), None, + false, )], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("NonRecursive".to_string()), @@ -1253,19 +1299,22 @@ Answer in JSON using this schema: Name::new("pointer".to_string()), FieldType::class("A"), None, + false, ), - (Name::new("data".to_string()), FieldType::int(), None), - (Name::new("field".to_string()), FieldType::bool(), None), + (Name::new("data".to_string()), FieldType::int(), None, false), + (Name::new("field".to_string()), FieldType::bool(), None, false), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("Nested".to_string()), fields: vec![ - (Name::new("data".to_string()), FieldType::int(), None), - (Name::new("field".to_string()), FieldType::bool(), None), + (Name::new("data".to_string()), FieldType::int(), None, false), + (Name::new("field".to_string()), FieldType::bool(), None, false), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, ]; @@ -1312,14 +1361,16 @@ Answer in JSON using this schema: Class { name: Name::new("Tree".to_string()), fields: vec![ - (Name::new("data".to_string()), FieldType::int(), None), + (Name::new("data".to_string()), FieldType::int(), None, false), ( Name::new("children".to_string()), FieldType::class("Forest"), None, + false, ), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("Forest".to_string()), @@ -1327,8 +1378,10 @@ Answer in JSON using this schema: Name::new("trees".to_string()), FieldType::list(FieldType::class("Tree")), None, + false, )], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, ]; @@ -1369,8 +1422,10 @@ Answer in JSON using this schema: Tree"# FieldType::optional(FieldType::class("SelfReferential")), ]), None, + false, )], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }]; let content = OutputFormatContent::target(FieldType::class("SelfReferential")) @@ -1399,26 +1454,30 @@ Answer in JSON using this schema: SelfReferential"# Class { name: Name::new("Node".to_string()), fields: vec![ - (Name::new("data".to_string()), FieldType::int(), None), + (Name::new("data".to_string()), FieldType::int(), None, false), ( Name::new("next".to_string()), FieldType::optional(FieldType::class("Node")), None, + false, ), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("Tree".to_string()), fields: vec![ - (Name::new("data".to_string()), FieldType::int(), None), + (Name::new("data".to_string()), FieldType::int(), None, false), ( Name::new("children".to_string()), FieldType::list(FieldType::class("Tree")), None, + false, ), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, ]; @@ -1462,39 +1521,46 @@ Node or Tree"# Name::new("data_type".to_string()), FieldType::Union(vec![FieldType::class("Node"), FieldType::class("Tree")]), None, + false, ), - (Name::new("len".to_string()), FieldType::int(), None), + (Name::new("len".to_string()), FieldType::int(), None, false), ( Name::new("description".to_string()), FieldType::string(), None, + false, ), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("Node".to_string()), fields: vec![ - (Name::new("data".to_string()), FieldType::int(), None), + (Name::new("data".to_string()), FieldType::int(), None, false), ( Name::new("next".to_string()), FieldType::optional(FieldType::class("Node")), None, + false, ), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("Tree".to_string()), fields: vec![ - (Name::new("data".to_string()), FieldType::int(), None), + (Name::new("data".to_string()), FieldType::int(), None, false), ( Name::new("children".to_string()), FieldType::list(FieldType::class("Tree")), None, + false, ), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, ]; @@ -1535,34 +1601,39 @@ Answer in JSON using this schema: Class { name: Name::new("Node".to_string()), fields: vec![ - (Name::new("data".to_string()), FieldType::int(), None), + (Name::new("data".to_string()), FieldType::int(), None, false), ( Name::new("next".to_string()), FieldType::optional(FieldType::class("Node")), None, + false, ), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("Tree".to_string()), fields: vec![ - (Name::new("data".to_string()), FieldType::int(), None), + (Name::new("data".to_string()), FieldType::int(), None, false), ( Name::new("children".to_string()), FieldType::list(FieldType::class("Tree")), None, + false, ), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("NonRecursive".to_string()), fields: vec![ - (Name::new("data".to_string()), FieldType::int(), None), - (Name::new("tag".to_string()), FieldType::string(), None), + (Name::new("data".to_string()), FieldType::int(), None, false), + (Name::new("tag".to_string()), FieldType::string(), None, false), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, ]; @@ -1614,47 +1685,55 @@ Node or Tree or { FieldType::class("NonRecursive"), ]), None, + false, ), - (Name::new("len".to_string()), FieldType::int(), None), + (Name::new("len".to_string()), FieldType::int(), None, false), ( Name::new("description".to_string()), FieldType::string(), None, + false, ), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("Node".to_string()), fields: vec![ - (Name::new("data".to_string()), FieldType::int(), None), + (Name::new("data".to_string()), FieldType::int(), None, false), ( Name::new("next".to_string()), FieldType::optional(FieldType::class("Node")), None, + false, ), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("Tree".to_string()), fields: vec![ - (Name::new("data".to_string()), FieldType::int(), None), + (Name::new("data".to_string()), FieldType::int(), None, false), ( Name::new("children".to_string()), FieldType::list(FieldType::class("Tree")), None, + false, ), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("NonRecursive".to_string()), fields: vec![ - (Name::new("data".to_string()), FieldType::int(), None), - (Name::new("tag".to_string()), FieldType::string(), None), + (Name::new("data".to_string()), FieldType::int(), None, false), + (Name::new("tag".to_string()), FieldType::string(), None, false), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, ]; @@ -1701,8 +1780,10 @@ Answer in JSON using this schema: Name::new("pointer".to_string()), FieldType::class("B"), None, + false, )], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("B".to_string()), @@ -1710,8 +1791,10 @@ Answer in JSON using this schema: Name::new("pointer".to_string()), FieldType::class("C"), None, + false, )], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("C".to_string()), @@ -1719,8 +1802,10 @@ Answer in JSON using this schema: Name::new("pointer".to_string()), FieldType::optional(FieldType::class("A")), None, + false, )], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("NonRecursive".to_string()), @@ -1729,11 +1814,13 @@ Answer in JSON using this schema: Name::new("pointer".to_string()), FieldType::class("A"), None, + false, ), - (Name::new("data".to_string()), FieldType::int(), None), - (Name::new("field".to_string()), FieldType::bool(), None), + (Name::new("data".to_string()), FieldType::int(), None, false), + (Name::new("field".to_string()), FieldType::bool(), None, false), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, ]; @@ -1778,26 +1865,30 @@ Answer in JSON using this interface: Class { name: Name::new("Node".to_string()), fields: vec![ - (Name::new("data".to_string()), FieldType::int(), None), + (Name::new("data".to_string()), FieldType::int(), None, false), ( Name::new("next".to_string()), FieldType::optional(FieldType::class("Node")), None, + false, ), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("Tree".to_string()), fields: vec![ - (Name::new("data".to_string()), FieldType::int(), None), + (Name::new("data".to_string()), FieldType::int(), None, false), ( Name::new("children".to_string()), FieldType::list(FieldType::class("Tree")), None, + false, ), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, ]; @@ -1837,26 +1928,30 @@ Node or int or string or Tree"# Class { name: Name::new("Node".to_string()), fields: vec![ - (Name::new("data".to_string()), FieldType::int(), None), + (Name::new("data".to_string()), FieldType::int(), None, false), ( Name::new("next".to_string()), FieldType::optional(FieldType::class("Node")), None, + false, ), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("Tree".to_string()), fields: vec![ - (Name::new("data".to_string()), FieldType::int(), None), + (Name::new("data".to_string()), FieldType::int(), None, false), ( Name::new("children".to_string()), FieldType::list(FieldType::class("Tree")), None, + false, ), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("NonRecursive".to_string()), @@ -1868,11 +1963,13 @@ Node or int or string or Tree"# FieldType::Union(vec![FieldType::string(), FieldType::class("Tree")]), ]), None, + false, ), - (Name::new("data".to_string()), FieldType::int(), None), - (Name::new("field".to_string()), FieldType::bool(), None), + (Name::new("data".to_string()), FieldType::int(), None, false), + (Name::new("field".to_string()), FieldType::bool(), None, false), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, ]; @@ -1912,14 +2009,16 @@ Answer in JSON using this schema: let classes = vec![Class { name: Name::new("Node".to_string()), fields: vec![ - (Name::new("data".to_string()), FieldType::int(), None), + (Name::new("data".to_string()), FieldType::int(), None, false), ( Name::new("next".to_string()), FieldType::optional(FieldType::class("Node")), None, + false, ), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }]; let content = OutputFormatContent::target(FieldType::list(FieldType::class("Node"))) @@ -1950,8 +2049,10 @@ Node[]"# Name::new("data".to_string()), FieldType::map(FieldType::string(), FieldType::class("RecursiveMap")), None, + false )], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }]; let content = OutputFormatContent::target(FieldType::class("RecursiveMap")) @@ -1981,8 +2082,10 @@ Answer in JSON using this schema: RecursiveMap"# Name::new("data".to_string()), FieldType::map(FieldType::string(), FieldType::class("RecursiveMap")), None, + false, )], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("NonRecursive".to_string()), @@ -1990,8 +2093,10 @@ Answer in JSON using this schema: RecursiveMap"# Name::new("rec_map".to_string()), FieldType::Class("RecursiveMap".to_string()), None, + false, )], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, ]; @@ -2021,14 +2126,16 @@ Answer in JSON using this schema: let classes = vec![Class { name: Name::new("Node".to_string()), fields: vec![ - (Name::new("data".to_string()), FieldType::int(), None), + (Name::new("data".to_string()), FieldType::int(), None, false), ( Name::new("next".to_string()), FieldType::optional(FieldType::class("Node")), None, + false, ), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }]; let content = OutputFormatContent::target(FieldType::map( @@ -2063,20 +2170,24 @@ map"# Name::new("data".to_string()), FieldType::map(FieldType::string(), FieldType::class("Node")), None, + false, )], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("Node".to_string()), fields: vec![ - (Name::new("data".to_string()), FieldType::int(), None), + (Name::new("data".to_string()), FieldType::int(), None, false), ( Name::new("next".to_string()), FieldType::optional(FieldType::class("Node")), None, + false ), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, ]; @@ -2114,20 +2225,24 @@ Answer in JSON using this schema: FieldType::optional(FieldType::class("Node")), ), None, + false, )], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("Node".to_string()), fields: vec![ - (Name::new("data".to_string()), FieldType::int(), None), + (Name::new("data".to_string()), FieldType::int(), None, false), ( Name::new("next".to_string()), FieldType::optional(FieldType::class("Node")), None, + false, ), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, ]; @@ -2159,22 +2274,25 @@ Answer in JSON using this schema: Class { name: Name::new("Node".to_string()), fields: vec![ - (Name::new("data".to_string()), FieldType::int(), None), + (Name::new("data".to_string()), FieldType::int(), None, false), ( Name::new("next".to_string()), FieldType::optional(FieldType::class("Node")), None, + false, ), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, Class { name: Name::new("NonRecursive".to_string()), fields: vec![ - (Name::new("field".to_string()), FieldType::string(), None), - (Name::new("data".to_string()), FieldType::int(), None), + (Name::new("field".to_string()), FieldType::string(), None, false), + (Name::new("data".to_string()), FieldType::int(), None, false), ], constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), }, ]; @@ -2224,28 +2342,33 @@ map +} +"#; + +pub fn bench_partials(c: &mut Criterion) { + let mut group = c.benchmark_group("partials"); + + // Test partial parsing of a deeply nested object + let target = FieldType::Class("NestedObject".to_string()); + let of = Builder::new(target.clone()).build(); + + group.bench_function("partial_nested_shallow", |b| b.iter(|| from_str( + &of, + &target, + r#"{ + "id": 1, + "name": "test" + }"#, + true, + ))); + + group.bench_function("partial_nested_mid", |b| b.iter(|| from_str( + &of, + &target, + r#"{ + "id": 1, + "name": "test", + "metadata": { + "created_at": "2024-01-10", + "tags": ["tag1", "tag2"] + } + }"#, + true, + ))); + + // Test partial with optional fields + let target = FieldType::Class("ComplexPartial".to_string()); + let of = Builder::new(target.clone()).build(); + + group.bench_function("partial_with_optional", |b| b.iter(|| from_str( + &of, + &target, + r#"{ + "required_field": "required", + "list_field": ["item1"] + }"#, + true, + ))); + + // Test partial with array fields + group.bench_function("partial_with_arrays", |b| b.iter(|| from_str( + &of, + &target, + r#"{ + "required_field": "required", + "list_field": ["item1", "item2", "item3", "item4", "item5"], + "nested": { + "id": 1, + "name": "test", + "metadata": { + "tags": ["tag1", "tag2", "tag3"] + } + } + }"#, + true, + ))); + + // Test partial with map fields + group.bench_function("partial_with_maps", |b| b.iter(|| from_str( + &of, + &target, + r#"{ + "required_field": "required", + "map_field": { + "key1": "value1", + "key2": "value2", + "key3": "value3" + } + }"#, + true, + ))); + + // Test partial with mixed fields + group.bench_function("partial_mixed", |b| b.iter(|| from_str( + &of, + &target, + r#"{ + "required_field": "required", + "optional_field": "optional", + "list_field": ["item1", "item2"], + "map_field": { + "key1": "value1" + }, + "nested": { + "id": 1, + "name": "test" + } + }"#, + true, + ))); + + group.finish(); +} \ No newline at end of file diff --git a/engine/baml-lib/jsonish/benches/unions.rs b/engine/baml-lib/jsonish/benches/unions.rs new file mode 100644 index 0000000000..6b378cf505 --- /dev/null +++ b/engine/baml-lib/jsonish/benches/unions.rs @@ -0,0 +1,72 @@ +use baml_types::FieldType; +use criterion::Criterion; +use internal_baml_jinja::types::Builder; +use jsonish::{from_str, helpers::common::UNION_SCHEMA, jsonish as internal_jsonish}; + +pub fn bench_unions(c: &mut Criterion) { + let mut group = c.benchmark_group("unions"); + + let target = FieldType::union(vec![ + FieldType::Class("VideoContent".to_string()), + FieldType::Class("TextContent".to_string()), + FieldType::Class("ImageContent".to_string()), + FieldType::Class("AudioContent".to_string()), + ]); + let ir = jsonish::helpers::load_test_ir(UNION_SCHEMA); + let of = jsonish::helpers::render_output_format(&ir, &target, &Default::default()).unwrap(); + + // let of = Builder::new(target.clone()).build(); + + group.bench_function("text_content", |b| { + b.iter(|| from_str(&of, &target, r#"{"text": "Hello World"}"#, false)) + }); + + group.bench_function("image_content", |b| { + b.iter(|| { + from_str( + &of, + &target, + r#"{"url": "https://example.com/img.jpg", "width": 800, "height": 600}"#, + false, + ) + }) + }); + + group.bench_function("video_content_jsonish_only", |b| { + b.iter(|| { + internal_jsonish::parse( + r#"{"url": "https://example.com/video.mp4", "duration": 120}"#, + internal_jsonish::ParseOptions::default(), + ) + }) + }); + + group.bench_function("video_content", |b| { + b.iter(|| { + from_str( + &of, + &target, + r#"{"url": "https://example.com/video.mp4", "duration": 120}"#, + false, + ) + }) + }); + + let target = FieldType::RecursiveTypeAlias("JSONValue".to_string()); + let of = jsonish::helpers::render_output_format(&ir, &target, &Default::default()).unwrap(); + + group.bench_function("json_value_jsonish_only", |b| { + b.iter(|| { + internal_jsonish::parse( + jsonish::helpers::common::JSON_STRING, + internal_jsonish::ParseOptions::default(), + ) + }) + }); + + group.bench_function("json_value", |b| { + b.iter(|| from_str(&of, &target, jsonish::helpers::common::JSON_STRING, true)) + }); + + group.finish(); +} diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_array.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_array.rs index 2fcb32b211..f9839a151d 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_array.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_array.rs @@ -1,10 +1,12 @@ use anyhow::Result; use internal_baml_core::ir::FieldType; -use crate::deserializer::{ +use baml_types::CompletionState; + +use crate::{deserializer::{ deserialize_flags::{DeserializerConditions, Flag}, types::BamlValueWithFlags, -}; +}}; use super::{ParsingContext, ParsingError, TypeCoercer}; @@ -31,7 +33,10 @@ pub(super) fn coerce_array( let mut flags = DeserializerConditions::new(); match &value { - Some(crate::jsonish::Value::Array(arr)) => { + Some(crate::jsonish::Value::Array(arr, completion_state)) => { + if *completion_state == CompletionState::Incomplete { + flags.add_flag(Flag::Incomplete); + } for (i, item) in arr.iter().enumerate() { match inner.coerce(&ctx.enter_scope(&format!("{i}")), inner, Some(item)) { Ok(v) => items.push(v), diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_literal.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_literal.rs index ccda9ab3b9..44c2754ebb 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_literal.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_literal.rs @@ -3,6 +3,7 @@ use std::vec; use anyhow::Result; use baml_types::LiteralValue; use internal_baml_core::ir::FieldType; +use internal_baml_jinja::CompletionOptions; use crate::{ deserializer::{ @@ -48,14 +49,14 @@ impl TypeCoercer for LiteralValue { }; // If we get an object with a single key-value pair, try to extract the value - if let jsonish::Value::Object(obj) = value { + if let jsonish::Value::Object(obj, completion_state) = value { if obj.len() == 1 { let (key, inner_value) = obj.iter().next().unwrap(); // only extract value if it's a primitive (not an object or array, hoping to god its fixed) match inner_value { - jsonish::Value::Number(_) | jsonish::Value::Boolean(_) | jsonish::Value::String(_) => { - let mut result = self.coerce(ctx, target, Some(inner_value))?; - result.add_flag(Flag::ObjectToPrimitive(jsonish::Value::Object(obj.clone()))); + jsonish::Value::Number(_, _) | jsonish::Value::Boolean(_) | jsonish::Value::String(_, _) => { + let mut result = self.coerce(ctx, target, Some(&inner_value))?; + result.add_flag(Flag::ObjectToPrimitive(jsonish::Value::Object(obj.clone(), completion_state.clone()))); return Ok(result); } _ => {} @@ -73,7 +74,7 @@ impl TypeCoercer for LiteralValue { if coerced_int.value() == literal_int { Ok(BamlValueWithFlags::Int(coerced_int)) } else { - Err(ctx.error_unexpected_type(target, value)) + Err(ctx.error_unexpected_type(target, &value)) } } @@ -86,7 +87,7 @@ impl TypeCoercer for LiteralValue { if coerced_bool.value() == literal_bool { Ok(BamlValueWithFlags::Bool(coerced_bool)) } else { - Err(ctx.error_unexpected_type(target, value)) + Err(ctx.error_unexpected_type(target, &value)) } } diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_map.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_map.rs index e10d6a36ff..e909378c6b 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_map.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_map.rs @@ -9,7 +9,7 @@ use crate::{ }, jsonish, }; -use baml_types::{BamlMap, FieldType, LiteralValue, TypeValue}; +use baml_types::{BamlMap, CompletionState, FieldType, LiteralValue, TypeValue}; use super::{ParsingContext, ParsingError, TypeCoercer}; @@ -64,7 +64,7 @@ pub(super) fn coerce_map( flags.add_flag(Flag::ObjectToMap(value.clone())); match &value { - jsonish::Value::Object(obj) => { + jsonish::Value::Object(obj, completion_state) => { let mut items = BamlMap::new(); for (idx, (key, value)) in obj.iter().enumerate() { let coerced_value = @@ -85,11 +85,12 @@ pub(super) fn coerce_map( // TODO: Is it necessary to check that values match here? This // is also checked at `coerce_arg` in // baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs - let key_as_jsonish = jsonish::Value::String(key.to_owned()); + // TODO: Is it Ok that we assume keys are complete? + let key_as_jsonish = jsonish::Value::String(key.to_owned(), CompletionState::Complete); match key_type.coerce(ctx, key_type, Some(&key_as_jsonish)) { Ok(_) => { // Hack to avoid cloning the key twice. - let jsonish::Value::String(owned_key) = key_as_jsonish else { + let jsonish::Value::String(owned_key, CompletionState::Complete) = key_as_jsonish else { unreachable!("key_as_jsonish is defined as jsonish::Value::String"); }; @@ -103,6 +104,9 @@ pub(super) fn coerce_map( Err(e) => flags.add_flag(Flag::MapKeyParseError(idx, e)), } } + if *completion_state == CompletionState::Incomplete { + flags.add_flag(Flag::Incomplete); + } Ok(BamlValueWithFlags::Map(flags, items)) } // TODO: first map in an array that matches diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_primitive.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_primitive.rs index ebac337200..8e1605d35b 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_primitive.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/coerce_primitive.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use baml_types::BamlMediaType; +use baml_types::{BamlMediaType, CompletionState}; use internal_baml_core::ir::{FieldType, TypeValue}; use crate::deserializer::{ @@ -69,7 +69,13 @@ fn coerce_string( }; match value { - crate::jsonish::Value::String(s) => Ok(BamlValueWithFlags::String(s.to_string().into())), + crate::jsonish::Value::String(s, completion_state) => { + let mut baml_value = BamlValueWithFlags::String(s.to_string().into()); + if completion_state == &CompletionState::Incomplete { + baml_value.add_flag(Flag::Incomplete); + } + Ok(baml_value) + }, crate::jsonish::Value::Null => Err(ctx.error_unexpected_null(target)), v => Ok(BamlValueWithFlags::String( (v.to_string(), Flag::JsonToString(v.clone())).into(), @@ -86,8 +92,8 @@ pub(super) fn coerce_int( return Err(ctx.error_unexpected_null(target)); }; - match value { - crate::jsonish::Value::Number(n) => { + let mut result = match value { + crate::jsonish::Value::Number(n, _) => { if let Some(n) = n.as_i64() { Ok(BamlValueWithFlags::Int(n.into())) } else if let Some(n) = n.as_u64() { @@ -97,10 +103,10 @@ pub(super) fn coerce_int( ((n.round() as i64), Flag::FloatToInt(n)).into(), )) } else { - Err(ctx.error_unexpected_type(target, value)) + Err(ctx.error_unexpected_type(target, &value)) } } - crate::jsonish::Value::String(s) => { + crate::jsonish::Value::String(s, _) => { let s = s.trim(); // Trim trailing commas let s = s.trim_end_matches(','); @@ -121,16 +127,24 @@ pub(super) fn coerce_int( ((frac.round() as i64), Flag::FloatToInt(frac)).into(), )) } else { - Err(ctx.error_unexpected_type(target, value)) + Err(ctx.error_unexpected_type(target, &value)) } } - crate::jsonish::Value::Array(items) => { + crate::jsonish::Value::Array(items, _) => { coerce_array_to_singular(ctx, target, &items.iter().collect::>(), &|value| { coerce_int(ctx, target, Some(value)) }) } - _ => Err(ctx.error_unexpected_type(target, value)), + _ => Err(ctx.error_unexpected_type(target, &value)), + }; + match value.completion_state() { + CompletionState::Complete => {}, + CompletionState::Incomplete => { + result.iter_mut().for_each(|v| v.add_flag(Flag::Incomplete)); + }, + CompletionState::Pending => unreachable!("jsonish::Value may never be in a Pending state."), } + result } fn float_from_maybe_fraction(value: &str) -> Option { @@ -174,8 +188,8 @@ fn coerce_float( return Err(ctx.error_unexpected_null(target)); }; - match value { - crate::jsonish::Value::Number(n) => { + let mut result = match value { + crate::jsonish::Value::Number(n, _) => { if let Some(n) = n.as_f64() { Ok(BamlValueWithFlags::Float(n.into())) } else if let Some(n) = n.as_i64() { @@ -183,10 +197,10 @@ fn coerce_float( } else if let Some(n) = n.as_u64() { Ok(BamlValueWithFlags::Float((n as f64).into())) } else { - Err(ctx.error_unexpected_type(target, value)) + Err(ctx.error_unexpected_type(target, &value)) } } - crate::jsonish::Value::String(s) => { + crate::jsonish::Value::String(s, _) => { let s = s.trim(); // Trim trailing commas let s = s.trim_end_matches(','); @@ -208,16 +222,24 @@ fn coerce_float( baml_value.add_flag(Flag::StringToFloat(s.to_string())); Ok(baml_value) } else { - Err(ctx.error_unexpected_type(target, value)) + Err(ctx.error_unexpected_type(target, &value)) } } - crate::jsonish::Value::Array(items) => { + crate::jsonish::Value::Array(items, _) => { coerce_array_to_singular(ctx, target, &items.iter().collect::>(), &|value| { coerce_float(ctx, target, Some(value)) }) } - _ => Err(ctx.error_unexpected_type(target, value)), + _ => Err(ctx.error_unexpected_type(target, &value)), + }; + match value.completion_state() { + CompletionState::Complete => {}, + CompletionState::Incomplete => { + result.iter_mut().for_each(|v| v.add_flag(Flag::Incomplete)); + }, + CompletionState::Pending => unreachable!("jsonish::Value may never be in pending state"), } + result } pub(super) fn coerce_bool( @@ -229,9 +251,9 @@ pub(super) fn coerce_bool( return Err(ctx.error_unexpected_null(target)); }; - match value { + let mut result = match value { crate::jsonish::Value::Boolean(b) => Ok(BamlValueWithFlags::Bool((*b).into())), - crate::jsonish::Value::String(s) => match s.to_lowercase().as_str() { + crate::jsonish::Value::String(s, _) => match s.to_lowercase().as_str() { "true" => Ok(BamlValueWithFlags::Bool( (true, Flag::StringToBool(s.clone())).into(), )), @@ -258,19 +280,27 @@ pub(super) fn coerce_bool( "false" => Ok(BamlValueWithFlags::Bool( (false, Flag::StringToBool(val.value().clone())).into(), )), - _ => Err(ctx.error_unexpected_type(target, value)), + _ => Err(ctx.error_unexpected_type(target, &value)), }, - Err(_) => Err(ctx.error_unexpected_type(target, value)), + Err(_) => Err(ctx.error_unexpected_type(target, &value)), } } }, - crate::jsonish::Value::Array(items) => { + crate::jsonish::Value::Array(items, _) => { coerce_array_to_singular(ctx, target, &items.iter().collect::>(), &|value| { coerce_bool(ctx, target, Some(value)) }) } - _ => Err(ctx.error_unexpected_type(target, value)), + _ => Err(ctx.error_unexpected_type(target, &value)), + }; + match value.completion_state() { + CompletionState::Complete => {}, + CompletionState::Incomplete => { + result.iter_mut().for_each(|v| v.add_flag(Flag::Incomplete)); + }, + CompletionState::Pending => unreachable!("jsonish::Value may never be in pending state."), } + result } #[cfg(test)] diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs index 26d973dd32..675c00cb16 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use baml_types::{BamlMap, Constraint, ConstraintLevel}; +use baml_types::{BamlMap, CompletionState, Constraint, ConstraintLevel}; use internal_baml_core::{ir::FieldType, ir::TypeValue}; use crate::deserializer::{ @@ -25,7 +25,7 @@ impl TypeCoercer for FieldType { target: &FieldType, value: Option<&crate::jsonish::Value>, ) -> Result { - match value { + let mut result = match value { Some(crate::jsonish::Value::AnyOf(candidates, primitive)) => { log::debug!( "scope: {scope} :: coercing to: {name} (current: {current})", @@ -37,7 +37,7 @@ impl TypeCoercer for FieldType { self.coerce( ctx, target, - Some(&crate::jsonish::Value::String(primitive.clone())), + Some(&crate::jsonish::Value::String(primitive.clone(), CompletionState::Incomplete)), ) } else { array_helper::coerce_array_to_singular( @@ -48,7 +48,7 @@ impl TypeCoercer for FieldType { ) } } - Some(crate::jsonish::Value::Markdown(_t, v)) => { + Some(crate::jsonish::Value::Markdown(_t, v, _completion)) => { log::debug!( "scope: {scope} :: coercing to: {name} (current: {current})", name = target.to_string(), @@ -89,7 +89,7 @@ impl TypeCoercer for FieldType { FieldType::Optional(_) => coerce_optional(ctx, self, value), FieldType::Map(_, _) => coerce_map(ctx, self, value), FieldType::Tuple(_) => Err(ctx.error_internal("Tuple not supported")), - FieldType::Constrained { base, .. } => { + FieldType::WithMetadata { base, .. } => { let mut coerced_value = base.coerce(ctx, base, value)?; let constraint_results = run_user_checks(&coerced_value.clone().into(), self) .map_err(|e| ParsingError { @@ -110,7 +110,11 @@ impl TypeCoercer for FieldType { Ok(coerced_value) } }, + }; + if let Some(CompletionState::Incomplete) = value.map(|v| v.completion_state()) { + result.iter_mut().for_each(|v| v.add_flag(Flag::Incomplete)); } + result } } @@ -189,7 +193,7 @@ impl DefaultValue for FieldType { } FieldType::Primitive(_) => None, // If it has constraints, we can't assume our defaults meet them. - FieldType::Constrained { .. } => None, + FieldType::WithMetadata { .. } => None, } } } diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_class.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_class.rs index f45d2dadee..35231d010b 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_class.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_class.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use baml_types::{BamlMap, Constraint}; +use baml_types::{BamlMap, Constraint, StreamingBehavior}; use internal_baml_core::ir::FieldType; use internal_baml_jinja::types::{Class, Name}; @@ -12,8 +12,8 @@ use crate::deserializer::{ use super::ParsingContext; -// Name, type, description. -type FieldValue = (Name, FieldType, Option); +// Name, type, description, streaming_needed. +type FieldValue = (Name, FieldType, Option, bool); impl TypeCoercer for Class { fn coerce( @@ -58,10 +58,10 @@ impl TypeCoercer for Class { let (optional, required): (Vec<_>, Vec<_>) = self.fields.iter().partition(|f| f.1.is_optional()); - let constraints = ctx + let (constraints, streaming_behavior) = ctx .of .find_class(self.name.real_name()) - .map_or(vec![], |class| class.constraints.clone()); + .map_or((vec![], StreamingBehavior::default()), |class| (class.constraints.clone(), class.streaming_behavior.clone())); let mut optional_values = optional .iter() @@ -80,8 +80,9 @@ impl TypeCoercer for Class { None => { // Do nothing } - Some(crate::jsonish::Value::Object(obj)) => { + Some(crate::jsonish::Value::Object(obj, completion)) => { // match keys, if that fails, then do something fancy later. + // dbg!(&obj); let mut extra_keys = vec![]; let mut found_keys = false; obj.iter().for_each(|(key, v)| { @@ -108,7 +109,7 @@ impl TypeCoercer for Class { .coerce( &scope, &field.1, - Some(&crate::jsonish::Value::Object(obj.clone())), + Some(&crate::jsonish::Value::Object(obj.clone(), completion.clone())), ) .map(|mut v| { v.add_flag(Flag::ImpliedKey(field.0.real_name().into())); @@ -133,7 +134,7 @@ impl TypeCoercer for Class { }); } } - Some(crate::jsonish::Value::Array(items)) => { + Some(crate::jsonish::Value::Array(items, completion)) => { if self.fields.len() == 1 { let field = &self.fields[0]; let scope = ctx.enter_scope(&format!("", field.0.real_name())); @@ -154,7 +155,7 @@ impl TypeCoercer for Class { &items.iter().collect::>(), &|value| self.coerce(ctx, target, Some(value)), ) - .and_then(|value| apply_constraints(target, vec![], value, constraints.clone())); + .and_then(|value| apply_constraints(target, vec![], value, constraints.clone(), streaming_behavior.clone())); if let Ok(option1) = option1_result { completed_cls.push(Ok(option1)); } @@ -195,7 +196,8 @@ impl TypeCoercer for Class { // If we're missing a field, thats ok! None => Some(BamlValueWithFlags::Null( DeserializerConditions::new() - .with_flag(Flag::OptionalDefaultFromNoValue), + .with_flag(Flag::OptionalDefaultFromNoValue) + .with_flag(Flag::Incomplete), )), }; @@ -211,7 +213,8 @@ impl TypeCoercer for Class { if ctx.allow_partials { Some(BamlValueWithFlags::Null( DeserializerConditions::new() - .with_flag(Flag::OptionalDefaultFromNoValue), + .with_flag(Flag::OptionalDefaultFromNoValue) + .with_flag(Flag::Incomplete), )) } else { None @@ -221,7 +224,8 @@ impl TypeCoercer for Class { if ctx.allow_partials { Some(BamlValueWithFlags::Null( DeserializerConditions::new() - .with_flag(Flag::OptionalDefaultFromNoValue), + .with_flag(Flag::OptionalDefaultFromNoValue) + .with_flag(Flag::Incomplete), )) } else { None @@ -302,12 +306,13 @@ impl TypeCoercer for Class { // Decide if null is a better option. (k.to_string(), v) } - None => (k.to_string(), BamlValueWithFlags::Null(Default::default())), + None => (k.to_string(), BamlValueWithFlags::Null(DeserializerConditions::new().with_flag(Flag::Incomplete))), Some(Err(e)) => ( k.to_string(), BamlValueWithFlags::Null( DeserializerConditions::new() - .with_flag(Flag::DefaultButHadUnparseableValue(e)), + .with_flag(Flag::DefaultButHadUnparseableValue(e)) + .with_flag(Flag::Incomplete), ), ), } @@ -328,7 +333,7 @@ impl TypeCoercer for Class { flags, ordered_valid_fields.clone(), )) - .and_then(|value| apply_constraints(target, vec![], value, constraints.clone())); + .and_then(|value| apply_constraints(target, vec![], value, constraints.clone(), streaming_behavior)); completed_cls.insert(0, completed_instance); } @@ -345,13 +350,15 @@ pub fn apply_constraints( scope: Vec, mut value: BamlValueWithFlags, constraints: Vec, + streaming_behavior: StreamingBehavior, ) -> Result { if constraints.is_empty() { Ok(value) } else { - let constrained_class = FieldType::Constrained { + let constrained_class = FieldType::WithMetadata { base: Box::new(class_type.clone()), constraints, + streaming_behavior, }; let constraint_results = run_user_checks(&value.clone().into(), &constrained_class) .map_err(|e| ParsingError { diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_enum.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_enum.rs index 56aa5020a7..c901ec68e0 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_enum.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_enum.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use baml_types::FieldType; +use baml_types::{FieldType, StreamingBehavior}; use internal_baml_jinja::types::Enum; use crate::deserializer::{ @@ -56,6 +56,7 @@ impl TypeCoercer for Enum { vec![], BamlValueWithFlags::Enum(self.name.real_name().to_string(), variant_match), constraints.clone(), + StreamingBehavior::default(), )?; Ok(enum_match) diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/match_string.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/match_string.rs index 9e7e27c750..833d96e955 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/match_string.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/match_string.rs @@ -38,7 +38,7 @@ pub(super) fn match_string( // Grab context. let jsonish_string = match value { - jsonish::Value::String(s) => s.clone(), + jsonish::Value::String(s, _) => s.clone(), jsonish::Value::AnyOf(_, s) => { flags.add_flag(Flag::ObjectToString(value.clone())); s.clone() diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/mod.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/mod.rs index 6dcc7bc5de..34a51334f2 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/mod.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/mod.rs @@ -271,7 +271,7 @@ pub fn run_user_checks( type_: &FieldType, ) -> Result> { match type_ { - FieldType::Constrained { constraints, .. } => constraints + FieldType::WithMetadata { constraints, .. } => constraints .iter() .map(|constraint| { let result = evaluate_predicate(baml_value, &constraint.expression)?; diff --git a/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs b/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs index 106981ebcd..4d5c3e5e21 100644 --- a/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs +++ b/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs @@ -48,11 +48,15 @@ pub enum Flag { /// Constraint results (only contains checks) ConstraintResults(Vec<(String, JinjaExpression, bool)>), + + /// Completion state for the top-level node of the value is Incomplete. + Incomplete, + Pending } #[derive(Clone)] pub struct DeserializerConditions { - pub(super) flags: Vec, + pub flags: Vec, } impl DeserializerConditions { @@ -98,6 +102,8 @@ impl DeserializerConditions { Flag::UnionMatch(_idx, _) => None, Flag::DefaultButHadUnparseableValue(e) => Some(e.clone()), Flag::ConstraintResults(_) => None, + Flag::Incomplete => None, + Flag::Pending => None, }) .collect::>() } @@ -114,6 +120,15 @@ impl DeserializerConditions { } } +pub fn constraint_results(flags: &Vec) -> Vec<(String, JinjaExpression, bool)> { + flags.iter().filter_map(|flag| match flag { + Flag::ConstraintResults(cs) => Some(cs.clone()), + _ => None, + }) + .flatten() + .collect() +} + impl std::fmt::Debug for DeserializerConditions { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(self, f) @@ -122,7 +137,7 @@ impl std::fmt::Debug for DeserializerConditions { impl std::fmt::Display for DeserializerConditions { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.flags.is_empty() { + if true { return Ok(()); } @@ -261,6 +276,12 @@ impl std::fmt::Display for Flag { )?; } } + Flag::Incomplete => { + write!(f, "Value is incompletely streamed")?; + } + Flag::Pending => { + write!(f, "Value not yet started")?; + } } Ok(()) } diff --git a/engine/baml-lib/jsonish/src/deserializer/mod.rs b/engine/baml-lib/jsonish/src/deserializer/mod.rs index fe848bff90..b603ac344e 100644 --- a/engine/baml-lib/jsonish/src/deserializer/mod.rs +++ b/engine/baml-lib/jsonish/src/deserializer/mod.rs @@ -1,5 +1,6 @@ pub mod coercer; pub mod deserialize_flags; // pub mod schema; -mod score; +pub mod score; +pub mod semantic_streaming; pub mod types; diff --git a/engine/baml-lib/jsonish/src/deserializer/score.rs b/engine/baml-lib/jsonish/src/deserializer/score.rs index 7fbfefd8a5..9f10935943 100644 --- a/engine/baml-lib/jsonish/src/deserializer/score.rs +++ b/engine/baml-lib/jsonish/src/deserializer/score.rs @@ -69,6 +69,9 @@ impl WithScore for Flag { Flag::NoFields(_) => 1, // No scores for contraints Flag::ConstraintResults(_) => 0, + // No scores for incompleteness. + Flag::Incomplete => 0, + Flag::Pending => 0, } } } @@ -84,3 +87,9 @@ impl WithScore for DeserializerConditions { self.flags.iter().map(WithScore::score).sum() } } + +impl WithScore for Vec { + fn score(&self) -> i32 { + self.iter().map(WithScore::score).sum() + } +} diff --git a/engine/baml-lib/jsonish/src/deserializer/semantic_streaming.rs b/engine/baml-lib/jsonish/src/deserializer/semantic_streaming.rs new file mode 100644 index 0000000000..51e61c73a7 --- /dev/null +++ b/engine/baml-lib/jsonish/src/deserializer/semantic_streaming.rs @@ -0,0 +1,636 @@ +// This module helps resolve baml values with attached streaming state +// in the context of the streaming behavior associated with their types. + +use crate::deserializer::coercer::ParsingError; +use crate::{BamlValueWithFlags, Flag}; +use indexmap::IndexMap; +use internal_baml_core::ir::repr::{IntermediateRepr, Walker}; +use internal_baml_core::ir::{Field, IRHelper}; + +use baml_types::{ + BamlValueWithMeta, Completion, CompletionState, FieldType, ResponseCheck, StreamingBehavior, TypeValue, +}; + +use anyhow::{Context, Error}; +use std::collections::{HashMap, HashSet}; +use thiserror; + +#[derive(Debug, thiserror::Error)] +pub enum StreamingError { + #[error("Expected to encounter a class")] + ExpectedClass, + #[error("Value was marked Done, but was incomplete in the stream")] + IncompleteDoneValue, + #[error("Class instance did not contain fields marked as needed")] + MissingNeededFields, + #[error("Failed to distribute_type_with_meta: {0}")] + DistributeTypeWithMetaFailure(#[from] anyhow::Error), +} + +/// For a given baml value, traverse its nodes, comparing the completion state +/// of each node against the streaming behavior of the node's type. +pub fn validate_streaming_state( + ir: &IntermediateRepr, + baml_value: &BamlValueWithFlags, + field_type: &FieldType, + allow_partials: bool, +) -> Result, StreamingError> { + let baml_value_with_meta_flags: BamlValueWithMeta> = baml_value.clone().into(); + let typed_baml_value: BamlValueWithMeta<(Vec, FieldType)> = + ir.distribute_type_with_meta(baml_value_with_meta_flags, field_type.clone())?; + let baml_value_with_streaming_state_and_behavior = + typed_baml_value.map_meta(|(flags, r#type)| (completion_state(&flags), r#type)); + + let res = process_node( + ir, + baml_value_with_streaming_state_and_behavior, + allow_partials, + ); + res +} + +/// Like validate_state, but specialized to the metadata we happen to have already. +/// (This is a performance hack to allow us to skip several map_meta, zip_meta +/// steps). +pub fn validate_streaming_state2( + ir: &IntermediateRepr, + baml_value: BamlValueWithMeta<(Vec, Vec)>, + field_type: &FieldType, + allow_partials: bool, +) -> Result< + BamlValueWithMeta<(Vec, Vec, Completion)>, + StreamingError, +> { + let typed_baml_value = ir.distribute_type_with_meta(baml_value, field_type.clone())?; + let res = process_node2(ir, typed_baml_value, allow_partials); + res +} + +/// Consider a node's type, streaming state, and streaming behavior annotations. Return +/// an error if streaming state doesn't meet the streaming requirements. Also attach +/// the streaming state to the node as metadata, if this was requested by the user +/// vial `@stream.with_state`. +/// +/// This function descends into child nodes when the argument is a compound value. +/// +/// Params: +/// value: A done in the BamlValue tree. +/// allow_partials: +fn process_node( + ir: &IntermediateRepr, + value: BamlValueWithMeta<(CompletionState, &FieldType)>, + allow_partials: bool, +) -> Result, StreamingError> { + // let value_copy = value.clone(); // For debugging later. Delete me. + let (completion_state, field_type) = value.meta().clone(); + let (base_type, (_, streaming_behavior)) = ir.distribute_metadata(field_type); + + let must_be_done = required_done(ir, field_type); + let allow_partials_in_sub_nodes = allow_partials && !must_be_done; + + let new_meta = Completion { + state: completion_state.clone(), + display: streaming_behavior.state, + required_done: must_be_done, + }; + // let new_meta = if streaming_behavior.state && allow_partials { + // Some(completion_state.clone()) + // } else { + // None + // }; + + if must_be_done && allow_partials && !(completion_state == CompletionState::Complete) { + return Err(StreamingError::IncompleteDoneValue); + } + + let new_value = match value { + BamlValueWithMeta::String(s, _) => Ok(BamlValueWithMeta::String(s, new_meta)), + BamlValueWithMeta::Media(m, _) => Ok(BamlValueWithMeta::Media(m, new_meta)), + BamlValueWithMeta::Null(_) => Ok(BamlValueWithMeta::Null(new_meta)), + BamlValueWithMeta::Int(i, _) => Ok(BamlValueWithMeta::Int(i, new_meta)), + BamlValueWithMeta::Float(f, _) => Ok(BamlValueWithMeta::Float(f, new_meta)), + BamlValueWithMeta::Bool(b, _) => Ok(BamlValueWithMeta::Bool(b, new_meta)), + BamlValueWithMeta::List(items, _) => Ok(BamlValueWithMeta::List( + items + .into_iter() + .filter_map(|item| process_node(ir, item, allow_partials_in_sub_nodes).ok()) + .collect(), + new_meta, + )), + BamlValueWithMeta::Class(ref class_name, value_fields, _) => { + let value_field_names: HashSet = value_fields + .keys() + .into_iter() + .map(|s| s.to_string()) + .collect(); + let needed_fields: HashSet = needed_fields(ir, field_type, allow_partials_in_sub_nodes)?; + + let present_nonnull_fields: HashSet = value_fields.iter().filter_map(|(field_name, field_value)| { + if matches!(field_value, BamlValueWithMeta::Null(_)) { + None + } else { + Some(field_name.to_string()) + } + }).collect(); + // let missing_needed_fields = needed_fields.difference(&new_field_names); + let missing_needed_fields: Vec<_> = needed_fields.difference(&present_nonnull_fields).into_iter().collect(); + // if (class_name == "SmallThing") { + if false { + dbg!(class_name); + dbg!(&value_field_names); + dbg!(&present_nonnull_fields); + dbg!(&needed_fields); + dbg!(&missing_needed_fields); + dbg!(&value_fields); + } + + + // The fields that need to be filled in by Null are initially the + // fields in the Class type that are not present in the input + // value. + let fields_needing_null = + fields_needing_null_filler(ir, field_type, value_field_names, allow_partials)?; + + let mut deleted_fields: HashMap> = + HashMap::new(); + + // let unneeded_fields = field_names.difference(&needed_fields); + let needed_nulls = fields_needing_null + .into_iter() + .filter_map(|ref null_field_name| { + let field = value_fields + .get(null_field_name) + .expect("This field is guaranteed to be in the field set"); + let use_state = type_streaming_behavior(ir, field.meta().1).state; + let field_stream_state = Completion { + state: CompletionState::Incomplete, + display: use_state, + required_done: false, + }; + Some(( + null_field_name.to_string(), + BamlValueWithMeta::Null(field_stream_state), + )) + }) + .collect::>>(); + + let mut new_fields = value_fields + .into_iter() + .filter_map(|(field_name, field_value)| { + let with_state = field_value + .meta() + .1 + .streaming_behavior() + .as_ref() + .map_or(false, |b| b.state); + let completion_state = field_value.meta().0.clone(); + match process_node(ir, field_value, allow_partials_in_sub_nodes) { + Ok(res) => Some((field_name, res)), + _ => { + let state = Completion { + state: completion_state, + display: with_state, + required_done: false, + }; + let null = BamlValueWithMeta::Null(state); + deleted_fields.insert(field_name, null); + None + } + } + }) + .collect::>>(); + + let derived_present_nonnull_fields: HashSet = new_fields.iter().filter_map(|(field_name, field_value)| { + if matches!(field_value, BamlValueWithMeta::Null(_)) { + None + } else { + Some(field_name.to_string()) + } + }).collect(); + let missing_needed_fields: Vec<_> = needed_fields.difference(&derived_present_nonnull_fields).into_iter().collect(); + + new_fields.extend(needed_nulls); + new_fields.extend(deleted_fields); + + let res = BamlValueWithMeta::Class(class_name.clone(), new_fields, new_meta); + if missing_needed_fields.clone().len() == 0 { + Ok(res) + } else { + Err(StreamingError::MissingNeededFields) + } + } + BamlValueWithMeta::Enum(name, value, _) => { + Ok(BamlValueWithMeta::Enum(name, value, new_meta)) + } + BamlValueWithMeta::Map(kvs, _) => { + let new_kvs = kvs + .into_iter() + .filter_map(|(k, v)| process_node(ir, v, allow_partials_in_sub_nodes).ok().map(|v| (k, v))) + .collect(); + Ok(BamlValueWithMeta::Map(new_kvs, new_meta)) + } + }; + + new_value +} + +fn process_node2( + ir: &IntermediateRepr, + value: BamlValueWithMeta<((Vec, Vec), FieldType)>, + allow_partials: bool, +) -> Result< + BamlValueWithMeta<(Vec, Vec, Completion)>, + StreamingError, +> { + // let value_copy = value.clone(); + let ((flags, checks), field_type) = value.meta().clone(); + let complete = completion_state(&flags); + let (base_type, (_, streaming_behavior)) = ir.distribute_metadata(&field_type); + + let must_be_done = required_done(ir, &field_type) && allow_partials; + + let new_meta = ( + flags, + checks, + Completion { + state: complete.clone(), + display: streaming_behavior.state, + required_done: must_be_done, + } + ); + + if must_be_done && !(complete == CompletionState::Complete) { + return Err(StreamingError::IncompleteDoneValue); + // return Ok(BamlValueWithMeta::Null(new_meta)) + } + + let new_value = match value { + BamlValueWithMeta::String(s, _) => Ok(BamlValueWithMeta::String(s, new_meta)), + BamlValueWithMeta::Media(m, _) => Ok(BamlValueWithMeta::Media(m, new_meta)), + BamlValueWithMeta::Null(_) => Ok(BamlValueWithMeta::Null(new_meta)), + BamlValueWithMeta::Int(i, _) => Ok(BamlValueWithMeta::Int(i, new_meta)), + BamlValueWithMeta::Float(f, _) => Ok(BamlValueWithMeta::Float(f, new_meta)), + BamlValueWithMeta::Bool(b, _) => Ok(BamlValueWithMeta::Bool(b, new_meta)), + BamlValueWithMeta::List(items, _) => Ok(BamlValueWithMeta::List( + items + .into_iter() + .filter_map(|item| process_node2(ir, item, allow_partials).ok()) + .collect(), + new_meta, + )), + BamlValueWithMeta::Class(ref class_name, fields, _) => { + let field_names: HashSet = + fields.keys().into_iter().map(|s| s.to_string()).collect(); + let needed_fields: HashSet = needed_fields(ir, &field_type, allow_partials)?; + // let missing_needed_fields = needed_fields.difference(&new_field_names); + let present_nonnull_fields: HashSet = fields.iter().filter_map(|(field_name, field_value)| { + if matches!(field_value, BamlValueWithMeta::Null(_)) { + None + } else { + Some(field_name.to_string()) + } + }).collect(); + + let missing_needed_fields: HashSet<&String> = needed_fields.difference(&present_nonnull_fields).into_iter().collect(); + let unneeded_fields = field_names.difference(&needed_fields); + + let fields_needing_null = + fields_needing_null_filler(ir, &field_type, field_names, allow_partials)?; + + let mut deleted_fields: HashMap< + String, + BamlValueWithMeta<(Vec, Vec, Completion)>, + > = HashMap::new(); + + // let unneeded_fields = field_names.difference(&needed_fields); + let needed_nulls = fields_needing_null + .into_iter() + .filter_map(|ref null_field_name| { + let field = fields + .get(null_field_name) + .expect("This field is guaranteed to be in the field set"); + let use_state = type_streaming_behavior(ir, &field.meta().1).state; + let field_stream_state = Completion { + state: CompletionState::Incomplete, + display: use_state, + required_done: false, + }; + Some(( + null_field_name.to_string(), + BamlValueWithMeta::Null((Vec::new(), Vec::new(), field_stream_state)), + )) + }) + .collect::>>(); + + let mut new_fields = fields + .into_iter() + .filter_map(|(field_name, field_value)| { + let with_state = field_value + .meta() + .1 + .streaming_behavior() + .as_ref() + .map_or(false, |b| b.state); + let complete: CompletionState = completion_state(&field_value.meta().0 .0); + + match process_node2(ir, field_value, allow_partials) { + Ok(res) => Some((field_name, res)), + _ => { + let state = Completion { state: complete, display: with_state, required_done: false }; + let null = BamlValueWithMeta::Null((Vec::new(), Vec::new(), state)); + deleted_fields.insert(field_name, null); + None + } + } + }) + .collect::>>(); + + new_fields.extend(needed_nulls); + new_fields.extend(deleted_fields); + let res = BamlValueWithMeta::Class(class_name.clone(), new_fields, new_meta); + if missing_needed_fields.clone().len() == 0 { + Ok(res) + } else { + Err(StreamingError::MissingNeededFields) + } + } + BamlValueWithMeta::Enum(name, value, _) => { + Ok(BamlValueWithMeta::Enum(name, value, new_meta)) + } + BamlValueWithMeta::Map(kvs, _) => { + let new_kvs = kvs + .into_iter() + .filter_map(|(k, v)| process_node2(ir, v, allow_partials).ok().map(|v| (k, v))) + .collect(); + Ok(BamlValueWithMeta::Map(new_kvs, new_meta)) + } + }; + + new_value +} + +fn fields_needing_null_filler<'a>( + ir: &'a IntermediateRepr, + field_type: &'a FieldType, + value_names: HashSet, + allow_partials: bool, +) -> Result, anyhow::Error> { + if allow_partials == false { + return Ok(HashSet::new()); + } + let res = match ir.distribute_metadata(field_type).0 { + FieldType::Class(class_name) => match ir.find_class(class_name) { + Err(_) => Ok(HashSet::new()), + Ok(class) => { + let missing_fields = class + .walk_fields() + .filter_map(|field: Walker<'_, &Field>| { + if !value_names.contains(field.name()) { + Some(field.name().to_string()) + } else { + None + } + }) + .collect(); + Ok(missing_fields) + } + }, + _ => Err(StreamingError::ExpectedClass).context(format!( + "needed_fields expected Class, got type {field_type:?}" + )), + }; + res +} + +/// For a given type, assume that it is a class, and list the fields of that +/// class that were marked `@stream.not_null`. +/// The parameter must have already been passed through `distribute_metadata`, +/// it's an error to call this function with undistributed metadata. +/// +/// When allow_partials==false, we are in a context where we are done with +/// streaming, so we override the normal implemenation of this function +/// and return an empty set (because we are ignoring the "needed" property, +/// which only applies to mid-stream messages). +fn needed_fields( + ir: &IntermediateRepr, + field_type: &FieldType, + allow_partials: bool, +) -> Result, anyhow::Error> { + if allow_partials == false { + return Ok(HashSet::new()); + } + match ir.distribute_metadata(field_type).0 { + FieldType::Class(class_name) => { + let class = ir + .find_class(class_name) + .map_err(|_| StreamingError::ExpectedClass) + .context("needed_fields failed to lookup class")?; + let needed_fields = class + .walk_fields() + .filter_map(|field: Walker<'_, &Field>| { + if field.streaming_needed() { + Some(field.name().to_string()) + } else { + None + } + }) + .collect(); + Ok(needed_fields) + } + _ => Err(StreamingError::ExpectedClass).context(format!( + "needed_fields expected Class got field type {field_type:?}" + )), // TODO: Handle type aliases?. + } +} + +fn unneeded_fields( + ir: &IntermediateRepr, + field_type: &FieldType, +) -> Result, anyhow::Error> { + match ir.distribute_metadata(field_type).0 { + FieldType::Class(class_name) => { + let class = ir + .find_class(class_name) + .map_err(|_| StreamingError::ExpectedClass) + .context(format!( + "unneeded_fields could not look up class {class_name}", + ))?; + let unneeded_fields = class + .walk_fields() + .filter_map(|field: Walker<'_, &Field>| { + if field.streaming_needed() { + None + } else { + Some(field.name().to_string()) + } + }) + .collect(); + Ok(unneeded_fields) + } + _ => Err(StreamingError::ExpectedClass) + .context(format!("unneeded_fields expected Class got {field_type:?}")), + } +} + +/// Whether a type must be complete before being included as a node +/// in a streamed value. +fn required_done(ir: &IntermediateRepr, field_type: &FieldType) -> bool { + let (base_type, (_, streaming_behavior)) = ir.distribute_metadata(field_type); + let type_implies_done = match base_type { + FieldType::Primitive(tv) => match tv { + TypeValue::String => false, + TypeValue::Int => true, + TypeValue::Float => true, + TypeValue::Media(_) => true, + TypeValue::Bool => true, + TypeValue::Null => true, + }, + FieldType::Optional(_) => false, // TODO: Think so? Or depends on Optional's base? + FieldType::Literal(_) => true, + FieldType::List(_) => false, + FieldType::Map(_, _) => false, + FieldType::Enum(_) => true, + FieldType::Tuple(_) => false, + FieldType::RecursiveTypeAlias(_) => false, + FieldType::Class(_) => false, + FieldType::Union(_) => false, + FieldType::WithMetadata { .. } => { + unreachable!("distribute_metadata always consumes `WithMetadata`.") + } + }; + let res = type_implies_done || streaming_behavior.done; + res +} + +fn completion_state(flags: &Vec) -> CompletionState { + if flags + .iter() + .any(|f| matches!(f, Flag::Incomplete) || matches!(f, Flag::Pending)) + { + CompletionState::Incomplete + } else { + CompletionState::Complete + } +} + +fn type_streaming_behavior(ir: &IntermediateRepr, r#type: &FieldType) -> StreamingBehavior { + let (_base_type, (_constraints, streaming_behavior)) = ir.distribute_metadata(r#type); + streaming_behavior +} + +#[cfg(test)] +mod tests { + use internal_baml_core::ir::repr::make_test_ir; + + use crate::deserializer::{deserialize_flags::DeserializerConditions, types::ValueWithFlags}; + + use super::*; + + fn mk_null() -> BamlValueWithFlags { + BamlValueWithFlags::Null(DeserializerConditions::default()) + } + + fn mk_string(s: &str) -> BamlValueWithFlags { + BamlValueWithFlags::String(ValueWithFlags { + value: s.to_string(), + flags: DeserializerConditions::default(), + }) + } + fn mk_float(s: f64) -> BamlValueWithFlags { + BamlValueWithFlags::Float(ValueWithFlags { + value: s, + flags: DeserializerConditions::default(), + }) + } + + #[test] + fn recursive_type_alias() { + let ir = make_test_ir( + r##" + type A = A[] + "##, + ) + .unwrap(); + + fn mk_list(items: Vec) -> BamlValueWithFlags { + BamlValueWithFlags::List(DeserializerConditions::default(), items) + } + + let value = mk_list(vec![ + mk_list(vec![]), + mk_list(vec![]), + mk_list(vec![mk_list(vec![]), mk_list(vec![])]), + ]); + + let res = validate_streaming_state( + &ir, + &value, + &FieldType::RecursiveTypeAlias("A".to_string()), + true, + ) + .unwrap(); + + assert_eq!(res.into_iter().count(), 6); + } + + #[test] + fn stable_keys() { + let ir = make_test_ir( + r##" + class Address { + street string + state string + } + class Name { + first string + last string? + } + class Info { + name Name + address Address? + hair_color string + height float + } + "##, + ) + .unwrap(); + + let value = BamlValueWithFlags::Class( + "Info".to_string(), + DeserializerConditions::default(), + vec![ + ( + "name".to_string(), + BamlValueWithFlags::Class( + "Name".to_string(), + DeserializerConditions::default(), + vec![ + ("first".to_string(), mk_string("Greg")), + ("last".to_string(), mk_string("Hale")), + ] + .into_iter() + .collect(), + ), + ), + ("address".to_string(), mk_null()), + ("hair_color".to_string(), mk_string("Grey")), + ("height".to_string(), mk_float(1.75)), + ] + .into_iter() + .collect(), + ); + let field_type = FieldType::class("Info"); + + let res = validate_streaming_state(&ir, &value, &field_type, true).unwrap(); + + + // The first key should be "Name", matching the order specified in the + // original value. + match res { + BamlValueWithMeta::Class(_name, fields, _meta) => { + assert_eq!(fields.into_iter().next().unwrap().0.as_str(), "name"); + } + _ => panic!("Expected Class"), + } + } +} diff --git a/engine/baml-lib/jsonish/src/deserializer/types.rs b/engine/baml-lib/jsonish/src/deserializer/types.rs index b614b915e4..fdbb18f950 100644 --- a/engine/baml-lib/jsonish/src/deserializer/types.rs +++ b/engine/baml-lib/jsonish/src/deserializer/types.rs @@ -89,7 +89,31 @@ impl BamlValueWithFlags { } } -trait ParsingErrorToUiJson { +impl From for BamlValueWithMeta> { + fn from(baml_value: BamlValueWithFlags) -> BamlValueWithMeta> { + match baml_value { + BamlValueWithFlags::String(v) => BamlValueWithMeta::String(v.value, v.flags.flags), + BamlValueWithFlags::Int(v) => BamlValueWithMeta::Int(v.value, v.flags.flags), + BamlValueWithFlags::Float(v) => BamlValueWithMeta::Float(v.value, v.flags.flags), + BamlValueWithFlags::Bool(v) => BamlValueWithMeta::Bool(v.value, v.flags.flags), + BamlValueWithFlags::List(conditions, items) => { + BamlValueWithMeta::List(items.into_iter().map(|v| BamlValueWithMeta::from(v)).collect(), conditions.flags) + }, + BamlValueWithFlags::Map(conditions, fields) => BamlValueWithMeta::Map( + // NOTE: For some reason, Map is map, even though `v` contains conds. + // Maybe the extra conds are for the field, not the value? + fields.into_iter().map(|(k,v)| (k, BamlValueWithMeta::from(v.1))).collect(), conditions.flags + ), + BamlValueWithFlags::Enum(n,v) => BamlValueWithMeta::Enum(n, v.value, v.flags.flags), + BamlValueWithFlags::Class(name, conds, fields) => + BamlValueWithMeta::Class(name, fields.into_iter().map(|(k,v)| (k, BamlValueWithMeta::from(v))).collect(), conds.flags), + BamlValueWithFlags::Null(v) => BamlValueWithMeta::Null(v.flags), + BamlValueWithFlags::Media(v) => BamlValueWithMeta::Media(v.value, v.flags.flags), + } + } +} + +pub trait ParsingErrorToUiJson { fn to_ui_json(&self) -> serde_json::Value; } @@ -246,7 +270,7 @@ impl BamlValueWithFlags { #[derive(Debug, Clone)] pub struct ValueWithFlags { pub value: T, - pub(super) flags: DeserializerConditions, + pub flags: DeserializerConditions, } impl ValueWithFlags { diff --git a/engine/baml-lib/jsonish/src/helpers/common.rs b/engine/baml-lib/jsonish/src/helpers/common.rs new file mode 100644 index 0000000000..adc19ce7ce --- /dev/null +++ b/engine/baml-lib/jsonish/src/helpers/common.rs @@ -0,0 +1,70 @@ +use baml_types::{FieldType, LiteralValue}; +use internal_baml_jinja::types::{Builder, OutputFormatContent}; + +pub const CLASS_SCHEMA: &str = r#" +class Book { + title string + author string + year int + tags string[] + ratings Rating[] +} + +class Rating { + score int + reviewer string + date string +} +"#; + +pub const UNION_SCHEMA: &str = r#" +class TextContent { + text string +} + +class ImageContent { + url string + width int + height int +} + +class VideoContent { + url string + duration int +} + +class AudioContent { + type string + url string + duration int +} + +type JSONValue = int | float | bool | string | null | JSONValue[] | map +"#; + +pub const JSON_STRING: &str = r#" + { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3], + "object": { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3] + }, + "json": { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3], + "object": { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3] + } + } + } +"#; diff --git a/engine/baml-lib/jsonish/src/helpers/mod.rs b/engine/baml-lib/jsonish/src/helpers/mod.rs new file mode 100644 index 0000000000..e7d4f57d93 --- /dev/null +++ b/engine/baml-lib/jsonish/src/helpers/mod.rs @@ -0,0 +1,289 @@ +pub mod common; +use std::{collections::HashSet, path::PathBuf}; + +use anyhow::Result; +use baml_types::EvaluationContext; +use baml_types::{BamlValueWithMeta, ResponseCheck, StreamingBehavior}; +use indexmap::{IndexMap, IndexSet}; +use internal_baml_core::{ + ast::Field, + internal_baml_diagnostics::SourceFile, + ir::{repr::IntermediateRepr, ClassWalker, EnumWalker, FieldType, IRHelper, TypeValue}, + validate, +}; +use internal_baml_jinja::types::{Builder, Name, OutputFormatContent}; +use internal_baml_jinja::types::{Class, Enum}; + +use crate::deserializer::deserialize_flags::{constraint_results, Flag}; +use crate::deserializer::semantic_streaming::validate_streaming_state2; +use crate::{BamlValueWithFlags, ResponseBamlValue}; + +pub fn load_test_ir(file_content: &str) -> IntermediateRepr { + let mut schema = validate( + &PathBuf::from("./baml_src"), + vec![SourceFile::from(( + PathBuf::from("./baml_src/example.baml"), + file_content.to_string(), + ))], + ); + match schema.diagnostics.to_result() { + Ok(_) => {} + Err(e) => { + panic!("Failed to validate schema: {}", e); + } + } + + IntermediateRepr::from_parser_database(&schema.db, schema.configuration).unwrap() +} + +pub fn render_output_format( + ir: &IntermediateRepr, + output: &FieldType, + env_values: &EvaluationContext<'_>, +) -> Result { + let (enums, classes, recursive_classes, structural_recursive_aliases) = + relevant_data_models(ir, output, env_values)?; + + Ok(OutputFormatContent::target(output.clone()) + .enums(enums) + .classes(classes) + .recursive_classes(recursive_classes) + .structural_recursive_aliases(structural_recursive_aliases) + .build()) +} + +fn find_existing_class_field( + class_name: &str, + field_name: &str, + class_walker: &Result>, + env_values: &EvaluationContext<'_>, +) -> Result<(Name, FieldType, Option, bool)> { + let Ok(class_walker) = class_walker else { + anyhow::bail!("Class {} does not exist", class_name); + }; + + let Some(field_walker) = class_walker.find_field(field_name) else { + anyhow::bail!("Class {} does not have a field: {}", class_name, field_name); + }; + + let name = Name::new_with_alias(field_name.to_string(), field_walker.alias(env_values)?); + let desc = field_walker.description(env_values)?; + let r#type = field_walker.r#type(); + let streaming_needed = field_walker + .item + .attributes + .get("stream.not_null") + .is_some(); + Ok((name, r#type.clone(), desc, streaming_needed)) +} + +fn find_enum_value( + enum_name: &str, + value_name: &str, + enum_walker: &Result>, + env_values: &EvaluationContext<'_>, +) -> Result)>> { + if enum_walker.is_err() { + anyhow::bail!("Enum {} does not exist", enum_name); + } + + let value_walker = match enum_walker { + Ok(e) => e.find_value(value_name), + Err(_) => None, + }; + + let value_walker = match value_walker { + Some(v) => v, + None => return Ok(None), + }; + + if value_walker.skip(env_values)? { + return Ok(None); + } + + let name = Name::new_with_alias(value_name.to_string(), value_walker.alias(env_values)?); + let desc = value_walker.description(env_values)?; + + Ok(Some((name, desc))) +} + +// TODO: This function is "almost" a duplicate of `relevant_data_models` at +// baml-runtime/src/internal/prompt_renderer/render_output_format.rs +// +// Should be refactored. +// +// TODO: (Greg) Is the use of `String` as a hash key safe? Is there some way to +// get a collision that results in some type not getting put onto the stack? +fn relevant_data_models<'a>( + ir: &'a IntermediateRepr, + output: &'a FieldType, + env_values: &EvaluationContext<'_>, +) -> Result<( + Vec, + Vec, + IndexSet, + IndexMap, +)> { + let mut checked_types: HashSet = HashSet::new(); + let mut enums = Vec::new(); + let mut classes: Vec = Vec::new(); + let mut recursive_classes = IndexSet::new(); + let mut structural_recursive_aliases = IndexMap::new(); + let mut start: Vec = vec![output.clone()]; + + while let Some(output) = start.pop() { + match ir.distribute_constraints(&output) { + (FieldType::Enum(enm), constraints) => { + if checked_types.insert(output.to_string()) { + let walker = ir.find_enum(enm); + + let real_values = walker + .as_ref() + .map(|e| e.walk_values().map(|v| v.name().to_string())) + .ok(); + let values = real_values + .into_iter() + .flatten() + .map(|value| { + let meta = find_enum_value(enm.as_str(), &value, &walker, env_values)?; + Ok(meta) + }) + .filter_map(|v| v.transpose()) + .collect::>>()?; + + enums.push(Enum { + name: Name::new_with_alias(enm.to_string(), walker?.alias(env_values)?), + values, + constraints, + }); + } + } + (FieldType::List(inner), _constraints) | (FieldType::Optional(inner), _constraints) => { + if !checked_types.contains(&inner.to_string()) { + start.push(inner.as_ref().clone()); + } + } + (FieldType::Map(k, v), _constraints) => { + if checked_types.insert(output.to_string()) { + if !checked_types.contains(&k.to_string()) { + start.push(k.as_ref().clone()); + } + if !checked_types.contains(&v.to_string()) { + start.push(v.as_ref().clone()); + } + } + } + (FieldType::Tuple(options), _constraints) + | (FieldType::Union(options), _constraints) => { + if checked_types.insert(output.to_string()) { + for inner in options { + if !checked_types.contains(&inner.to_string()) { + start.push(inner.clone()); + } + } + } + } + (FieldType::Class(cls), constraints) => { + if checked_types.insert(output.to_string()) { + let walker = ir.find_class(cls); + + let real_fields = walker + .as_ref() + .map(|e| e.walk_fields().map(|v| v.name().to_string())) + .ok(); + + let fields = real_fields.into_iter().flatten().map(|field| { + let meta = find_existing_class_field(cls, &field, &walker, env_values)?; + Ok(meta) + }); + + let fields = fields.collect::>>()?; + + for (_, t, _, _) in fields.iter().as_ref() { + if !checked_types.contains(&t.to_string()) { + start.push(t.clone()); + } + } + + // TODO: O(n) algorithm. Maybe a Merge-Find Set can optimize + // this to O(log n) or something like that + // (maybe, IDK though ¯\_(ツ)_/¯) + // + // Also there's a lot of cloning in this process of going + // from Parser DB to IR to Jinja Output Format, not only + // with recursive classes but also the rest of models. + // There's room for optimization here. + // + // Also take a look at the TODO on top of this function. + for cycle in ir.finite_recursive_cycles() { + if cycle.contains(cls) { + recursive_classes.extend(cycle.iter().map(ToOwned::to_owned)); + } + } + + classes.push(Class { + name: Name::new_with_alias(cls.to_string(), walker?.alias(env_values)?), + fields, + constraints, + streaming_behavior: StreamingBehavior::default(), + }); + } + } + (FieldType::RecursiveTypeAlias(name), _) => { + // TODO: Same O(n) problem as above. + for cycle in ir.structural_recursive_alias_cycles() { + if cycle.contains_key(name) { + for (alias, target) in cycle.iter() { + structural_recursive_aliases.insert(alias.to_owned(), target.clone()); + } + } + } + } + (FieldType::Literal(_), _) => {} + (FieldType::Primitive(_), _constraints) => {} + (_, _) => { + // TODO: Don't use this wildcard. + unreachable!("It is guaranteed that a call to distribute_constraints will not return FieldType::Constrained") + } + } + } + + Ok(( + enums, + classes, + recursive_classes, + structural_recursive_aliases, + )) +} + +/// Validate a parsed value, checking asserts and checks. +pub fn parsed_value_to_response( + ir: &IntermediateRepr, + baml_value: BamlValueWithFlags, + field_type: &FieldType, + allow_partials: bool, +) -> Result { + let meta_flags: BamlValueWithMeta> = baml_value.into(); + + let baml_value_with_streaming2 = meta_flags.map_meta_owned(|flags| { + let constraint_results = constraint_results(&flags); + let response_checks: Vec = constraint_results + .iter() + .map(|(label, expr, result)| { + let status = (if *result { "succeeded" } else { "failed" }).to_string(); + ResponseCheck { + name: label.clone(), + expression: expr.0.clone(), + status, + } + }) + .collect(); + (flags, response_checks) + }); + + let response_value2 = + validate_streaming_state2(ir, baml_value_with_streaming2, field_type, allow_partials) + .map_err(|s| anyhow::anyhow!("{s}"))?; + + Ok(crate::ResponseBamlValue(response_value2)) +} diff --git a/engine/baml-lib/jsonish/src/jsonish/iterative_parser.rs b/engine/baml-lib/jsonish/src/jsonish/iterative_parser.rs index 4dc08cc371..5121e130a9 100644 --- a/engine/baml-lib/jsonish/src/jsonish/iterative_parser.rs +++ b/engine/baml-lib/jsonish/src/jsonish/iterative_parser.rs @@ -3,6 +3,8 @@ use std::iter::Peekable; use anyhow::Result; +use baml_types::CompletionState; +use crate::jsonish::Value; /* Try and see if there is a json object somewhere in the string * Could be a "[...] some text" or "{...} some text" or even a: @@ -11,7 +13,7 @@ use anyhow::Result; * ``` * block. */ -fn find_in_json_markdown(str: &str, options: &JSONishOptions) -> Result { +fn find_in_json_markdown(str: &str, options: &JSONishOptions) -> Result { let mut values = vec![]; let mut remaining = str; @@ -46,11 +48,11 @@ fn find_in_json_markdown(str: &str, options: &JSONishOptions) -> Result Err(anyhow::anyhow!("No JSON object found")), 1 => Ok(values[0].clone()), - _ => Ok(serde_json::Value::Array(values)), + _ => Ok(Value::Array(values, CompletionState::Complete)), } } -fn find_all_json_objects(input: &str, options: &JSONishOptions) -> Result { +fn find_all_json_objects(input: &str, options: &JSONishOptions) -> Result { let mut stack = Vec::new(); let mut json_str_start = None; let mut json_objects = Vec::new(); @@ -93,69 +95,69 @@ fn find_all_json_objects(input: &str, options: &JSONishOptions) -> Result Err(anyhow::anyhow!("No JSON objects found")), 1 => Ok(json_objects[0].clone()), - _ => Ok(json_objects.into()), + _ => Ok(Value::Array(json_objects, CompletionState::Incomplete)), } } #[derive(Debug)] enum JsonCollection { // Key, Value - Object(Vec, Vec), - Array(Vec), - QuotedString(String), - SingleQuotedString(String), + Object(Vec, Vec, CompletionState), + Array(Vec, CompletionState), + QuotedString(String, CompletionState), + SingleQuotedString(String, CompletionState), // Handles numbers, booleans, null, and unquoted strings - UnquotedString(String), + UnquotedString(String, CompletionState), // Starting with // or # - TrailingComment(String), + TrailingComment(String, CompletionState), // Content between /* and */ - BlockComment(String), + BlockComment(String, CompletionState), } impl JsonCollection { fn name(&self) -> &'static str { match self { - JsonCollection::Object(_, _) => "Object", - JsonCollection::Array(_) => "Array", - JsonCollection::QuotedString(_) => "String", - JsonCollection::SingleQuotedString(_) => "String", - JsonCollection::UnquotedString(_) => "UnquotedString", - JsonCollection::TrailingComment(_) => "Comment", - JsonCollection::BlockComment(_) => "Comment", + JsonCollection::Object(_, _, _) => "Object", + JsonCollection::Array(_, _) => "Array", + JsonCollection::QuotedString(_, _) => "String", + JsonCollection::SingleQuotedString(_, _) => "String", + JsonCollection::UnquotedString(_, _) => "UnquotedString", + JsonCollection::TrailingComment(_, _) => "Comment", + JsonCollection::BlockComment(_, _) => "Comment", } } } -impl From for Option { - fn from(collection: JsonCollection) -> Option { +impl From for Option { + fn from(collection: JsonCollection) -> Option { Some(match collection { - JsonCollection::TrailingComment(_) | JsonCollection::BlockComment(_) => return None, - JsonCollection::Object(keys, values) => { - let mut object = serde_json::Map::new(); + JsonCollection::TrailingComment(_, _) | JsonCollection::BlockComment(_, _) => return None, + JsonCollection::Object(keys, values, object_completion) => { + let mut object = Vec::new(); for (key, value) in keys.into_iter().zip(values.into_iter()) { - object.insert(key, value); + object.push((key, value)); } - serde_json::Value::Object(object) + Value::Object(object, object_completion) } - JsonCollection::Array(values) => serde_json::Value::Array(values), - JsonCollection::QuotedString(s) => serde_json::Value::String(s), - JsonCollection::SingleQuotedString(s) => serde_json::Value::String(s), - JsonCollection::UnquotedString(s) => { + JsonCollection::Array(values, completion_state) => Value::Array(values, completion_state), + JsonCollection::QuotedString(s, completion_state) => Value::String(s, completion_state), + JsonCollection::SingleQuotedString(s, completion_state) => Value::String(s, completion_state), + JsonCollection::UnquotedString(s, completion_state) => { let s = s.trim(); if s == "true" { - serde_json::Value::Bool(true) + Value::Boolean(true) } else if s == "false" { - serde_json::Value::Bool(false) + Value::Boolean(false) } else if s == "null" { - serde_json::Value::Null + Value::Null } else if let Ok(n) = s.parse::() { - serde_json::Value::Number(n.into()) + Value::Number(n.into(), completion_state) } else if let Ok(n) = s.parse::() { - serde_json::Value::Number(n.into()) + Value::Number(n.into(), completion_state) } else if let Ok(n) = s.parse::() { - serde_json::Value::Number(serde_json::Number::from_f64(n).unwrap()) + Value::Number(serde_json::Number::from_f64(n).unwrap(), completion_state) } else { - serde_json::Value::String(s.into()) + Value::String(s.into(), completion_state) } } }) @@ -166,7 +168,7 @@ struct JsonParseState { collection_stack: Vec, // Technically we may find multiple values in a single string - completed_values: Vec<(&'static str, serde_json::Value)>, + completed_values: Vec<(&'static str, Value)>, } impl JsonParseState { @@ -177,7 +179,7 @@ impl JsonParseState { } } - fn complete_collection(&mut self) { + fn complete_collection(&mut self, completion_state: CompletionState) { let collection = match self.collection_stack.pop() { Some(collection) => collection, None => return, @@ -187,24 +189,27 @@ impl JsonParseState { log::debug!("Completed: {:?} -> {:?}", name, collection); - let value: serde_json::Value = match collection.into() { + let mut value: crate::jsonish::Value = match collection.into() { Some(value) => value, None => return, }; + if completion_state == CompletionState::Complete { + value.complete_deeply(); + } if let Some(last) = self.collection_stack.last_mut() { match last { - JsonCollection::Object(keys, values) => { + JsonCollection::Object(keys, values, completion_state) => { if keys.len() == values.len() { match value { - serde_json::Value::String(s) => keys.push(s), + Value::String(s, completion_state) => keys.push(s), _ => keys.push(value.to_string()), } } else { values.push(value); } } - JsonCollection::Array(values) => { + JsonCollection::Array(values, _) => { values.push(value); } _ => { @@ -223,11 +228,11 @@ impl JsonParseState { fn consume(&mut self, token: char) -> Result { let last = self.collection_stack.last_mut().unwrap(); match last { - JsonCollection::QuotedString(s) - | JsonCollection::BlockComment(s) - | JsonCollection::SingleQuotedString(s) - | JsonCollection::UnquotedString(s) - | JsonCollection::TrailingComment(s) => { + JsonCollection::QuotedString(s, _) + | JsonCollection::BlockComment(s, _) + | JsonCollection::SingleQuotedString(s, _) + | JsonCollection::UnquotedString(s, _) + | JsonCollection::TrailingComment(s, _) => { // println!("Consuming: {s} + {:?}", token); s.push(token); } @@ -239,7 +244,8 @@ impl JsonParseState { } fn is_string_complete(&self) -> bool { - let Some(JsonCollection::UnquotedString(v)) = self.collection_stack.last() else { + // TODO: Do we need to consider the CompletionState here? + let Some(JsonCollection::UnquotedString(v, _)) = self.collection_stack.last() else { return false; }; @@ -259,19 +265,19 @@ impl JsonParseState { fn should_close_unescaped_string( &mut self, mut next: Peekable>, - ) -> Option { + ) -> CloseStringResult { let pos = if self.collection_stack.len() >= 2 { self.collection_stack .get(self.collection_stack.len() - 2) .map(|c| match c { - JsonCollection::Object(keys, values) => { + JsonCollection::Object(keys, values, _) => { if keys.len() == values.len() { 2 } else { 3 } } - JsonCollection::Array(_) => 4, + JsonCollection::Array(_, _) => 4, _ => 1, }) .unwrap() @@ -286,28 +292,28 @@ impl JsonParseState { counter = idx; match c { // If at some point we find a valid json character, we'll close the string - '{' | '[' => return Some(idx), + '{' | '[' => return CloseStringResult::Close(idx, CompletionState::Complete), x => { let _ = self.consume(x); } } } - Some(counter) + CloseStringResult::Close(counter, CompletionState::Incomplete) } - 1 => None, + 1 => CloseStringResult::Continue, 2 => { // in object key let mut counter = 0; for (idx, c) in next.by_ref() { counter = idx; match c { - ':' => return Some(idx), + ':' => return CloseStringResult::Close(idx, CompletionState::Complete), x => { let _ = self.consume(x); } } } - Some(counter) + CloseStringResult::Close(counter, CompletionState::Incomplete) } 3 => { // in object value @@ -319,23 +325,23 @@ impl JsonParseState { if let Some((_, next_c)) = next.peek() { match next_c { '\n' => { - return Some(idx); + return CloseStringResult::Close(idx, CompletionState::Complete); } _ => { let _ = self.consume(c); } } } else { - return Some(idx); + return CloseStringResult::Close(idx, CompletionState::Complete); } } - '}' => return Some(idx), + '}' => return CloseStringResult::Close(idx, CompletionState::Complete), x => { let _ = self.consume(x); } } } - Some(counter) + CloseStringResult::Close(counter, CompletionState::Incomplete) } 4 => { // in array @@ -343,14 +349,14 @@ impl JsonParseState { for (idx, c) in next { counter = idx; match c { - ',' => return Some(idx), - ']' => return Some(idx), + ',' => return CloseStringResult::Close(idx, CompletionState::Complete), + ']' => return CloseStringResult::Close(idx, CompletionState::Complete), x => { let _ = self.consume(x); } } } - Some(counter) + CloseStringResult::Close(counter, CompletionState::Incomplete) } _ => unreachable!("Invalid position"), } @@ -366,14 +372,14 @@ impl JsonParseState { self.collection_stack .get(self.collection_stack.len() - 2) .map(|c| match c { - JsonCollection::Object(keys, values) => { + JsonCollection::Object(keys, values, _) => { if keys.len() == values.len() { (true, false, false) } else { (false, true, true) } } - JsonCollection::Array(_) => (false, false, true), + JsonCollection::Array(_, _) => (false, false, true), _ => (false, false, false), }) .map(|(a, b, c)| (true, a, b, c)) @@ -459,11 +465,11 @@ impl JsonParseState { // println!("Processing: {:?}..{:?}", token, next.peek()); if let Some(last) = self.collection_stack.last() { match last { - JsonCollection::Object(_, _) => { + JsonCollection::Object(_, _, _) => { match token { '}' => { // We're ready to close the object - self.complete_collection(); + self.complete_collection(CompletionState::Complete); Ok(0) } // We can safely ignore these tokens @@ -472,7 +478,7 @@ impl JsonParseState { _ => self.find_any_starting_value(token, next), } } - JsonCollection::Array(_) => { + JsonCollection::Array(_, _) => { // We could be expecting: // - A value // - a comma @@ -480,7 +486,7 @@ impl JsonParseState { match token { ']' => { // We're ready to close the array - self.complete_collection(); + self.complete_collection(CompletionState::Complete); Ok(0) } // Skip these tokens @@ -488,7 +494,7 @@ impl JsonParseState { _ => self.find_any_starting_value(token, next), } } - JsonCollection::QuotedString(_) => { + JsonCollection::QuotedString(_, _) => { // We could be expecting: // - A closing quote // - A character @@ -497,7 +503,7 @@ impl JsonParseState { // It's possible that the LLM messed up the escaping // We'll try to fix it. if self.should_close_string(next, '"') { - self.complete_collection(); + self.complete_collection(CompletionState::Complete); Ok(0) } else { self.consume(token) @@ -506,7 +512,7 @@ impl JsonParseState { _ => self.consume(token), } } - JsonCollection::SingleQuotedString(_) => { + JsonCollection::SingleQuotedString(_, _) => { // We could be expecting: // - A closing quote // - A character @@ -516,7 +522,7 @@ impl JsonParseState { // It's possible that the LLM messed up the escaping // We'll try to fix it. if self.should_close_string(next, '\'') { - self.complete_collection(); + self.complete_collection(CompletionState::Complete); Ok(0) } else { self.consume(token) @@ -525,32 +531,32 @@ impl JsonParseState { _ => self.consume(token), } } - JsonCollection::UnquotedString(_) => { + JsonCollection::UnquotedString(_, _) => { // We could be expecting: // - A terminating json character (comma, colon, bracket, space, newline) // - A character let res = self.consume(token); - if let Some(count) = self.should_close_unescaped_string(next) { - self.complete_collection(); + if let CloseStringResult::Close(count, completion_state) = self.should_close_unescaped_string(next) { + self.complete_collection(completion_state); Ok(count) } else { res } } - JsonCollection::TrailingComment(_) => { + JsonCollection::TrailingComment(_, _) => { // We could be expecting: // - A newline // - A character match token { '\n' => { // We're ready to close the comment - self.complete_collection(); + self.complete_collection(CompletionState::Complete); Ok(0) } _ => self.consume(token), } } - JsonCollection::BlockComment(_) => { + JsonCollection::BlockComment(_, _) => { // We could be expecting: // - A closing comment // - A character @@ -560,7 +566,7 @@ impl JsonParseState { match next.peek() { Some((_, '/')) => { // We're ready to close the comment - self.complete_collection(); + self.complete_collection(CompletionState::Complete); Ok(1) } _ => Ok(0), @@ -588,30 +594,30 @@ impl JsonParseState { match token { '{' => { self.collection_stack - .push(JsonCollection::Object(vec![], vec![])); + .push(JsonCollection::Object(vec![], vec![], CompletionState::Incomplete)); } '[' => { - self.collection_stack.push(JsonCollection::Array(vec![])); + self.collection_stack.push(JsonCollection::Array(vec![], CompletionState::Incomplete)); } '"' => { self.collection_stack - .push(JsonCollection::QuotedString(String::new())); + .push(JsonCollection::QuotedString(String::new(), CompletionState::Incomplete)); } '\'' => { self.collection_stack - .push(JsonCollection::SingleQuotedString(String::new())); + .push(JsonCollection::SingleQuotedString(String::new(), CompletionState::Incomplete)); } '/' => { // Could be a comment match next.peek() { Some((_, '/')) => { self.collection_stack - .push(JsonCollection::TrailingComment(String::new())); + .push(JsonCollection::TrailingComment(String::new(), CompletionState::Incomplete)); return Ok(1); } Some((_, '*')) => { self.collection_stack - .push(JsonCollection::BlockComment(String::new())); + .push(JsonCollection::BlockComment(String::new(), CompletionState::Incomplete)); return Ok(1); } _ => {} @@ -620,9 +626,9 @@ impl JsonParseState { x if x.is_whitespace() => {} x => { self.collection_stack - .push(JsonCollection::UnquotedString(x.into())); - if let Some(count) = self.should_close_unescaped_string(next) { - self.complete_collection(); + .push(JsonCollection::UnquotedString(x.into(), CompletionState::Incomplete)); + if let CloseStringResult::Close(count, completion_state) = self.should_close_unescaped_string(next) { + self.complete_collection(completion_state); return Ok(count); } } @@ -632,7 +638,7 @@ impl JsonParseState { } } -pub fn try_fix_jsonish(str: &str) -> Result { +pub fn try_fix_jsonish(str: &str) -> Result { // Try to fix some common JSON issues // - Unquoted single word strings // - Single quoted strings @@ -670,7 +676,7 @@ pub fn try_fix_jsonish(str: &str) -> Result { // If we still have a collection open, close it while !state.collection_stack.is_empty() { - state.complete_collection(); + state.complete_collection(CompletionState::Incomplete); } // Determine what to return. @@ -683,12 +689,13 @@ pub fn try_fix_jsonish(str: &str) -> Result { } _ => { if state.completed_values.iter().all(|f| f.0 == "string") { - Ok(serde_json::Value::Array( + Ok(Value::Array( state.completed_values.iter().map(|f| f.1.clone()).collect(), + CompletionState::Incomplete )) } else { // Filter for only objects and arrays - let values: Vec = state + let values: Vec = state .completed_values .iter() .filter_map(|f| { @@ -702,7 +709,7 @@ pub fn try_fix_jsonish(str: &str) -> Result { match values.len() { 0 => Err(anyhow::anyhow!("No JSON objects found")), 1 => Ok(values[0].clone()), - _ => Ok(serde_json::Value::Array(values)), + _ => Ok(Value::Array(values, CompletionState::Incomplete)), // TODO: Correct completion state? } } } @@ -742,7 +749,7 @@ impl JSONishOptions { // Responsible for taking a string --> valid JSON // TODO: @hellovai add max recursive loop -pub fn parse_jsonish_value(str: &str, options: JSONishOptions) -> Result { +pub fn parse_jsonish_value(str: &str, options: JSONishOptions) -> Result { log::debug!("Parsing:\n{:?}\n-------\n{:?}\n-------", options, str); if options.depth > 10 { @@ -751,7 +758,9 @@ pub fn parse_jsonish_value(str: &str, options: JSONishOptions) -> Result return Ok(value), + Ok(value) => { + return Ok(value) + }, Err(e) => { log::trace!("Failed to parse JSON: {:?}\n{str}", e); } @@ -763,10 +772,10 @@ pub fn parse_jsonish_value(str: &str, options: JSONishOptions) -> Result 0 { return Ok(value); } - return Ok(serde_json::Value::Array(vec![ + return Ok(Value::Array(vec![ value, - serde_json::Value::String(str.into()), - ])); + Value::String(str.into(), CompletionState::Incomplete), // TODO: Correct? + ], CompletionState::Complete)); // TODO: Correct? } } @@ -776,10 +785,10 @@ pub fn parse_jsonish_value(str: &str, options: JSONishOptions) -> Result 0 { return Ok(value); } - return Ok(serde_json::Value::Array(vec![ + return Ok(Value::Array(vec![ value, - serde_json::Value::String(str.into()), - ])); + Value::String(str.into(), CompletionState::Complete), // TODO: Correct? + ], CompletionState::Complete)); // TODO: Correct? } } @@ -787,10 +796,10 @@ pub fn parse_jsonish_value(str: &str, options: JSONishOptions) -> Result { - return Ok(serde_json::Value::Array(vec![ + return Ok(Value::Array(vec![ value, - serde_json::Value::String(str.into()), - ])); + Value::String(str.into(), CompletionState::Complete), // TODO: Correct completion state? + ], CompletionState::Complete)); // TODO: Correct completion state? } Err(e) => { log::trace!("Failed to fix JSON: {:?}", e); @@ -801,8 +810,13 @@ pub fn parse_jsonish_value(str: &str, options: JSONishOptions) -> Result Result { } match serde_json::from_str(str) { - Ok(v) => return Ok(Value::AnyOf(vec![v], str.to_string())), + Ok(mut v) => { + match &mut v { + Value::String(_, completion_state) => { + // The string must have been contained in quotes in order + // to parse as a JSON string, therefore it is complete. + *completion_state = CompletionState::Complete; + } + Value::Number(_, completion_state) => { + *completion_state = CompletionState::Incomplete; + } + Value::Boolean(_) => {} + Value::Object(_, _) => {} + Value::Array(_, _) => {} + Value::Null => {} + Value::Markdown(_, _, completion_state) => { + *completion_state = CompletionState::Incomplete; + } + Value::FixedJson(_, _) => { + unreachable!("Serde deserializes into concrete values, not FixedJson") + } + Value::AnyOf(_, _) => { + unreachable!("Serde deserializes into concrete values, not AnyOf") + } + } + return Ok(Value::AnyOf(vec![v], str.to_string())); + } Err(e) => { log::debug!("Invalid JSON: {:?}", e); } @@ -38,7 +64,11 @@ pub fn parse(str: &str, mut options: ParseOptions) -> Result { match res { Some(MarkdownResult::CodeBlock(s, v)) => { return Ok(Value::AnyOf( - vec![Value::Markdown(s.to_string(), Box::new(v))], + vec![Value::Markdown( + s.to_string(), + Box::new(v), + CompletionState::Incomplete, + )], str.to_string(), )); } @@ -59,7 +89,9 @@ pub fn parse(str: &str, mut options: ParseOptions) -> Result { let others = items .iter() .filter_map(|res| match res { - MarkdownResult::String(s) => Some(Value::String(s.to_string())), + MarkdownResult::String(s) => { + Some(Value::String(s.to_string(), CompletionState::Incomplete)) + } _ => None, }) .map(|v| { @@ -85,9 +117,11 @@ pub fn parse(str: &str, mut options: ParseOptions) -> Result { MarkdownResult::CodeBlock(s, v) => Some((s, v)), _ => None, }) - .map(|(s, v)| Value::Markdown(s.to_string(), Box::new(v))) + .map(|(s, v)| { + Value::Markdown(s.to_string(), Box::new(v.clone()), v.completion_state().clone()) + }) .collect::>(); - let array = Value::Array(items.clone()); + let array = Value::Array(items.clone(), CompletionState::Incomplete); let items = items .into_iter() .chain(std::iter::once(array)) @@ -107,7 +141,9 @@ pub fn parse(str: &str, mut options: ParseOptions) -> Result { Ok(items) => match items.len() { 0 => {} 1 => { - return Ok(Value::AnyOf( + // eprintln!("MULTI_JSON: {items:?}"); + let ret = + Value::AnyOf( vec![Value::FixedJson( items .into_iter() @@ -117,14 +153,21 @@ pub fn parse(str: &str, mut options: ParseOptions) -> Result { vec![Fixes::GreppedForJSON], )], str.to_string(), - )) + ); + // eprintln!("ret: {ret:?}"); + return Ok(ret); } - _ => { - let items_clone = Value::Array(items.clone()); + n => { + let items_clone = Value::Array(items.clone(), CompletionState::Incomplete); let items = items .into_iter() .chain(std::iter::once(items_clone)) - .map(|v| Value::FixedJson(v.into(), vec![Fixes::GreppedForJSON])) + .map(|v| { + Value::FixedJson( + v.into(), + vec![Fixes::GreppedForJSON], + ) + }) .collect::>(); return Ok(Value::AnyOf(items, str.to_string())); } @@ -145,7 +188,10 @@ pub fn parse(str: &str, mut options: ParseOptions) -> Result { anyhow::anyhow!("Expected 1 item when performing fixes") })?; return Ok(Value::AnyOf( - vec![Value::FixedJson(v.into(), fixes)], + vec![Value::FixedJson( + v.into(), + fixes, + )], str.to_string(), )); } @@ -160,10 +206,12 @@ pub fn parse(str: &str, mut options: ParseOptions) -> Result { let items = items .into_iter() - .map(|(v, fixes)| Value::FixedJson(v.into(), fixes)) + .map(|(v, fixes)| { + Value::FixedJson(v.into(), fixes) + }) .collect::>(); - let items_clone = Value::Array(items.clone()); + let items_clone = Value::Array(items.clone(), CompletionState::Incomplete); let items = items .into_iter() @@ -180,8 +228,103 @@ pub fn parse(str: &str, mut options: ParseOptions) -> Result { } if options.allow_as_string { - return Ok(Value::String(str.to_string())); + return Ok(Value::String(str.to_string(), CompletionState::Incomplete)); } Err(anyhow::anyhow!("Failed to parse JSON")) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::jsonish::Value; + use baml_types::CompletionState; + + fn to_any_of(inner: Value, s: &str) -> Value { + Value::AnyOf(vec![inner], s.to_string()) + } + + fn to_fixed(inner: Value, fixes: &[Fixes]) -> Value { + Value::FixedJson(Box::new(inner), fixes.to_vec()) + } + + #[test] + fn test_partial_int() { + let res = parse("1", ParseOptions::default()).unwrap(); + assert_eq!( + res, + to_any_of(Value::Number(1.into(), CompletionState::Incomplete), "1") + ); + } + + #[test] + fn test_complete_list() { + let res = parse("[1]", ParseOptions::default()).unwrap(); + assert_eq!( + res, + to_any_of( + Value::Array( + vec![Value::Number(1.into(), CompletionState::Complete)], + CompletionState::Complete + ), + "[1]" + ) + ); + } + + #[test] + fn test_incomplete_list() { + let res = parse("[1, 2", ParseOptions::default()).unwrap(); + assert_eq!( + res, + to_any_of( + to_fixed( + to_any_of( + to_fixed( + Value::Array( + vec![ + Value::Number(1.into(), CompletionState::Complete), + Value::Number(2.into(), CompletionState::Incomplete), + ], + CompletionState::Incomplete + ), + &[] + ), + "[1, 2" + ), + &[Fixes::GreppedForJSON] + ), "[1, 2" + )); + } + + #[test] + fn test_incomplete_nested_list() { + let res = parse("[1, 2, [3", ParseOptions::default()).unwrap(); + assert_eq!( + res, + to_any_of( + to_fixed( + to_any_of( + to_fixed( + Value::Array( + vec![ + Value::Number(1.into(), CompletionState::Complete), + Value::Number(2.into(), CompletionState::Complete), + Value::Array( + vec![Value::Number(3.into(), CompletionState::Incomplete),], + CompletionState::Incomplete + ) + ], + CompletionState::Incomplete + ), + &[] + ), + "[1, 2, [3" + ), + &[Fixes::GreppedForJSON] + ), + "[1, 2, [3" + ) + ); + } +} diff --git a/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser.rs b/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser.rs index ee8cece485..0cf0bd15c7 100644 --- a/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser.rs +++ b/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser.rs @@ -1,6 +1,7 @@ mod json_collection; mod json_parse_state; +use baml_types::CompletionState; use crate::jsonish::{value::Fixes, Value}; use self::json_parse_state::JsonParseState; @@ -46,7 +47,7 @@ pub fn parse(str: &str, _options: &ParseOptions) -> Result Result { + Value::Array(xs, array_cmplt) => { assert_eq!(xs.len(), 1); + assert_eq!(array_cmplt, CompletionState::Incomplete); match &xs[0] { - Value::Number(n) => { + Value::Number(n, n_cmplt) => { dbg!(&n); assert_eq!(n, &serde_json::Number::from(12)); + assert_eq!(n_cmplt, &CompletionState::Incomplete); } _ => panic!("Expected number"), } @@ -125,14 +132,17 @@ mod tests { let vals = parse(r#"{"a": 11, "b": 22"#, &opts).unwrap(); dbg!(&vals); match &vals[0].0 { - Value::Object(fields) => { + Value::Object(fields, obj_cmplt) => { assert_eq!(fields.len(), 2); + assert_eq!(obj_cmplt, &CompletionState::Incomplete); match (&fields[0], &fields[1]) { - ((key_a, Value::Number(a)), (key_b, Value::Number(b))) => { + ((key_a, Value::Number(a, a_cmplt)), (key_b, Value::Number(b, b_cmplt))) => { assert_eq!(key_a.as_str(), "a"); assert_eq!(key_b.as_str(), "b"); assert_eq!(a, &serde_json::Number::from(11)); assert_eq!(b, &serde_json::Number::from(22)); + assert_eq!(a_cmplt, &CompletionState::Complete); + assert_eq!(b_cmplt, &CompletionState::Incomplete); } _ => panic!("Expected two numbers."), } @@ -147,14 +157,17 @@ mod tests { let vals = parse("{\n \"a\": 11, \n \"b\": 22", &opts).unwrap(); dbg!(&vals); match &vals[0].0 { - Value::Object(fields) => { + Value::Object(fields, obj_cmplt) => { assert_eq!(fields.len(), 2); + assert_eq!(obj_cmplt, &CompletionState::Incomplete); match (&fields[0], &fields[1]) { - ((key_a, Value::Number(a)), (key_b, Value::Number(b))) => { + ((key_a, Value::Number(a, a_cmplt)), (key_b, Value::Number(b, b_cmplt))) => { assert_eq!(key_a.as_str(), "a"); assert_eq!(key_b.as_str(), "b"); assert_eq!(a, &serde_json::Number::from(11)); assert_eq!(b, &serde_json::Number::from(22)); + assert_eq!(a_cmplt, &CompletionState::Complete); + assert_eq!(b_cmplt, &CompletionState::Incomplete); } _ => panic!("Expected two numbers."), } diff --git a/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser/json_collection.rs b/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser/json_collection.rs index 95ff3507ea..ebd5656bf6 100644 --- a/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser/json_collection.rs +++ b/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser/json_collection.rs @@ -1,16 +1,17 @@ use baml_types::BamlMap; use bstd::dedent; +use baml_types::CompletionState; use crate::jsonish::Value; #[derive(Debug)] pub enum JsonCollection { // Key, Value - Object(Vec, Vec), - Array(Vec), - QuotedString(String), - TripleQuotedString(String), - SingleQuotedString(String), + Object(Vec, Vec, CompletionState), + Array(Vec, CompletionState), + QuotedString(String, CompletionState), + TripleQuotedString(String, CompletionState), + SingleQuotedString(String, CompletionState), // edge cases that need handling: // - triple backticks in a triple backtick string // - will the LLM terminate a triple backtick with a single backtick? probably not @@ -20,32 +21,47 @@ pub enum JsonCollection { // - do we dedent the output? // - is it an acceptable heuristic to discard the first line of a triple backtick block? TripleBacktickString { - lang: Option, - path: Option, - content: String, + lang: Option<(String, CompletionState)>, + path: Option<(String, CompletionState)>, + content: (String, CompletionState), }, - BacktickString(String), + BacktickString(String, CompletionState), // Handles numbers, booleans, null, and unquoted strings - UnquotedString(String), + UnquotedString(String, CompletionState), // Starting with // or # - TrailingComment(String), + TrailingComment(String, CompletionState), // Content between /* and */ - BlockComment(String), + BlockComment(String, CompletionState), } impl JsonCollection { pub fn name(&self) -> &'static str { match self { - JsonCollection::Object(_, _) => "Object", - JsonCollection::Array(_) => "Array", - JsonCollection::QuotedString(_) => "String", - JsonCollection::SingleQuotedString(_) => "String", + JsonCollection::Object(_, _, _) => "Object", + JsonCollection::Array(_, _) => "Array", + JsonCollection::QuotedString(_, _) => "String", + JsonCollection::SingleQuotedString(_, _) => "String", JsonCollection::TripleBacktickString { .. } => "TripleBacktickString", - JsonCollection::BacktickString(_) => "String", - JsonCollection::TripleQuotedString(_) => "TripleQuotedString", - JsonCollection::UnquotedString(_) => "UnquotedString", - JsonCollection::TrailingComment(_) => "Comment", - JsonCollection::BlockComment(_) => "Comment", + JsonCollection::BacktickString(_, _) => "String", + JsonCollection::TripleQuotedString(_, _) => "TripleQuotedString", + JsonCollection::UnquotedString(_, _) => "UnquotedString", + JsonCollection::TrailingComment(_, _) => "Comment", + JsonCollection::BlockComment(_, _) => "Comment", + } + } + + pub fn completion_state(&self) -> &CompletionState { + match self { + JsonCollection::Object(_, _, s) => s, + JsonCollection::Array(_, s) => s, + JsonCollection::QuotedString(_, s) => s, + JsonCollection::SingleQuotedString(_, s) => s, + JsonCollection::TripleBacktickString { content, .. } => &content.1, // TODO: correct? + JsonCollection::BacktickString(_, s) => s, + JsonCollection::TripleQuotedString(_, s) => s, + JsonCollection::UnquotedString(_, s) => s, + JsonCollection::TrailingComment(_, s) => s, + JsonCollection::BlockComment(_, s) => s, } } } @@ -53,29 +69,29 @@ impl JsonCollection { impl From for Option { fn from(collection: JsonCollection) -> Option { Some(match collection { - JsonCollection::TrailingComment(_) | JsonCollection::BlockComment(_) => return None, - JsonCollection::Object(keys, values) => { + JsonCollection::TrailingComment(_, _) | JsonCollection::BlockComment(_, _) => return None, + JsonCollection::Object(keys, values, object_completion) => { // log::debug!("keys: {:?}", keys); - let mut object = Vec::new(); + let mut object: Vec<_> = Vec::new(); for (key, value) in keys.into_iter().zip(values.into_iter()) { object.push((key, value)); } - Value::Object(object) + Value::Object(object, object_completion) } - JsonCollection::Array(values) => Value::Array(values), - JsonCollection::QuotedString(s) => Value::String(s), - JsonCollection::TripleQuotedString(s) => Value::String(s), - JsonCollection::SingleQuotedString(s) => Value::String(s), + JsonCollection::Array(values, completion_state) => Value::Array(values, completion_state), + JsonCollection::QuotedString(s, completion_state) => Value::String(s, completion_state), + JsonCollection::TripleQuotedString(s, completion_state) => Value::String(s, completion_state), + JsonCollection::SingleQuotedString(s, completion_state) => Value::String(s, completion_state), JsonCollection::TripleBacktickString { content, .. } => { - let Some((fenced_codeblock_info, codeblock_contents)) = content.split_once("\n") + let Some((fenced_codeblock_info, codeblock_contents)) = content.0.split_once("\n") else { - return Some(Value::String(content)); + return Some(Value::String(content.0, content.1)); }; - Value::String(dedent(codeblock_contents).content) + Value::String(dedent(codeblock_contents).content, content.1) } - JsonCollection::BacktickString(s) => Value::String(s), - JsonCollection::UnquotedString(s) => { + JsonCollection::BacktickString(s, completion_state) => Value::String(s, completion_state), + JsonCollection::UnquotedString(s, completion_state) => { let s = s.trim(); if s == "true" { Value::Boolean(true) @@ -84,16 +100,16 @@ impl From for Option { } else if s == "null" { Value::Null } else if let Ok(n) = s.parse::() { - Value::Number(n.into()) + Value::Number(n.into(), completion_state) } else if let Ok(n) = s.parse::() { - Value::Number(n.into()) + Value::Number(n.into(), completion_state) } else if let Ok(n) = s.parse::() { match serde_json::Number::from_f64(n) { - Some(n) => Value::Number(n), - None => Value::String(s.into()), + Some(n) => Value::Number(n, completion_state), + None => Value::String(s.into(), completion_state), } } else { - Value::String(s.into()) + Value::String(s.into(), completion_state) } } }) diff --git a/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser/json_parse_state.rs b/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser/json_parse_state.rs index 85969db58c..e88bdeb0c6 100644 --- a/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser/json_parse_state.rs +++ b/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser/json_parse_state.rs @@ -1,14 +1,21 @@ use std::iter::Peekable; use crate::jsonish::{value::Fixes, Value}; +use baml_types::CompletionState; use anyhow::Result; use super::json_collection::JsonCollection; +#[derive(Debug)] pub struct JsonParseState { + /// The stack of Json collection values being assembled. + /// The stack-ness is used in order to parse nested values, + /// e.g. an object with fields of list, or lists of lists. pub collection_stack: Vec<(JsonCollection, Vec)>, - // Technically we may find multiple values in a single string + /// Values for which parsing is completed, and popped off of the + /// collection stack. + /// Technically we may find multiple values in a single string pub completed_values: Vec<(&'static str, Value, Vec)>, } @@ -20,7 +27,13 @@ impl JsonParseState { } } - pub fn complete_collection(&mut self) { + /// Examine the top of the collection stack, popping it off and + /// adding it to `completed_values` if it is ready. + /// + /// The `completion_state` parameter is applied to the value being + /// completed. If it is `CompletionState::Complete`, we also apply + /// that state to the children of the value being completed. + pub fn complete_collection(&mut self, completion_state: CompletionState) { let (collection, fixes) = match self.collection_stack.pop() { Some(collection) => collection, None => return, @@ -28,17 +41,20 @@ impl JsonParseState { let name = collection.name(); - let value: Value = match collection.into() { + let mut value: Value = match collection.into() { Some(value) => value, None => return, }; + if completion_state == CompletionState::Complete { + value.complete_deeply(); + } if let Some((last, _fixes)) = self.collection_stack.last_mut() { match last { - JsonCollection::Object(keys, values) => { + JsonCollection::Object(keys, values, _) => { if keys.len() == values.len() { match value { - Value::String(s) => keys.push(s), + Value::String(s,_) => keys.push(s), Value::AnyOf(_, s) => keys.push(s), _ => keys.push(value.to_string()), } @@ -46,7 +62,7 @@ impl JsonParseState { values.push(value); } } - JsonCollection::Array(values) => { + JsonCollection::Array(values, _) => { values.push(value); } _ => { @@ -70,18 +86,18 @@ impl JsonParseState { )); }; match last { - JsonCollection::QuotedString(s) - | JsonCollection::TripleQuotedString(s) - | JsonCollection::BlockComment(s) - | JsonCollection::SingleQuotedString(s) - | JsonCollection::BacktickString(s) - | JsonCollection::TripleBacktickString { content: s, .. } - | JsonCollection::UnquotedString(s) - | JsonCollection::TrailingComment(s) => { + JsonCollection::QuotedString(s, _) + | JsonCollection::TripleQuotedString(s, _) + | JsonCollection::BlockComment(s, _) + | JsonCollection::SingleQuotedString(s, _) + | JsonCollection::BacktickString(s, _) + | JsonCollection::TripleBacktickString { content: (s, _), .. } + | JsonCollection::UnquotedString(s, _) + | JsonCollection::TrailingComment(s, _) => { // println!("Consuming: {s} + {:?}", token); s.push(token); } - JsonCollection::Object(_, _) | JsonCollection::Array(_) => { + JsonCollection::Object(_, _, _) | JsonCollection::Array(_, _) => { panic!("Unexpected token: {:?} in: {:?}", token, last); } } @@ -89,7 +105,7 @@ impl JsonParseState { } fn is_string_complete(&self) -> bool { - let Some((JsonCollection::UnquotedString(v), _)) = self.collection_stack.last() else { + let Some((JsonCollection::UnquotedString(v, _), _)) = self.collection_stack.last() else { return false; }; @@ -109,19 +125,19 @@ impl JsonParseState { fn should_close_unescaped_string( &mut self, mut next: Peekable>, - ) -> Option { + ) -> CloseStringResult { let pos = if self.collection_stack.len() >= 2 { self.collection_stack .get(self.collection_stack.len() - 2) .map(|(c, _)| match c { - JsonCollection::Object(keys, values) => { + JsonCollection::Object(keys, values, _) => { if keys.len() == values.len() { 2 } else { 3 } } - JsonCollection::Array(_) => 4, + JsonCollection::Array(_, _) => 4, _ => 1, }) .unwrap() @@ -136,28 +152,28 @@ impl JsonParseState { counter = idx; match c { // If at some point we find a valid json character, we'll close the string - '{' | '[' => return Some(idx), + '{' | '[' => return CloseStringResult::Close(idx, CompletionState::Complete), x => { let _ = self.consume(x); } } } - Some(counter) + CloseStringResult::Close(counter, CompletionState::Incomplete) } - 1 => None, + 1 => CloseStringResult::Continue, 2 => { // in object key let mut counter = 0; for (idx, c) in next.by_ref() { counter = idx; match c { - ':' => return Some(idx), + ':' => return CloseStringResult::Close(idx, CompletionState::Complete), x => { let _ = self.consume(x); } } } - Some(counter) + CloseStringResult::Close(counter, CompletionState::Incomplete) } 3 => { // in object value @@ -167,10 +183,10 @@ impl JsonParseState { match c { ',' => { // Check if we have just numeric values in the string so far. - let Some((JsonCollection::UnquotedString(current_value), _)) = + let Some((JsonCollection::UnquotedString(current_value, _), _)) = self.collection_stack.last() else { - return Some(idx); + return CloseStringResult::Close(idx, CompletionState::Complete); }; // current value could be a numeric looking things. @@ -184,12 +200,12 @@ impl JsonParseState { match next_c { '\n' => { log::debug!("Closing due to: newline after comma"); - return Some(idx); + return CloseStringResult::Close(idx, CompletionState::Complete); } ' ' => { log::debug!("Testing for comment after space + comma"); if is_possible_value { - return Some(idx); + return CloseStringResult::Close(idx, CompletionState::Complete); } // If after the space we have "//" or "/*" or the beginning of a key, we'll close the string let mut buffer = ",".to_string(); @@ -206,17 +222,17 @@ impl JsonParseState { // Likely end of the key as the LLM generated a ", " token by mistake instead of a "," // so drop the comma log::debug!("Closing due to: newline after comma + space"); - return Some(idx); + return CloseStringResult::Close(idx, CompletionState::Complete); } } '/' => match next.peek() { Some((_, '/')) => { // This is likely a comment - return Some(idx); + return CloseStringResult::Close(idx, CompletionState::Complete); } Some((_, '*')) => { // This is likely a comment - return Some(idx); + return CloseStringResult::Close(idx, CompletionState::Complete); } _ => { // let _ = self.consume(c); @@ -225,7 +241,7 @@ impl JsonParseState { '"' => { // This is likely a new key log::debug!("Closing due to: new key after space + comma"); - return Some(idx); + return CloseStringResult::Close(idx, CompletionState::Complete); } _x => { break; @@ -242,16 +258,16 @@ impl JsonParseState { } } else { // Don't include the comma - return Some(idx); + return CloseStringResult::Close(idx, CompletionState::Complete); } } - '}' => return Some(idx), + '}' => return CloseStringResult::Close(idx, CompletionState::Complete), x => { let _ = self.consume(x); } } } - Some(counter) + CloseStringResult::Close(counter, CompletionState::Incomplete) } 4 => { // in array @@ -259,15 +275,15 @@ impl JsonParseState { for (idx, c) in next { counter = idx; match c { - ',' => return Some(idx), - ']' => return Some(idx), + ',' => return CloseStringResult::Close(idx, CompletionState::Complete), + ']' => return CloseStringResult::Close(idx, CompletionState::Complete), x => { let _ = self.consume(x); } } } counter += 1; // Indicate that we called next() one time after the final `Some`. - Some(counter) + CloseStringResult::Close(counter, CompletionState::Incomplete) } _ => unreachable!("Invalid position"), } @@ -283,14 +299,14 @@ impl JsonParseState { self.collection_stack .get(self.collection_stack.len() - 2) .map(|(c, _)| match c { - JsonCollection::Object(keys, values) => { + JsonCollection::Object(keys, values, _) => { if keys.len() == values.len() { (true, false, false) } else { (false, true, true) } } - JsonCollection::Array(_) => (false, false, true), + JsonCollection::Array(_, _) => (false, false, true), _ => (false, false, false), }) .map(|(a, b, c)| (true, a, b, c)) @@ -376,11 +392,11 @@ impl JsonParseState { // println!("Processing: {:?}..{:?}", token, next.peek()); match self.collection_stack.last() { Some((last, _)) => match last { - JsonCollection::Object(_, _) => { + JsonCollection::Object(_, _, _) => { match token { '}' => { // We're ready to close the object - self.complete_collection(); + self.complete_collection(CompletionState::Complete); Ok(0) } // We can safely ignore these tokens @@ -389,7 +405,7 @@ impl JsonParseState { _ => self.find_any_starting_value(token, next), } } - JsonCollection::Array(_) => { + JsonCollection::Array(_, _) => { // We could be expecting: // - A value // - a comma @@ -397,7 +413,7 @@ impl JsonParseState { match token { ']' => { // We're ready to close the array - self.complete_collection(); + self.complete_collection(CompletionState::Complete); Ok(0) } // Skip these tokens @@ -405,7 +421,7 @@ impl JsonParseState { _ => self.find_any_starting_value(token, next), } } - JsonCollection::TripleQuotedString(_) => { + JsonCollection::TripleQuotedString(_, _) => { // We should be expecting: if token == '"' { // TODO: this logic is busted. peekable.peek() does not @@ -419,7 +435,7 @@ impl JsonParseState { }; if is_triple_quoted { - self.complete_collection(); + self.complete_collection(CompletionState::Complete); Ok(3) } else { self.consume(token) @@ -428,7 +444,7 @@ impl JsonParseState { self.consume(token) } } - JsonCollection::QuotedString(_) => { + JsonCollection::QuotedString(_, _) => { // We could be expecting: // - A closing quote // - A character @@ -437,7 +453,7 @@ impl JsonParseState { // It's possible that the LLM messed up the escaping // We'll try to fix it. if self.should_close_string(next, '"') { - self.complete_collection(); + self.complete_collection(CompletionState::Complete); Ok(0) } else { self.consume(token) @@ -512,7 +528,7 @@ impl JsonParseState { }; if is_triple_quoted { - self.complete_collection(); + self.complete_collection(CompletionState::Complete); Ok(3) } else { self.consume(token) @@ -521,14 +537,14 @@ impl JsonParseState { self.consume(token) } } - JsonCollection::BacktickString(_) => { + JsonCollection::BacktickString(_, _) => { // We could be expecting: // - A closing backtick // - A character match token { '`' => { if self.should_close_string(next, '`') { - self.complete_collection(); + self.complete_collection(CompletionState::Complete); Ok(0) } else { self.consume(token) @@ -537,7 +553,7 @@ impl JsonParseState { _ => self.consume(token), } } - JsonCollection::SingleQuotedString(_) => { + JsonCollection::SingleQuotedString(_, _) => { // We could be expecting: // - A closing quote // - A character @@ -547,7 +563,7 @@ impl JsonParseState { // It's possible that the LLM messed up the escaping // We'll try to fix it. if self.should_close_string(next, '\'') { - self.complete_collection(); + self.complete_collection(CompletionState::Complete); Ok(0) } else { self.consume(token) @@ -556,32 +572,32 @@ impl JsonParseState { _ => self.consume(token), } } - JsonCollection::UnquotedString(_) => { + JsonCollection::UnquotedString(_, _) => { // We could be expecting: // - A terminating json character (comma, colon, bracket, space, newline) // - A character let res = self.consume(token); - if let Some(count) = self.should_close_unescaped_string(next) { - self.complete_collection(); + if let CloseStringResult::Close(count, completion) = self.should_close_unescaped_string(next) { + self.complete_collection(completion); Ok(count) } else { res } } - JsonCollection::TrailingComment(_) => { + JsonCollection::TrailingComment(_, _) => { // We could be expecting: // - A newline // - A character match token { '\n' => { // We're ready to close the comment - self.complete_collection(); + self.complete_collection(CompletionState::Complete); Ok(0) } _ => self.consume(token), } } - JsonCollection::BlockComment(_) => { + JsonCollection::BlockComment(_, _) => { // We could be expecting: // - A closing comment // - A character @@ -591,7 +607,7 @@ impl JsonParseState { match next.peek() { Some((_, '/')) => { // We're ready to close the comment - self.complete_collection(); + self.complete_collection(CompletionState::Complete); Ok(1) } _ => Ok(0), @@ -620,11 +636,11 @@ impl JsonParseState { match token { '{' => { self.collection_stack - .push((JsonCollection::Object(vec![], vec![]), Default::default())); + .push((JsonCollection::Object(vec![], vec![], CompletionState::Incomplete), Default::default())); } '[' => { self.collection_stack - .push((JsonCollection::Array(vec![]), Default::default())); + .push((JsonCollection::Array(vec![], CompletionState::Incomplete), Default::default())); } '"' => { // Peek if next 2 characters are also quotes @@ -636,20 +652,20 @@ impl JsonParseState { if is_triple_quoted { self.collection_stack.push(( - JsonCollection::TripleQuotedString(String::new()), + JsonCollection::TripleQuotedString(String::new(), CompletionState::Incomplete), Default::default(), )); return Ok(2); } else { self.collection_stack.push(( - JsonCollection::QuotedString(String::new()), + JsonCollection::QuotedString(String::new(), CompletionState::Incomplete), Default::default(), )) } } '\'' => { self.collection_stack.push(( - JsonCollection::SingleQuotedString(String::new()), + JsonCollection::SingleQuotedString(String::new(), CompletionState::Incomplete), Default::default(), )); } @@ -666,14 +682,14 @@ impl JsonParseState { JsonCollection::TripleBacktickString { lang: None, path: None, - content: String::new(), + content: (String::new(), CompletionState::Incomplete), }, Default::default(), )); return Ok(2); } else { self.collection_stack.push(( - JsonCollection::BacktickString(String::new()), + JsonCollection::BacktickString(String::new(), CompletionState::Incomplete), Default::default(), )) } @@ -683,14 +699,14 @@ impl JsonParseState { match next.peek() { Some((_, '/')) => { self.collection_stack.push(( - JsonCollection::TrailingComment(String::new()), + JsonCollection::TrailingComment(String::new(), CompletionState::Incomplete), Default::default(), )); return Ok(1); } Some((_, '*')) => { self.collection_stack.push(( - JsonCollection::BlockComment(String::new()), + JsonCollection::BlockComment(String::new(), CompletionState::Incomplete), Default::default(), )); return Ok(1); @@ -700,10 +716,10 @@ impl JsonParseState { // say a path? if matches!( self.collection_stack.last(), - Some((JsonCollection::Object(_, _), _)) + Some((JsonCollection::Object(_, _, _), _)) ) { self.collection_stack.push(( - JsonCollection::UnquotedString(token.into()), + JsonCollection::UnquotedString(token.into(), CompletionState::Incomplete), Default::default(), )); return Ok(0); @@ -714,9 +730,9 @@ impl JsonParseState { x if x.is_whitespace() => {} x => { self.collection_stack - .push((JsonCollection::UnquotedString(x.into()), Default::default())); - if let Some(count) = self.should_close_unescaped_string(next) { - self.complete_collection(); + .push((JsonCollection::UnquotedString(x.into(), CompletionState::Incomplete), Default::default())); + if let CloseStringResult::Close(count, completion) = self.should_close_unescaped_string(next) { + self.complete_collection(completion); return Ok(count); } } @@ -725,3 +741,9 @@ impl JsonParseState { Ok(0) } } + +#[derive(Debug, PartialEq)] +enum CloseStringResult { + Close(usize, CompletionState), + Continue +} diff --git a/engine/baml-lib/jsonish/src/jsonish/parser/markdown_parser.rs b/engine/baml-lib/jsonish/src/jsonish/parser/markdown_parser.rs index 974d90a7ec..93f168b292 100644 --- a/engine/baml-lib/jsonish/src/jsonish/parser/markdown_parser.rs +++ b/engine/baml-lib/jsonish/src/jsonish/parser/markdown_parser.rs @@ -13,7 +13,7 @@ pub enum MarkdownResult { } pub fn parse(str: &str, options: &ParseOptions) -> Result> { - let mut values = vec![]; + let mut values: Vec = vec![]; let mut remaining = str; // Find regex for markdown blocks (```) @@ -47,6 +47,7 @@ pub fn parse(str: &str, options: &ParseOptions) -> Result> { match res { Ok(v) => { + // eprintln!("Pushing value {v:?}"); // TODO: Add any more additional strings here. values.push(MarkdownResult::CodeBlock( if tag.len() > 3 { @@ -80,6 +81,8 @@ pub fn parse(str: &str, options: &ParseOptions) -> Result> { #[cfg(test)] mod test { + use baml_types::CompletionState; + use super::*; use test_log::test; @@ -118,9 +121,13 @@ print("Hello, world!") panic!("Expected AnyOf, got {:#?}", value); }; assert!(value.contains(&Value::Object( - [("a".to_string(), Value::Number((1).into()))] - .into_iter() - .collect() + [( + "a".to_string(), + Value::Number((1).into(), CompletionState::Complete) + )] + .into_iter() + .collect(), + CompletionState::Complete ))); } { @@ -134,7 +141,11 @@ print("Hello, world!") let Value::AnyOf(value, _) = value else { panic!("Expected AnyOf, got {:#?}", value); }; - assert!(value.contains(&Value::String("This is a test".to_string()))); + // dbg!(&value); + assert!(value.contains(&Value::String( + "This is a test".to_string(), + CompletionState::Complete + ))); } Ok(()) diff --git a/engine/baml-lib/jsonish/src/jsonish/parser/multi_json_parser.rs b/engine/baml-lib/jsonish/src/jsonish/parser/multi_json_parser.rs index 99c5743d37..cfac769532 100644 --- a/engine/baml-lib/jsonish/src/jsonish/parser/multi_json_parser.rs +++ b/engine/baml-lib/jsonish/src/jsonish/parser/multi_json_parser.rs @@ -38,7 +38,9 @@ pub fn parse(str: &str, options: &ParseOptions) -> Result> { json_str, options.next_from_mode(super::ParsingMode::AllJsonObjects), ) { - Ok(json) => json_objects.push(json), + Ok(json) => { + json_objects.push(json) + }, Err(e) => { // Ignore errors log::error!("Failed to parse JSON object: {:?}", e); @@ -59,7 +61,11 @@ pub fn parse(str: &str, options: &ParseOptions) -> Result> { json_str, options.next_from_mode(super::ParsingMode::AllJsonObjects), ) { - Ok(json) => json_objects.push(json), + Ok(json) => { + complete_stack_head(&mut json_objects); + json_objects.push(json) + + }, Err(e) => { // Ignore errors log::error!("Failed to parse JSON object: {:?}", e); @@ -78,8 +84,17 @@ pub fn parse(str: &str, options: &ParseOptions) -> Result> { } } +fn complete_stack_head(stack: &mut Vec) { + match stack.last_mut() { + Some(v) => { v.complete_deeply(); }, + None => {}, + } +} + #[cfg(test)] mod test { + use baml_types::CompletionState; + use super::*; use test_log::test; @@ -112,9 +127,13 @@ print("Hello, world!") panic!("Expected AnyOf, got {:#?}", value); }; assert!(value.contains(&Value::Object( - [("a".to_string(), Value::Number((1).into()))] - .into_iter() - .collect() + [( + "a".to_string(), + Value::Number((1).into(), CompletionState::Complete) + )] + .into_iter() + .collect(), + CompletionState::Complete ))); } { @@ -122,9 +141,13 @@ print("Hello, world!") let Value::AnyOf(value, _) = value else { panic!("Expected AnyOf, got {:#?}", value); }; - assert!(value.contains(&Value::Array(vec![Value::String( - "This is a test".to_string() - )]))); + assert!(value.contains(&Value::Array( + vec![Value::String( + "This is a test".to_string(), + CompletionState::Complete + )], + CompletionState::Complete + ))); } Ok(()) diff --git a/engine/baml-lib/jsonish/src/jsonish/value.rs b/engine/baml-lib/jsonish/src/jsonish/value.rs index 5a43d08f78..fb568fa6cd 100644 --- a/engine/baml-lib/jsonish/src/jsonish/value.rs +++ b/engine/baml-lib/jsonish/src/jsonish/value.rs @@ -3,7 +3,7 @@ use std::{ hash::{Hash, Hasher}, }; -use baml_types::BamlMap; +use baml_types::{BamlMap, CompletionState}; #[derive(Debug, Clone, PartialEq, Eq)] pub enum Fixes { @@ -14,42 +14,47 @@ pub enum Fixes { #[derive(Debug, Clone, PartialEq, Eq)] pub enum Value { // Primitive Types - String(String), - Number(serde_json::Number), + String(String, CompletionState), + Number(serde_json::Number, CompletionState), Boolean(bool), Null, // Complex Types - Object(Vec<(String, Value)>), - Array(Vec), + // Note: Greg - should keys carry completion state? + // During parsing, if we hare an incomplete key, does the parser + // complete it and set its value to null? Or drop it? + // If the parser drops it, we don't need to carry CompletionState. + Object(Vec<(String, Value)>, CompletionState), + Array(Vec, CompletionState), // Fixed types - Markdown(String, Box), - FixedJson(Box, Vec), + Markdown(String, Box, CompletionState), + FixedJson(Box, Vec), // TODO: Does this really need a CompletionState? AnyOf(Vec, String), } impl Hash for Value { + // Hashing a Value ignores CompletationState. fn hash(&self, state: &mut H) { std::mem::discriminant(self).hash(state); match self { - Value::String(s) => s.hash(state), - Value::Number(n) => n.to_string().hash(state), + Value::String(s, _) => s.hash(state), + Value::Number(n, _) => n.to_string().hash(state), Value::Boolean(b) => b.hash(state), Value::Null => "null".hash(state), - Value::Object(o) => { + Value::Object(o, _) => { for (k, v) in o { k.hash(state); v.hash(state); } } - Value::Array(a) => { + Value::Array(a, _) => { for v in a { v.hash(state); } } - Value::Markdown(s, v) => { + Value::Markdown(s, v, _) => { s.hash(state); v.hash(state); } @@ -63,14 +68,15 @@ impl Hash for Value { } } + impl Value { pub fn r#type(&self) -> String { match self { - Value::String(_) => "String".to_string(), - Value::Number(_) => "Number".to_string(), + Value::String(_, _) => "String".to_string(), + Value::Number(_, _) => "Number".to_string(), Value::Boolean(_) => "Boolean".to_string(), Value::Null => "Null".to_string(), - Value::Object(k) => { + Value::Object(k, _) => { let mut s = "Object{".to_string(); for (key, value) in k.iter() { s.push_str(&format!("{}: {}, ", key, value.r#type())); @@ -78,7 +84,7 @@ impl Value { s.push('}'); s } - Value::Array(i) => { + Value::Array(i, _) => { let mut s = "Array[".to_string(); let items = i .iter() @@ -91,7 +97,7 @@ impl Value { s.push(']'); s } - Value::Markdown(tag, item) => { + Value::Markdown(tag, item, _) => { format!("Markdown:{} - {}", tag, item.r#type()) } Value::FixedJson(inner, fixes) => { @@ -108,16 +114,81 @@ impl Value { } } } + + pub fn completion_state(&self) -> &CompletionState { + match self { + Value::String(_, s) => s, + Value::Number(_, s) => s, + Value::Boolean(_) => &CompletionState::Complete, + Value::Null => &CompletionState::Complete, + Value::Object(_, s) => s, + Value::Array(_, s) => s, + Value::Markdown(_, _, s) => s, + Value::FixedJson(_, _) => &CompletionState::Complete, + Value::AnyOf(choices, _) => { + if choices + .iter() + .any(|c| c.completion_state() == &CompletionState::Incomplete) + { + &CompletionState::Incomplete + } else { + &CompletionState::Complete + } + } + } + } + + pub fn complete_deeply(&mut self) { + match self { + Value::String(_, s) => *s = CompletionState::Complete, + Value::Number(_, s) => *s = CompletionState::Complete, + Value::Boolean(_) => {} + Value::Null => {} + Value::Object(kv_pairs, s) => { + *s = CompletionState::Complete; + kv_pairs.iter_mut().for_each(|(_, v)| v.complete_deeply()); + } + Value::Array(elems, s) => { + *s = CompletionState::Complete; + elems.iter_mut().for_each(|v| v.complete_deeply()); + } + Value::Markdown(_, _, s) => *s = CompletionState::Complete, + Value::FixedJson(val, fixes) => { + val.complete_deeply(); + }, + Value::AnyOf(choices, _) => choices.iter_mut().for_each(|v| v.complete_deeply()), + } + } + + pub fn completed_deeply(self) -> Self { + match self { + Value::String(v, _) => Value::String(v, CompletionState::Complete), + Value::Number(v, _) => Value::Number(v, CompletionState::Complete), + Value::Boolean(v) => Value::Boolean(v), + Value::Null => Value::Null, + Value::Object(v, _) => Value::Object(v, CompletionState::Complete), + Value::Array(v, _) => Value::Array( + v.into_iter().map(|v| v.completed_deeply()).collect(), + CompletionState::Complete, + ), + Value::Markdown(x, y, _) => Value::Markdown(x, y, CompletionState::Complete), + Value::FixedJson(x, y) => Value::FixedJson(x, y), + Value::AnyOf(choices, s) => Value::AnyOf( + choices.into_iter().map(|v| v.completed_deeply()).collect(), + s, + ), + } + } } impl std::fmt::Display for Value { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Value::String(s) => write!(f, "{}", s), - Value::Number(n) => write!(f, "{}", n), + Value::String(s, _) => write!(f, "{}", s), + Value::Number(n, _) => write!(f, "{}", n), Value::Boolean(b) => write!(f, "{}", b), Value::Null => write!(f, "null"), - Value::Object(o) => { + Value::Object(o, _) => { write!(f, "{{")?; for (i, (k, v)) in o.iter().enumerate() { if i > 0 { @@ -127,7 +198,7 @@ impl std::fmt::Display for Value { } write!(f, "}}") } - Value::Array(a) => { + Value::Array(a, _) => { write!(f, "[")?; for (i, v) in a.iter().enumerate() { if i > 0 { @@ -137,7 +208,7 @@ impl std::fmt::Display for Value { } write!(f, "]") } - Value::Markdown(s, v) => write!(f, "{}\n{}", s, v), + Value::Markdown(s, v, _) => write!(f, "{}\n{}", s, v), Value::FixedJson(v, _) => write!(f, "{}", v), Value::AnyOf(items, s) => { write!(f, "AnyOf[{},", s)?; @@ -150,6 +221,19 @@ impl std::fmt::Display for Value { } } +// The serde implementation is used as one of our parsing options. +// We deserialize into a "complete" value, and this property is +// true for nested values, because serde will call the same `deserialize` +// method on children of a serde container. +// +// Numbers and strings should be considered Incomplete if they are encountered +// at the top level. Therefore the non-recursive callsite of `deserialize` +// is responsible for setting completion state to Incomplete for top-level +// strings and numbers. +// +// Lists and objects at the top level are necessarily complete, because +// serde will not parse an array or an object unless the closing delimiter +// is present. impl<'de> serde::Deserialize<'de> for Value { fn deserialize(deserializer: D) -> Result where @@ -157,8 +241,8 @@ impl<'de> serde::Deserialize<'de> for Value { { let value = serde_json::Value::deserialize(deserializer)?; match value { - serde_json::Value::String(s) => Ok(Value::String(s)), - serde_json::Value::Number(n) => Ok(Value::Number(n)), + serde_json::Value::String(s) => Ok(Value::String(s, CompletionState::Complete)), + serde_json::Value::Number(n) => Ok(Value::Number(n, CompletionState::Complete)), serde_json::Value::Bool(b) => Ok(Value::Boolean(b)), serde_json::Value::Null => Ok(Value::Null), serde_json::Value::Object(o) => { @@ -168,7 +252,7 @@ impl<'de> serde::Deserialize<'de> for Value { serde_json::from_value(v).map_err(serde::de::Error::custom)?; map.push((k, parsed_value)); } - Ok(Value::Object(map)) + Ok(Value::Object(map, CompletionState::Complete)) } serde_json::Value::Array(a) => { let mut vec = Vec::new(); @@ -177,7 +261,7 @@ impl<'de> serde::Deserialize<'de> for Value { serde_json::from_value(v).map_err(serde::de::Error::custom)?; vec.push(parsed_value); } - Ok(Value::Array(vec)) + Ok(Value::Array(vec, CompletionState::Complete)) } } } diff --git a/engine/baml-lib/jsonish/src/lib.rs b/engine/baml-lib/jsonish/src/lib.rs index b82055c158..6d664bb12f 100644 --- a/engine/baml-lib/jsonish/src/lib.rs +++ b/engine/baml-lib/jsonish/src/lib.rs @@ -1,19 +1,155 @@ -#[cfg(test)] -mod tests; +pub mod helpers; +pub mod tests; use anyhow::Result; +use indexmap::IndexMap; pub mod deserializer; -mod jsonish; +use std::collections::HashMap; +pub mod jsonish; -use baml_types::FieldType; -use deserializer::coercer::{ParsingContext, TypeCoercer}; +use baml_types::{BamlValue, BamlValueWithMeta, FieldType, JinjaExpression, ResponseCheck}; +use deserializer::{ + coercer::{ParsingContext, ParsingError, TypeCoercer}, + deserialize_flags::DeserializerConditions, +}; pub use deserializer::types::BamlValueWithFlags; use internal_baml_core::ir::TypeValue; use internal_baml_jinja::types::OutputFormatContent; +use crate::deserializer::score::WithScore; +use baml_types::{Completion, CompletionState}; use deserializer::deserialize_flags::Flag; +use deserializer::types::ParsingErrorToUiJson; use jsonish::Value; +use serde::{ser::SerializeMap, ser::SerializeStruct, Serialize, Serializer}; + +#[derive(Clone, Debug)] +pub struct ResponseBamlValue( + pub BamlValueWithMeta<(Vec, Vec, Completion)>, +); + +#[derive(Debug, Clone, PartialEq)] +pub enum SerializeMode { + Final, + Partial, +} + +// impl serde::Serialize for (ResponseBamlValue, SerializeMode) { +// fn serialize(&self, serializer: S) -> Result { +// SerializeResponseBamlValue{ value: &self.0.0, serialize_mode: self.1 }.serialize(serializer) +// } +// } + +pub struct SerializeResponseBamlValue<'a>{ + pub value: &'a BamlValueWithMeta<(Vec, Vec, Completion)>, + pub serialize_mode: SerializeMode, +} + +impl ResponseBamlValue { + pub fn serialize_final<'a> (&'a self) -> SerializeResponseBamlValue<'a> { + SerializeResponseBamlValue { + value: &self.0, + serialize_mode: SerializeMode::Final + } + } + + pub fn serialize_partial<'a> (&'a self) -> SerializeResponseBamlValue<'a> { + SerializeResponseBamlValue { + value: &self.0, + serialize_mode: SerializeMode::Partial + } + } +} + + +impl serde::Serialize for SerializeResponseBamlValue<'_> { + fn serialize(&self, serializer: S) -> Result { + use BamlValueWithMeta::*; + let serialize_mode = &self.serialize_mode; + match &self.value { + String(s, ref meta) => serialize_with_meta(&s, &meta, serialize_mode, serializer), + Int(i, ref meta) => serialize_with_meta(&i, &meta, serialize_mode, serializer), + Float(f, ref meta) => serialize_with_meta(&f, &meta, serialize_mode, serializer), + Bool(b, ref meta) => serialize_with_meta(&b, &meta, serialize_mode, serializer), + Media(v, ref meta) => serialize_with_meta(&v, &meta, serialize_mode, serializer), + Enum(ref _name, v, ref meta) => serialize_with_meta(&v, &meta, serialize_mode, serializer), + Map(items, ref meta) => { + let new_items = items + .into_iter() + .map(|(k, v)| (k.clone(), SerializeResponseBamlValue{value: &v, serialize_mode: serialize_mode.clone()})) + .collect::>>(); + serialize_with_meta(&new_items, &meta, serialize_mode, serializer) + } + List(items, ref meta) => { + let new_items = items + .into_iter() + .map(|v| SerializeResponseBamlValue{value: v, serialize_mode: serialize_mode.clone()}) + .collect::>(); + serialize_with_meta(&new_items, &meta, serialize_mode, serializer) + } + Class(_name, fields, ref meta) => { + let new_fields = fields + .into_iter() + .map(|(k, v)| { + let subvalue_serialize_mode = match (&serialize_mode, v.meta().2.required_done) { + (SerializeMode::Final, _) => SerializeMode::Final, + (SerializeMode::Partial, true) => SerializeMode::Final, + (SerializeMode::Partial, false) => SerializeMode::Partial, + }; + (k, SerializeResponseBamlValue{value: v, serialize_mode: subvalue_serialize_mode}) + }) + .collect::>(); + serialize_with_meta(&new_fields, &meta, serialize_mode, serializer) + } + Null(ref meta) => serialize_with_meta(&(), &meta, serialize_mode, serializer), + } + } +} + +/// This newtype wrapper exists solely for the purpose of defining a +/// `Serialize` impl. +pub struct ResponseChecksMetadata<'a, T: Serialize>(pub (&'a T, &'a Vec)); + +impl<'a, T: Serialize> serde::Serialize for ResponseChecksMetadata<'a, T> { + fn serialize(&self, serializer: S) -> Result { + let checks_map: HashMap<_, _> = self + .0 + .1 + .iter() + .map(|check| (check.name.clone(), check)) + .collect(); + let mut state = serializer.serialize_struct("Checked", 2)?; + state.serialize_field("value", &self.0 .0)?; + state.serialize_field("checks", &checks_map)?; + state.end() + } +} + +fn serialize_with_meta( + value: &T, + meta: &(Vec, Vec, Completion), + serialize_mode: &SerializeMode, + serializer: S, +) -> Result { + let should_display_stream_state = meta.2.display && matches!(serialize_mode, SerializeMode::Partial); + match (meta.1.len(), should_display_stream_state) { + (0, false) => value.serialize(serializer), + (_, false) => ResponseChecksMetadata((value, &meta.1)).serialize(serializer), + (0, true) => { + let mut state = serializer.serialize_struct("StreamState", 2)?; + state.serialize_field("state", &meta.2.state)?; + state.serialize_field("value", value)?; + state.end() + } + (_, true) => { + let mut outer_value = serializer.serialize_struct("StreamState", 2)?; + outer_value.serialize_field("state", &meta.2.state)?; + outer_value.serialize_field("value", &ResponseChecksMetadata((value, &meta.1)))?; + outer_value.end() + } + } +} pub fn from_str( of: &OutputFormatContent, @@ -26,12 +162,13 @@ pub fn from_str( } // When the schema is just a string, i should really just return the raw_string w/o parsing it. - let mut value = jsonish::parse(raw_string, jsonish::ParseOptions::default())?; + let value = jsonish::parse(raw_string, jsonish::ParseOptions::default())?; // let schema = deserializer::schema::from_jsonish_value(&value, None); + // eprintln!("value: {value:?}"); // See Note [Streaming Number Invalidation] if allow_partials { - invalidate_numbers_in_progress(&mut value, raw_string); + // invalidate_numbers_in_progress(&mut value, raw_string); } // Pick the schema that is the most specific. @@ -49,114 +186,93 @@ pub fn from_str( // Determine the best way to get the desired schema from the parsed schema. // Lets try to now coerce the value into the expected schema. - match target.coerce(&ctx, target, Some(&value)) { + let parsed_value: BamlValueWithFlags = match target.coerce(&ctx, target, Some(&value)) { Ok(v) => { if v.conditions() .flags() .iter() - .any(|f| matches!(f, Flag::InferedObject(jsonish::Value::String(_)))) + .any(|f| matches!(f, Flag::InferedObject(jsonish::Value::String(_, _)))) { anyhow::bail!("Failed to coerce value: {:?}", v.conditions().flags()); } - Ok(v) + Ok::(v) } Err(e) => anyhow::bail!("Failed to coerce value: {}", e), - } + }?; + + Ok(parsed_value) } -/// Nullify numbers that may still be streaming in. -/// -/// See note [Streaming Number Invalidation] -fn invalidate_numbers_in_progress(value: &mut Value, raw_string: &str) { - let ends_in_digit = raw_string - .chars() - .last() - .map_or(false, |c| c.is_numeric() || c == '.'); - let last_values = last_value_as_number(value); - if ends_in_digit { - last_values.into_iter().for_each(|v| { - *v = Value::Null; +impl ResponseBamlValue { + pub fn score(&self) -> i32 { + self.0.iter().map(|node| node.meta().0.score()).sum() + } + + pub fn explanation_json(&self) -> Vec { + let mut expl = vec![]; + self.explanation_impl(vec!["".to_string()], &mut expl); + expl.into_iter().map(|e| e.to_ui_json()).collect::>() + } + + fn explanation_impl(&self, scope: Vec, expls: &mut Vec) { + self.0.iter().for_each(|node| { + let message = match node { + BamlValueWithMeta::String(_, _) => "error while parsing string".to_string(), + BamlValueWithMeta::Int(_, _) => "error while parsing int".to_string(), + BamlValueWithMeta::Float(_, _) => "error while parsing float".to_string(), + BamlValueWithMeta::Bool(_, _) => "error while parsing bool".to_string(), + BamlValueWithMeta::List(_, _) => "error while parsing list".to_string(), + BamlValueWithMeta::Map(_, _) => "error while parsing map".to_string(), + BamlValueWithMeta::Enum(enum_name, _, _) => { + format!("error while parsing {enum_name} enum value") + } + BamlValueWithMeta::Class(class_name, _, _) => { + format!("error while parsing class {class_name}") + } + BamlValueWithMeta::Null(_) => "error while parsing null".to_string(), + BamlValueWithMeta::Media(_, _) => "error while parsing media".to_string(), + }; + let parsing_error = ParsingError { + scope: scope.clone(), + reason: message, + causes: DeserializerConditions { + flags: node.meta().0.clone(), + } + .explanation(), + }; + if node.meta().0.len() > 0 { + expls.push(parsing_error) + } }) } } -// Find the "last" element of a Value and return a mutable pointer to it. -// There may be multiple last elements only in the case of `Value::Anyof`. -// Every other case returns 0 or 1 pointers. -// -// The algorithm for finding the last value has several cases: -// - Base case: Raw values like `String` and `Number` are themselves the -// last value. -// - Simple compound case: The last value of a flat array is trivially -// the array's last element. The last value of an Object is the last -// key-value pair to be parsed. Because we store objects' key-value -// pairs in the order in which they are defined in the input tokens, -// we simply look up the last field from the list of fields. -// - Inductive case for lists and objects: When a list or object contains -// other lists or objects, the transitively last element of the parent -// type is the last element of the last direct element. -// - AnyOf case: AnyOf represents multiple `jsonish::Value` parses of the -// token stream. We have to compute the last item of each variant, and -// handle them all, because any one of them could be selected downstream -// by the coercer. AnyOf is the reason this function returns a `Vec` of -// references rather than an `Optional` reference. -fn last_value_as_number(value: &mut Value) -> Vec<&mut Value> { - match value { - Value::String(_) => vec![], - Value::Number(_) => vec![value], - Value::Boolean(_) => vec![], - Value::Null => vec![], - Value::Array(items) => items - .last_mut() - .map(|i| last_value_as_number(i)) - .unwrap_or_default(), - Value::Object(fields) => fields - .last_mut() - .map(|(_k, v)| last_value_as_number(v)) - .unwrap_or_default(), - Value::Markdown(_, sub_value) => last_value_as_number(sub_value), - Value::FixedJson(fixed_val, _fixes) => last_value_as_number(fixed_val), - Value::AnyOf(variants, _) => variants - .iter_mut() - .flat_map(|variant| last_value_as_number(variant)) - .collect(), +impl From for BamlValue { + fn from(v: ResponseBamlValue) -> BamlValue { + v.0.into() + } +} + +impl WithScore for ResponseBamlValue { + fn score(&self) -> i32 { + self.0.iter().map(|node| node.meta().0.score()).sum() } } -/* - * Note: Streaming Number Invalidation - * - * Displaying partial results can be more harmful for certain datatypes, - * like ints and floats, despite being useful for strings. This is because - * the prefix of a number conveys something very differenet about the full - * number than what the prefix of a string conveys about the full string. - * - * To prevent confusing users with streamed number prefixes, we have - * implemented a specific and slightly hacky workaround, which we may replace by - * something more robust in the future. (We won't spend time here describing - * this future solution). - * - * Our temporary solution works like this: - * - Flexibly parse LLM responses into `jsonish::Value` as usual. - * - Determine whether the last tokens represent a number that might - * be extended by subsequent tokens. - * - If the last tokens represent an in-progress number, identify the part - * of the `jsonish::Value` that is currently being extended, and convert - * it to `jsonish::Value::Null`. - * - * This algorithm is implemented in `invalidate_numbers_in_progress`. Finding - * the currently-in-progress part of the `jsonish::Value` structure is - * implemented in `last_value_as_number`. - * - * - * Consider these examples of streamed tokens and their parses into - * `Value`: - * - * - "123" => 123. This `Value::Number` will be rewritten to - * `Value::Null` because it is the final element in the ADT - * and the input string ends in a digit. - * - * - "[123, 456" => [123, 456]. The `456` will be nulled. - * - "[123, 456]" => [123, 456]. No change. - */ +// impl SerializeMetadata for ResponseBamlValue { +// fn metadata_fields(&self) -> Vec<(String, serde_json::Value)> { +// let mut fields = Vec::new(); +// let checks: Vec<(&str, &ResponseCheck)> = self.0.meta().1.iter().map(|check| (check.name.as_str(), check)).collect(); +// if !checks.is_empty() { +// let checks_json = serde_json::to_value(checks).expect("Serializing checks is safe."); +// fields.push(("checks".to_string(), checks_json)); +// } +// let completion_state: Option<&CompletionState> = self.0.meta().2.as_ref(); +// if let Some(state) = completion_state { +// let completion_state_json = serde_json::to_value(&state).expect("Serializing completion state is safe."); +// fields.push(("completion_state".to_string(), completion_state_json)); +// } +// fields +// } +// } diff --git a/engine/baml-lib/jsonish/src/tests/animation.rs b/engine/baml-lib/jsonish/src/tests/animation.rs new file mode 100644 index 0000000000..9200dcb078 --- /dev/null +++ b/engine/baml-lib/jsonish/src/tests/animation.rs @@ -0,0 +1,56 @@ +use crate::helpers::render_output_format; +#[cfg(test)] +use crate::{from_str, helpers::parsed_value_to_response}; +use baml_types::{FieldType, StreamingBehavior}; +use internal_baml_core::ir::repr::make_test_ir; + +#[test] +pub fn make_test_data1() { + let ir = make_test_ir( + r##" + class PersonAssignment { + person Person @stream.with_state + assignment string @stream.with_state + } + + class Person { + name string @stream.done @stream.with_state + age int @stream.with_state + } + "##, + ) + .unwrap(); + + let target_type = FieldType::WithMetadata { + base: Box::new(FieldType::class("PersonAssignment")), + constraints: vec![], + streaming_behavior: StreamingBehavior { + done: false, + state: true, + }, + }; + let target = render_output_format(&ir, &target_type, &Default::default()).unwrap(); + + let llm_data = r#"{"person": {"name": "Greg", "age": 42}, "assignment": "Write"}"#; + + let results = (0..llm_data.len() + 1) + // let results = (0..2) + .map(|i| { + let partial_llm_data = &llm_data[0..i]; + let parsed_value = from_str(&target, &target_type, partial_llm_data, true); + let value = + parsed_value_to_response(&ir, parsed_value.unwrap(), &target_type, true).unwrap(); + + serde_json::to_value(&vec![ + serde_json::to_value(partial_llm_data).unwrap(), + serde_json::to_value(&value).unwrap(), + ]) + .unwrap() + }) + .collect::>(); + + let json = serde_json::to_string(&results).unwrap(); + eprintln!("{}", json); + + // assert!(false); +} diff --git a/engine/baml-lib/jsonish/src/tests/macros.rs b/engine/baml-lib/jsonish/src/tests/macros.rs index 4fa6fc6817..241be5212d 100644 --- a/engine/baml-lib/jsonish/src/tests/macros.rs +++ b/engine/baml-lib/jsonish/src/tests/macros.rs @@ -2,8 +2,10 @@ macro_rules! test_failing_deserializer { ($name:ident, $file_content:expr, $raw_string:expr, $target_type:expr) => { #[test_log::test] fn $name() { - let ir = load_test_ir($file_content); - let target = render_output_format(&ir, &$target_type, &Default::default()).unwrap(); + let ir = crate::helpers::load_test_ir($file_content); + let target = + crate::helpers::render_output_format(&ir, &$target_type, &Default::default()) + .unwrap(); let result = from_str(&target, &$target_type, $raw_string, false); @@ -26,7 +28,7 @@ macro_rules! test_failing_deserializer { /// /// Example /// -/// ```rust +/// ```ignore /// test_deserializer!( /// my_test, /// "schema_content", @@ -39,8 +41,8 @@ macro_rules! test_deserializer { ($name:ident, $file_content:expr, $raw_string:expr, $target_type:expr, $($json:tt)+) => { #[test_log::test] fn $name() { - let ir = load_test_ir($file_content); - let target = render_output_format(&ir, &$target_type, &Default::default()).unwrap(); + let ir = crate::helpers::load_test_ir($file_content); + let target = crate::helpers::render_output_format(&ir, &$target_type, &Default::default()).unwrap(); let result = from_str( &target, @@ -68,8 +70,10 @@ macro_rules! test_deserializer_with_expected_score { ($name:ident, $file_content:expr, $raw_string:expr, $target_type:expr, $target_score:expr) => { #[test_log::test] fn $name() { - let ir = load_test_ir($file_content); - let target = render_output_format(&ir, &$target_type, &Default::default()).unwrap(); + let ir = crate::helpers::load_test_ir($file_content); + let target = + crate::helpers::render_output_format(&ir, &$target_type, &Default::default()) + .unwrap(); let result = from_str(&target, &$target_type, $raw_string, false); @@ -87,8 +91,8 @@ macro_rules! test_partial_deserializer { ($name:ident, $file_content:expr, $raw_string:expr, $target_type:expr, $($json:tt)+) => { #[test_log::test] fn $name() { - let ir = load_test_ir($file_content); - let target = render_output_format(&ir, &$target_type, &Default::default()).unwrap(); + let ir = crate::helpers::load_test_ir($file_content); + let target = crate::helpers::render_output_format(&ir, &$target_type, &Default::default()).unwrap(); let result = from_str( &target, @@ -111,3 +115,66 @@ macro_rules! test_partial_deserializer { } }; } + +macro_rules! test_partial_deserializer_streaming { + ($name:ident, $file_content:expr, $raw_string:expr, $target_type:expr, $($json:tt)+) => { + #[test_log::test] + fn $name() { + let ir = crate::helpers::load_test_ir($file_content); + let target = crate::helpers::render_output_format(&ir, &$target_type, &Default::default()).unwrap(); + + let parsed = from_str( + &target, + &$target_type, + $raw_string, + true, + ); + + // dbg!(&target); + // dbg!(&$target_type); + dbg!(&parsed); + + assert!(parsed.is_ok(), "Failed to parse: {:?}", parsed); + + let result = crate::helpers::parsed_value_to_response(&ir, parsed.unwrap(), &$target_type, true).unwrap(); + + dbg!(&result); + + let value = result; + log::trace!("Score: {}", value.score()); + let json_value = json!(value); + + let expected = serde_json::json!($($json)+); + + assert_json_diff::assert_json_eq!(json_value, expected); + } + }; +} + +macro_rules! test_partial_deserializer_streaming_failure { + ($name:ident, $file_content:expr, $raw_string:expr, $target_type:expr) => { + #[test_log::test] + fn $name() { + let ir = load_test_ir($file_content); + let target = + crate::helpers::render_output_format(&ir, &$target_type, &Default::default()) + .unwrap(); + + let parsed = from_str(&target, &$target_type, $raw_string, true); + + dbg!(&target); + dbg!(&$target_type); + + assert!(parsed.is_ok(), "Failed to parse: {:?}", parsed); + + let result = + crate::helpers::parsed_value_to_response(&ir, parsed.unwrap(), &$target_type, true); + + assert!( + result.is_err(), + "Failed not to parse: {:?}", + result.unwrap() + ); + } + }; +} diff --git a/engine/baml-lib/jsonish/src/tests/mod.rs b/engine/baml-lib/jsonish/src/tests/mod.rs index 00da9d6192..9db8257a8b 100644 --- a/engine/baml-lib/jsonish/src/tests/mod.rs +++ b/engine/baml-lib/jsonish/src/tests/mod.rs @@ -1,9 +1,7 @@ -use anyhow::Result; -use internal_baml_jinja::types::{Class, Enum, Name, OutputFormatContent}; - #[macro_use] pub mod macros; +mod animation; mod test_aliases; mod test_basics; mod test_class; @@ -15,6 +13,7 @@ mod test_lists; mod test_literals; mod test_maps; mod test_partials; +mod test_streaming; mod test_unions; use indexmap::{IndexMap, IndexSet}; @@ -23,248 +22,27 @@ use std::{ path::PathBuf, }; -use baml_types::{BamlValue, EvaluationContext}; +use crate::deserializer::deserialize_flags::Flag; +use crate::deserializer::semantic_streaming::validate_streaming_state; +use crate::{BamlValueWithFlags, ResponseBamlValue}; +use anyhow::Result; +use baml_types::{ + BamlValue, BamlValueWithMeta, CompletionState, EvaluationContext, FieldType, JinjaExpression, + ResponseCheck, StreamingBehavior, +}; +use internal_baml_core::ir::repr::IntermediateRepr; +use internal_baml_jinja::types::{Class, Enum, Name, OutputFormatContent}; + use internal_baml_core::{ ast::Field, internal_baml_diagnostics::SourceFile, - ir::{repr::IntermediateRepr, ClassWalker, EnumWalker, FieldType, IRHelper, TypeValue}, + ir::{ClassWalker, EnumWalker, IRHelper, TypeValue}, validate, }; use serde_json::json; use crate::from_str; -fn load_test_ir(file_content: &str) -> IntermediateRepr { - let mut schema = validate( - &PathBuf::from("./baml_src"), - vec![SourceFile::from(( - PathBuf::from("./baml_src/example.baml"), - file_content.to_string(), - ))], - ); - match schema.diagnostics.to_result() { - Ok(_) => {} - Err(e) => { - panic!("Failed to validate schema: {}", e); - } - } - - IntermediateRepr::from_parser_database(&schema.db, schema.configuration).unwrap() -} - -fn render_output_format( - ir: &IntermediateRepr, - output: &FieldType, - env_values: &EvaluationContext<'_>, -) -> Result { - let (enums, classes, recursive_classes, structural_recursive_aliases) = - relevant_data_models(ir, output, env_values)?; - - Ok(OutputFormatContent::target(output.clone()) - .enums(enums) - .classes(classes) - .recursive_classes(recursive_classes) - .structural_recursive_aliases(structural_recursive_aliases) - .build()) -} - -fn find_existing_class_field( - class_name: &str, - field_name: &str, - class_walker: &Result>, - env_values: &EvaluationContext<'_>, -) -> Result<(Name, FieldType, Option)> { - let Ok(class_walker) = class_walker else { - anyhow::bail!("Class {} does not exist", class_name); - }; - - let Some(field_walker) = class_walker.find_field(field_name) else { - anyhow::bail!("Class {} does not have a field: {}", class_name, field_name); - }; - - let name = Name::new_with_alias(field_name.to_string(), field_walker.alias(env_values)?); - let desc = field_walker.description(env_values)?; - let r#type = field_walker.r#type(); - Ok((name, r#type.clone(), desc)) -} - -fn find_enum_value( - enum_name: &str, - value_name: &str, - enum_walker: &Result>, - env_values: &EvaluationContext<'_>, -) -> Result)>> { - if enum_walker.is_err() { - anyhow::bail!("Enum {} does not exist", enum_name); - } - - let value_walker = match enum_walker { - Ok(e) => e.find_value(value_name), - Err(_) => None, - }; - - let value_walker = match value_walker { - Some(v) => v, - None => return Ok(None), - }; - - if value_walker.skip(env_values)? { - return Ok(None); - } - - let name = Name::new_with_alias(value_name.to_string(), value_walker.alias(env_values)?); - let desc = value_walker.description(env_values)?; - - Ok(Some((name, desc))) -} - -// TODO: This function is "almost" a duplicate of `relevant_data_models` at -// baml-runtime/src/internal/prompt_renderer/render_output_format.rs -// -// Should be refactored. -// -// TODO: (Greg) Is the use of `String` as a hash key safe? Is there some way to -// get a collision that results in some type not getting put onto the stack? -fn relevant_data_models<'a>( - ir: &'a IntermediateRepr, - output: &'a FieldType, - env_values: &EvaluationContext<'_>, -) -> Result<( - Vec, - Vec, - IndexSet, - IndexMap, -)> { - let mut checked_types: HashSet = HashSet::new(); - let mut enums = Vec::new(); - let mut classes: Vec = Vec::new(); - let mut recursive_classes = IndexSet::new(); - let mut structural_recursive_aliases = IndexMap::new(); - let mut start: Vec = vec![output.clone()]; - - while let Some(output) = start.pop() { - match ir.distribute_constraints(&output) { - (FieldType::Enum(enm), constraints) => { - if checked_types.insert(output.to_string()) { - let walker = ir.find_enum(enm); - - let real_values = walker - .as_ref() - .map(|e| e.walk_values().map(|v| v.name().to_string())) - .ok(); - let values = real_values - .into_iter() - .flatten() - .map(|value| { - let meta = find_enum_value(enm.as_str(), &value, &walker, env_values)?; - Ok(meta) - }) - .filter_map(|v| v.transpose()) - .collect::>>()?; - - enums.push(Enum { - name: Name::new_with_alias(enm.to_string(), walker?.alias(env_values)?), - values, - constraints, - }); - } - } - (FieldType::List(inner), _constraints) | (FieldType::Optional(inner), _constraints) => { - if !checked_types.contains(&inner.to_string()) { - start.push(inner.as_ref().clone()); - } - } - (FieldType::Map(k, v), _constraints) => { - if checked_types.insert(output.to_string()) { - if !checked_types.contains(&k.to_string()) { - start.push(k.as_ref().clone()); - } - if !checked_types.contains(&v.to_string()) { - start.push(v.as_ref().clone()); - } - } - } - (FieldType::Tuple(options), _constraints) - | (FieldType::Union(options), _constraints) => { - if checked_types.insert(output.to_string()) { - for inner in options { - if !checked_types.contains(&inner.to_string()) { - start.push(inner.clone()); - } - } - } - } - (FieldType::Class(cls), constraints) => { - if checked_types.insert(output.to_string()) { - let walker = ir.find_class(cls); - - let real_fields = walker - .as_ref() - .map(|e| e.walk_fields().map(|v| v.name().to_string())) - .ok(); - - let fields = real_fields.into_iter().flatten().map(|field| { - let meta = find_existing_class_field(cls, &field, &walker, env_values)?; - Ok(meta) - }); - - let fields = fields.collect::>>()?; - - for (_, t, _) in fields.iter().as_ref() { - if !checked_types.contains(&t.to_string()) { - start.push(t.clone()); - } - } - - // TODO: O(n) algorithm. Maybe a Merge-Find Set can optimize - // this to O(log n) or something like that - // (maybe, IDK though ¯\_(ツ)_/¯) - // - // Also there's a lot of cloning in this process of going - // from Parser DB to IR to Jinja Output Format, not only - // with recursive classes but also the rest of models. - // There's room for optimization here. - // - // Also take a look at the TODO on top of this function. - for cycle in ir.finite_recursive_cycles() { - if cycle.contains(cls) { - recursive_classes.extend(cycle.iter().map(ToOwned::to_owned)); - } - } - - classes.push(Class { - name: Name::new_with_alias(cls.to_string(), walker?.alias(env_values)?), - fields, - constraints, - }); - } - } - (FieldType::RecursiveTypeAlias(name), _) => { - // TODO: Same O(n) problem as above. - for cycle in ir.structural_recursive_alias_cycles() { - if cycle.contains_key(name) { - for (alias, target) in cycle.iter() { - structural_recursive_aliases.insert(alias.to_owned(), target.clone()); - } - } - } - } - (FieldType::Literal(_), _) => {} - (FieldType::Primitive(_), _constraints) => {} - (FieldType::Constrained { .. }, _) => { - unreachable!("It is guaranteed that a call to distribute_constraints will not return FieldType::Constrained") - } - } - } - - Ok(( - enums, - classes, - recursive_classes, - structural_recursive_aliases, - )) -} - const EMPTY_FILE: &str = r#" "#; @@ -750,68 +528,3 @@ test_deserializer!( "four": "four" }) ); - -#[test] -/// Test that when partial parsing, if we encounter an int in a context -/// where it could possibly be extended further, it must be returned -/// as Null. -fn singleton_list_int_deleted() { - let target = FieldType::List(Box::new(FieldType::Primitive(TypeValue::Int))); - let output_format = OutputFormatContent::target(target.clone()).build(); - let res = from_str(&output_format, &target, "[123", true).expect("Can parse"); - let baml_value: BamlValue = res.into(); - assert_eq!(baml_value, BamlValue::List(vec![])); -} - -#[test] -/// Test that when partial parsing, if we encounter an int in a context -/// where it could possibly be extended further, it must be returned -/// as Null. -fn list_int_deleted() { - let target = FieldType::List(Box::new(FieldType::Primitive(TypeValue::Int))); - let output_format = OutputFormatContent::target(target.clone()).build(); - let res = from_str(&output_format, &target, "[123, 456", true).expect("Can parse"); - let baml_value: BamlValue = res.into(); - assert_eq!(baml_value, BamlValue::List(vec![BamlValue::Int(123)])); -} - -#[test] -/// Test that when partial parsing, if we encounter an int in a context -/// where it could possibly be extended further, it must be returned -/// as Null. -fn list_int_not_deleted() { - let target = FieldType::List(Box::new(FieldType::Primitive(TypeValue::Int))); - let output_format = OutputFormatContent::target(target.clone()).build(); - let res = from_str(&output_format, &target, "[123, 456 // Done", true).expect("Can parse"); - let baml_value: BamlValue = res.into(); - assert_eq!( - baml_value, - BamlValue::List(vec![BamlValue::Int(123), BamlValue::Int(456)]) - ); -} - -#[test] -/// Test that when partial parsing, if we encounter an int in a context -/// where it could possibly be extended further, it must be returned -/// as Null. -fn partial_int_deleted() { - let target = FieldType::Optional(Box::new(FieldType::Primitive(TypeValue::Int))); - let output_format = OutputFormatContent::target(target.clone()).build(); - let res = from_str(&output_format, &target, "123", true).expect("Can parse"); - let baml_value: BamlValue = res.into(); - // Note: This happens to parse as a List, but Null also seems appropriate. - assert_eq!(baml_value, BamlValue::Null); -} - -#[test] -/// Test that when partial parsing, if we encounter an int in a context -/// where it could possibly be extended further, it must be returned -/// as Null. -fn partial_int_not_deleted() { - let target = FieldType::List(Box::new(FieldType::Primitive(TypeValue::Int))); - let output_format = OutputFormatContent::target(target.clone()).build(); - let res = from_str(&output_format, &target, "123", true).expect("Can parse"); - let baml_value: BamlValue = res.into(); - // Note: This happens to parse as a List, but Null also seems appropriate. - assert_eq!(baml_value, BamlValue::List(vec![])); -} diff --git a/engine/baml-lib/jsonish/src/tests/test_class.rs b/engine/baml-lib/jsonish/src/tests/test_class.rs index 2e43ab30fc..0ec2fd4774 100644 --- a/engine/baml-lib/jsonish/src/tests/test_class.rs +++ b/engine/baml-lib/jsonish/src/tests/test_class.rs @@ -1083,22 +1083,6 @@ class Bar { } "#; -test_partial_deserializer!( - test_object_streaming_ints, - OBJECT_STREAM_TEST, - r#"{"a": 11, "b": 22"#, - FieldType::Class("Foo".to_string()), - {"a": 11, "b": null, "c": null} -); - -test_partial_deserializer!( - test_object_streaming_ints_newlines, - OBJECT_STREAM_TEST, - "{\n\"a\":11,\n\"b\": 22", - FieldType::Class("Foo".to_string()), - {"a": 11, "b": null, "c": null} -); - test_partial_deserializer!( test_object_finished_ints, OBJECT_STREAM_TEST, @@ -1107,59 +1091,6 @@ test_partial_deserializer!( {"a": 1234, "b": 1234, "c": 1234} ); -test_partial_deserializer!( - test_nested_object_streaming, - OBJECT_STREAM_TEST, - r#"{"a": 1234, "foo": { "c": 33, "a": 11"#, - FieldType::Class("Bar".to_string()), - {"a": 1234, "foo": { "a": null, "b": null, "c": 33}} -); - -const BIG_OBJECT_STREAM_TEST: &str = r#" -class BigNumbers { - a int - b float -} - -class CompoundBigNumbers { - big BigNumbers - big_nums BigNumbers[] - another BigNumbers -} -"#; - -test_partial_deserializer!( - test_big_object_empty, - BIG_OBJECT_STREAM_TEST, - "{", - FieldType::Class("CompoundBigNumbers".to_string()), - {"big": null, "big_nums": [], "another": null} -); - -test_partial_deserializer!( - test_big_object_start_big, - BIG_OBJECT_STREAM_TEST, - r#"{"big": {"a": 11, "b": 12"#, - FieldType::Class("CompoundBigNumbers".to_string()), - {"big": {"a": 11, "b": null}, "big_nums": [], "another": null} -); - -test_partial_deserializer!( - test_big_object_start_big_into_list, - BIG_OBJECT_STREAM_TEST, - r#"json```{"big": {"a": 11, "b": 12}, "big_nums": [{"a": 22, "b": 33"#, - FieldType::Class("CompoundBigNumbers".to_string()), - {"big": {"a": 11, "b": 12.0}, "big_nums": [{"a": 22, "b": null}], "another": null} -); - -test_partial_deserializer!( - test_big_object_start_big_into_list2, - BIG_OBJECT_STREAM_TEST, - r#"json```{"big": {"a": 11, "b": 12.2}, "big_nums": [{"a": 22, "b": 33}, {"a": 1, "b": 2.2}], "another": {"a": 45, "b": 0.1"#, - FieldType::Class("CompoundBigNumbers".to_string()), - {"big": {"a": 11, "b": 12.2}, "big_nums": [{"a": 22, "b": 33.0}, {"a": 1, "b": 2.2}], "another": {"a": 45, "b": null}} -); - test_deserializer!( test_empty_string_value, r#" @@ -1486,3 +1417,53 @@ test_deserializer!( }, } ); + +const OPTIONAL_LIST_AND_MAP: &str = r#" +class OptionalListAndMap { + p string[]? + q map? +} +"#; + +test_partial_deserializer_streaming!( + test_optional_list, + OPTIONAL_LIST_AND_MAP, + r#" + ```json + { + "p": ["test"], + "q": { + "test": "ok" + } + } + ``` + "#, + FieldType::class("OptionalListAndMap"), + {"p": ["test"], "q": { "test": "ok" }} +); + +const INTEG_TEST_FAILURE_STR: &str = r#" +[ + { + "prop1": "In the realm of artificial intelligence, advancements have been remarkable. Between neural networks and cutting-edge algorithms, the landscape of machine learning has evolved dramatically. From the development of self-driving cars to sophisticated chatbots that can engage in human-like conversations, AI technology has become an integral aspect of modern life. Researchers are continually pushing the boundaries of what is possible, exploring deep learning techniques that enable machines to learn from extensive datasets. The application of AI spans various industries including healthcare, where predictive analytics aids in diagnostics, to finance, where algorithms manage investment portfolios. As AI continues to adapt and grow, ethical considerations surrounding data privacy and decision-making processes become increasingly important. Ongoing debates question the implications of relying on AI and machine learning for critical functions in society. Moreover, governments and organizations alike are grappling with the challenges of regulation and oversight in this fast-paced field. The future of AI seems bright, but it also poses inquiries into trust, accountability, and the long-term effects on the job market. As we look ahead, the collaboration between humans and machines could redefine productivity and creativity, paving the way for innovative solutions to complex problems that society faces.", + "prop2": 1 + } +] +"#; + +const INTEG_TEST_FAILURE_SCHEMA: &str = r#" +class TestOutputClass { + prop1 string @description("A long string with about 200 words") + prop2 int +} +"#; + +test_partial_deserializer_streaming!( + test_integ_test_failure, + INTEG_TEST_FAILURE_SCHEMA, + INTEG_TEST_FAILURE_STR, + FieldType::Class("TestOutputClass".to_string()), + { "prop1": "In the realm of artificial intelligence, advancements have been remarkable. Between neural networks and cutting-edge algorithms, the landscape of machine learning has evolved dramatically. From the development of self-driving cars to sophisticated chatbots that can engage in human-like conversations, AI technology has become an integral aspect of modern life. Researchers are continually pushing the boundaries of what is possible, exploring deep learning techniques that enable machines to learn from extensive datasets. The application of AI spans various industries including healthcare, where predictive analytics aids in diagnostics, to finance, where algorithms manage investment portfolios. As AI continues to adapt and grow, ethical considerations surrounding data privacy and decision-making processes become increasingly important. Ongoing debates question the implications of relying on AI and machine learning for critical functions in society. Moreover, governments and organizations alike are grappling with the challenges of regulation and oversight in this fast-paced field. The future of AI seems bright, but it also poses inquiries into trust, accountability, and the long-term effects on the job market. As we look ahead, the collaboration between humans and machines could redefine productivity and creativity, paving the way for innovative solutions to complex problems that society faces.", + "prop2": 1 + } +); diff --git a/engine/baml-lib/jsonish/src/tests/test_lists.rs b/engine/baml-lib/jsonish/src/tests/test_lists.rs index 66049b964e..8804d2c6a1 100644 --- a/engine/baml-lib/jsonish/src/tests/test_lists.rs +++ b/engine/baml-lib/jsonish/src/tests/test_lists.rs @@ -137,12 +137,4 @@ test_deserializer!( r#"[1234"#, FieldType::List(FieldType::Primitive(TypeValue::Int).into()), [1234] -); - -test_partial_deserializer!( - test_list_streaming_partial, - "", - r#"[1234, 5678"#, - FieldType::List(FieldType::Primitive(TypeValue::Int).into()), - [1234] -); +); \ No newline at end of file diff --git a/engine/baml-lib/jsonish/src/tests/test_maps.rs b/engine/baml-lib/jsonish/src/tests/test_maps.rs index 5c1b1d6415..2b14391c66 100644 --- a/engine/baml-lib/jsonish/src/tests/test_maps.rs +++ b/engine/baml-lib/jsonish/src/tests/test_maps.rs @@ -1,4 +1,5 @@ use crate::BamlValueWithFlags; +use baml_types::LiteralValue; use super::*; @@ -126,8 +127,8 @@ fn test_union_of_class_and_map() { let llm_output = r#"{"a": 1, "b": "hello"}"#; let expected = json!({"a": "1", "b": "hello"}); - let ir = load_test_ir(file_content); - let target = render_output_format(&ir, &target_type, &Default::default()).unwrap(); + let ir = crate::helpers::load_test_ir(file_content); + let target = crate::helpers::render_output_format(&ir, &target_type, &Default::default()).unwrap(); let result = from_str(&target, &target_type, llm_output, false); @@ -158,8 +159,8 @@ fn test_union_of_map_and_class() { let llm_output = r#"{"a": 1, "b": "hello"}"#; let expected = json!({"a": "1", "b": "hello"}); - let ir = load_test_ir(file_content); - let target = render_output_format(&ir, &target_type, &Default::default()).unwrap(); + let ir = crate::helpers::load_test_ir(file_content); + let target = crate::helpers::render_output_format(&ir, &target_type, &Default::default()).unwrap(); let result = from_str(&target, &target_type, llm_output, false); @@ -176,3 +177,40 @@ fn test_union_of_map_and_class() { assert_json_diff::assert_json_eq!(json_value, expected); } + +test_deserializer!( + test_map_with_enum_keys, + r#" + enum Key { + A + B + } + "#, + r#"{"A": "one", "B": "two"}"#, + FieldType::map(FieldType::Enum("Key".to_string()), FieldType::string()), + {"A": "one", "B": "two"} +); + +test_partial_deserializer_streaming!( + test_map_with_enum_keys_streaming, + r#" + enum Key { + A + B + } + "#, + r#"{"A": "one", "B": "two"}"#, + FieldType::map(FieldType::Enum("Key".to_string()), FieldType::string()), + {"A": "one", "B": "two"} +); + +test_partial_deserializer_streaming!( + test_map_with_literal_keys_streaming, + "", + r#"{"A": "one", "B": "two"}"#, + FieldType::map(FieldType::Union(vec![ + FieldType::Literal(LiteralValue::String("A".to_string())), + FieldType::Literal(LiteralValue::String("B".to_string())), + ]), FieldType::string()), + {"A": "one", "B": "two"} +); \ No newline at end of file diff --git a/engine/baml-lib/jsonish/src/tests/test_streaming.rs b/engine/baml-lib/jsonish/src/tests/test_streaming.rs new file mode 100644 index 0000000000..26a77c220e --- /dev/null +++ b/engine/baml-lib/jsonish/src/tests/test_streaming.rs @@ -0,0 +1,128 @@ +use super::*; + +use crate::helpers::load_test_ir; + +const NUMBERS: &str = r#" +class Foo { + nums int[] +} +"#; + +test_partial_deserializer_streaming!( + test_number_list, + NUMBERS, + "{'nums': [1,2", + FieldType::class("Foo"), + {"nums": [1]} +); + +const NUMBERS_STATE: &str = r#" +class Foo { + nums int[] @stream.with_state +} +"#; + +test_partial_deserializer_streaming!( + test_number_list_state_incomplete, + NUMBERS_STATE, + "{'nums': [1,2", + FieldType::class("Foo"), + {"nums": {"value": [1], "state": "Incomplete"}} +); + +const TOPLEVEL_DONE: &str = r#" +class Foo { + nums int[] + @@stream.done +} +"#; + +test_partial_deserializer_streaming_failure!( + test_toplevel_done, + TOPLEVEL_DONE, + "{'nums': [1,2]", + FieldType::class("Foo") +); + +const NESTED_DONE: &str = r#" +class Foo { + nums int[] + @@stream.done +} + +class Bar { + foos Foo[] +} +"#; + +test_partial_deserializer_streaming!( + test_nested_done, + NESTED_DONE, + r#"{ + 'foos': [ + {'nums': [1, 2]}, + {'nums': [3, 4] + "#, + FieldType::class("Bar"), + {"foos": [ {"nums": [1, 2]}]} +); + +const NESTED_DONE_WITH_TOPLEVEL_DONE: &str = r#" +class Foo { + nums int[] + @@stream.done +} + +class Bar { + message string @stream.done + foos Foo[] +} +"#; + +test_partial_deserializer_streaming!( + test_nested_done_with_toplevel_done, + NESTED_DONE_WITH_TOPLEVEL_DONE, + r#"{ + 'message': "Hello", + 'foos': [ + {'nums': [1, 2]}, + {'nums': [3, 4] + "#, + FieldType::class("Bar"), + {"message": "Hello", "foos": [ {"nums": [1, 2]}]} +); + +const NEEDED_FIELD: &str = r#" +class Foo { + my_int int + my_string string @stream.not_null +} + +class Bar { + foos Foo[] +} +"#; + +test_partial_deserializer_streaming!( + test_needed_field, + NEEDED_FIELD, + // r#"{"foos": [{"my_int": 1, "my_string": "hi"}, {"my_int": 10,"#, + r#"{"foos": [{"my_int": 1, "my"#, + FieldType::class("Bar"), + {"foos": []} +); + +const DONE_FIELD: &str = r#" +class Foo { + foo string @stream.done + bar string +} +"#; + +test_partial_deserializer_streaming!( + test_done_field, + DONE_FIELD, + r#"{"foo": ""#, + FieldType::class("Foo"), + {"foo": null, "bar": null} +); diff --git a/engine/baml-lib/parser-database/src/attributes/mod.rs b/engine/baml-lib/parser-database/src/attributes/mod.rs index 92e8c58bfc..f8c32a6633 100644 --- a/engine/baml-lib/parser-database/src/attributes/mod.rs +++ b/engine/baml-lib/parser-database/src/attributes/mod.rs @@ -29,6 +29,15 @@ pub struct Attributes { /// @check and @assert attributes attached to the node. pub constraints: Vec, + + /// Whether the node has a `@sstream.done` attribute. + pub streaming_done: Option, + + /// Whether the node has a `@stream.not_null` attribute. + pub streaming_needed: Option, + + /// Whether the node has a `@stream.with_state` attribute. + pub streaming_state: Option, } impl Attributes { diff --git a/engine/baml-lib/parser-database/src/attributes/to_string_attribute.rs b/engine/baml-lib/parser-database/src/attributes/to_string_attribute.rs index 7899198288..355cddbc2a 100644 --- a/engine/baml-lib/parser-database/src/attributes/to_string_attribute.rs +++ b/engine/baml-lib/parser-database/src/attributes/to_string_attribute.rs @@ -37,12 +37,32 @@ pub(super) fn visit(ctx: &mut Context<'_>, span: &Span, as_block: bool) -> Optio ctx.validate_visited_arguments(); } - if as_block && ctx.visit_optional_single_attr("dynamic") { - attributes.set_dynamic_type(); + if ctx.visit_optional_single_attr("stream.done") { + attributes.streaming_done = Some(true); modified = true; ctx.validate_visited_arguments(); } + if ctx.visit_optional_single_attr("stream.not_null") { + attributes.streaming_needed = Some(true); + modified = true; + ctx.validate_visited_arguments(); + } + + if ctx.visit_optional_single_attr("stream.with_state") { + attributes.streaming_state = Some(true); + modified = true; + ctx.validate_visited_arguments(); + } + + if as_block { + if ctx.visit_optional_single_attr("dynamic") { + attributes.set_dynamic_type(); + modified = true; + ctx.validate_visited_arguments(); + } + } + if modified { Some(attributes) } else { diff --git a/engine/baml-lib/parser-database/src/lib.rs b/engine/baml-lib/parser-database/src/lib.rs index 775cc9dad4..c81b4c7e45 100644 --- a/engine/baml-lib/parser-database/src/lib.rs +++ b/engine/baml-lib/parser-database/src/lib.rs @@ -72,7 +72,8 @@ use names::Names; /// - Global validations are then performed on the mostly validated schema. /// Currently only index name collisions. pub struct ParserDatabase { - ast: ast::SchemaAst, + /// The AST. + pub ast: ast::SchemaAst, interner: interner::StringInterner, names: Names, types: Types, diff --git a/engine/baml-lib/parser-database/src/names/validate_reserved_names.rs b/engine/baml-lib/parser-database/src/names/validate_reserved_names.rs index 6389c32de4..b75c12385e 100644 --- a/engine/baml-lib/parser-database/src/names/validate_reserved_names.rs +++ b/engine/baml-lib/parser-database/src/names/validate_reserved_names.rs @@ -108,11 +108,15 @@ fn validate_name( "env.* is reserved.", span.clone(), )), - ast::Identifier::Ref(_, span) => Err(DatamodelError::new_name_error( - _type, - "Namespace imports (using '.') are not yet supported.", - span.clone(), - )), + ast::Identifier::Ref(path, span) => { + let valid_paths = ["stream.done", "stream.not_null", "stream.with_state"]; + if !valid_paths.contains(&path.full_name.as_str()) { + Err(DatamodelError::new_name_error( + _type, + "Namespace imports (using '.') are not yet supported.", + span.clone(), + ))} else { Ok(()) } + }, ast::Identifier::Invalid(_, span) | ast::Identifier::String(_, span) => { Err(DatamodelError::new_name_error( diff --git a/engine/baml-lib/schema-ast/src/ast/field.rs b/engine/baml-lib/schema-ast/src/ast/field.rs index 6a14e02681..f1d8a33c26 100644 --- a/engine/baml-lib/schema-ast/src/ast/field.rs +++ b/engine/baml-lib/schema-ast/src/ast/field.rs @@ -348,6 +348,7 @@ impl FieldType { } } } + } // Impl display for FieldType diff --git a/engine/baml-lib/schema-ast/src/parser/datamodel.pest b/engine/baml-lib/schema-ast/src/parser/datamodel.pest index 623b746add..d5ff707aeb 100644 --- a/engine/baml-lib/schema-ast/src/parser/datamodel.pest +++ b/engine/baml-lib/schema-ast/src/parser/datamodel.pest @@ -72,8 +72,9 @@ non_union = { array_notation | map | identifier | group | tuple | literal_type } parenthesized_type = { openParan ~ field_type_with_attr ~ closeParan } path_identifier = { single_word ~ ("." ~ single_word)+ } -identifier = { path_identifier | single_word } +identifier = { path_identifier | namespaced_identifier | single_word } single_word = @{ ASCII_ALPHA ~ (ASCII_ALPHANUMERIC | "_" | "-")* } +namespaced_identifier = { single_word ~ ("::" ~ single_word)+ } // ###################################### // Type Alias diff --git a/engine/baml-lib/schema-ast/src/parser/parse_field.rs b/engine/baml-lib/schema-ast/src/parser/parse_field.rs index ed42d2ac3e..05767d23b6 100644 --- a/engine/baml-lib/schema-ast/src/parser/parse_field.rs +++ b/engine/baml-lib/schema-ast/src/parser/parse_field.rs @@ -60,16 +60,21 @@ pub(crate) fn parse_value_expr( } } +/// Sort all attributes on a field into either field attributes or type attributes. +/// The name of the attribute fully determines whether it will be associated with +/// the field, or with the type. fn reassociate_type_attributes(field_attributes: &mut Vec, field_type: &mut FieldType) { let mut all_attrs = field_type.attributes().to_owned(); all_attrs.append(field_attributes); - let (attrs_for_type, attrs_for_field): (Vec, Vec) = all_attrs + let (attrs_for_type, attrs_for_field): (Vec<_>, Vec<_>) = all_attrs .into_iter() - .partition(|attr| ["assert", "check"].contains(&attr.name())); - field_type.set_attributes(attrs_for_type); + .partition(|attr| TYPE_ATTRIBUTE_NAMES.contains(&attr.name())); + field_type.set_attributes(attrs_for_type.clone()); *field_attributes = attrs_for_field; } +const TYPE_ATTRIBUTE_NAMES: [&str; 4] = ["assert", "check", "stream.done", "stream.with_state"]; + pub(crate) fn parse_type_expr( model_name: &Option, container_type: &'static str, @@ -96,7 +101,8 @@ pub(crate) fn parse_type_expr( field_type = parse_field_type_chain(current, diagnostics); } Rule::field_attribute => { - field_attributes.push(parse_attribute(current, false, diagnostics)) + let attribute = parse_attribute(current, false, diagnostics); + field_attributes.push(attribute); } _ => parsing_catch_all(current, "field"), } @@ -477,6 +483,19 @@ mod tests { } } + #[test] + fn streaming_attributes() { + test_parse_baml_type! { + source: r#"int @stream.done @stream.not_null @stream.with_state"#, + target: FieldType::Primitive( + FieldArity::Required, + TypeValue::Int, + Span::fake(), + Some(vec![mk_bare_attribute("stream.done"), mk_bare_attribute("stream.not_null"), mk_bare_attribute("stream.with_state")]) + ), + } + } + // Convenience functions. fn mk_int(attrs: Option>) -> FieldType { @@ -506,4 +525,15 @@ mod tests { span: Span::fake(), } } + + fn mk_bare_attribute(value: &'static str) -> Attribute { + Attribute { + name: (value, Span::fake()).into(), + parenthesized: false, + arguments: ArgumentsList { + arguments: Vec::new() + }, + span: Span::fake() + } + } } diff --git a/engine/baml-lib/schema-ast/src/parser/parse_identifier.rs b/engine/baml-lib/schema-ast/src/parser/parse_identifier.rs index 376a0a6899..fd5f7d473a 100644 --- a/engine/baml-lib/schema-ast/src/parser/parse_identifier.rs +++ b/engine/baml-lib/schema-ast/src/parser/parse_identifier.rs @@ -15,6 +15,7 @@ pub fn parse_identifier(pair: Pair<'_>, diagnostics: &mut Diagnostics) -> Identi if let Some(inner) = pair.into_inner().next() { return match inner.as_rule() { Rule::path_identifier => parse_path_identifier(inner, diagnostics), + Rule::namespaced_identifier => parse_namespaced_identifier(inner, diagnostics), Rule::single_word => parse_single_word(inner, diagnostics), _ => unreachable_rule!(inner, Rule::identifier), }; @@ -64,3 +65,29 @@ fn parse_path_identifier(pair: Pair<'_>, diagnostics: &mut Diagnostics) -> Ident span, ); } + +/// Parse an identifier of the form `word::word::word` directly into a that string. +/// TODO: `Identifier` should eventually store the namespace components +/// individually. +fn parse_namespaced_identifier(pair: Pair<'_>, diagnostics: &mut Diagnostics) -> Identifier { + assert_correct_parser!(pair, Rule::namespaced_identifier); + + let raw_str = pair.as_str(); + let span = diagnostics.span(pair.as_span()); + let mut name_parts = Vec::new(); + for inner in pair.into_inner() { + match inner.as_rule() { + Rule::single_word => name_parts.push(inner.as_str()), + _ => unreachable_rule!(inner, Rule::namespaced_identifier), + } + } + + assert!( + name_parts.len() > 1, + "Namespaced identifier must have at least 2 elements. Parts({}) Raw({})", + name_parts.join("::"), + raw_str + ); + + Identifier::Local(name_parts.join("::"), span) +} \ No newline at end of file diff --git a/engine/baml-runtime/Cargo.toml b/engine/baml-runtime/Cargo.toml index 25dc816281..e6d430d580 100644 --- a/engine/baml-runtime/Cargo.toml +++ b/engine/baml-runtime/Cargo.toml @@ -134,6 +134,7 @@ aws-config = "1.5.3" aws-sdk-bedrockruntime = "1.37.0" axum = "0.7.5" axum-extra = { version = "0.9.3", features = ["erased-json", "typed-header"] } +criterion = "0.5.1" gcp_auth = "0.12.3" hostname = "0.3.1" jsonwebtoken = { version = "9.3.0" } @@ -154,6 +155,7 @@ skip-integ-tests = [] [dev-dependencies] assert_cmd = "2" console_log = "1" +criterion = "0.5.1" dissimilar = "1.0.4" expect-test = "1.1.0" indoc.workspace = true @@ -162,3 +164,8 @@ rstest = "0.22.0" wasm-bindgen-test = "0.3.42" walkdir = "2.5.0" wasm-logger = "0.2.0" + +[[ bench ]] +name = "bench" +path = "benches/bench.rs" +harness = false diff --git a/engine/baml-runtime/benches/bench.rs b/engine/baml-runtime/benches/bench.rs new file mode 100644 index 0000000000..d730d1d42f --- /dev/null +++ b/engine/baml-runtime/benches/bench.rs @@ -0,0 +1,107 @@ +use criterion::{criterion_group, criterion_main, Criterion}; +use std::hint::black_box; + +use baml_types::{BamlValueWithMeta, EvaluationContext, FieldType}; +use internal_baml_core::ir::repr::make_test_ir; +use jsonish::{from_str, helpers::render_output_format, BamlValueWithFlags}; + +use baml_runtime::internal::llm_client::{parsed_value_to_response, ResponseBamlValue}; + +criterion_group!(benches, parse_benchmarks, response_benchmarks); +criterion_main!(benches); + +fn parse( + schema: &str, + target_type: &FieldType, + msg: &str, + allow_partials: bool, +) -> BamlValueWithFlags { + let ir = make_test_ir(schema).unwrap(); + let target = render_output_format(&ir, target_type, &EvaluationContext::default()).unwrap(); + from_str(&target, target_type, msg, allow_partials).unwrap() +} + +const SCHEMA: &str = r#" +class Foo { + i int +} + +type JSONValue = int | float | bool | string | null | JSONValue[] | map +"#; + +const BIG_JSON: &str = r#" + { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3], + "object": { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3] + }, + "json": { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3], + "object": { + "number": 1, + "string": "test", + "bool": true, + "list": [1, 2, 3] + } + } + } +"#; + +fn parse_benchmarks(c: &mut Criterion) { + // c.bench_function("parse basic", |b| b.iter(|| parse( + // black_box(SCHEMA), + // black_box(&FieldType::class("Foo")), + // black_box(r#"{"i": 1}"#), + // black_box(false), + // )) + // ); + + // c.bench_function("parse JSONValue", |b| b.iter(|| parse( + // SCHEMA, + // &FieldType::RecursiveTypeAlias("JSONValue".to_string()), + // BIG_JSON, + // false, + // )) + // ); +} + +fn response_benchmarks(c: &mut Criterion) { + // c.bench_function("response basic", |b| b.iter(|| to_response( + // black_box(SCHEMA), + // black_box(&FieldType::class("Foo")), + // black_box(r#"{"i": 1}"#), + // black_box(false), + // ))); + + c.bench_function("response JSONValue", |b| { + b.iter(|| { + to_response( + black_box(SCHEMA), + black_box(&FieldType::RecursiveTypeAlias("JSONValue".to_string())), + black_box(BIG_JSON), + black_box(false), + ) + }) + }); +} + +fn to_response( + schema: &str, + target_type: &FieldType, + msg: &str, + allow_partials: bool, +) -> ResponseBamlValue { + let ir = make_test_ir(schema).unwrap(); + let target = render_output_format(&ir, target_type, &EvaluationContext::default()).unwrap(); + let parsed = from_str(&target, target_type, msg, allow_partials).unwrap(); + parsed_value_to_response(&ir, parsed, target_type, true).unwrap() +} diff --git a/engine/baml-runtime/benches/lib.rs b/engine/baml-runtime/benches/lib.rs new file mode 100644 index 0000000000..d33fc01002 --- /dev/null +++ b/engine/baml-runtime/benches/lib.rs @@ -0,0 +1 @@ +pub mod sap_parser_benchmark; \ No newline at end of file diff --git a/engine/baml-runtime/benches/sap_parser_benchmark.rs b/engine/baml-runtime/benches/sap_parser_benchmark.rs new file mode 100644 index 0000000000..e69de29bb2 diff --git a/engine/baml-runtime/src/cli/serve/mod.rs b/engine/baml-runtime/src/cli/serve/mod.rs index 1feab0f6d1..7c4a6e0d0d 100644 --- a/engine/baml-runtime/src/cli/serve/mod.rs +++ b/engine/baml-runtime/src/cli/serve/mod.rs @@ -31,11 +31,12 @@ use serde_json::json; use std::{path::PathBuf, sync::Arc, task::Poll}; use tokio::{net::TcpListener, sync::RwLock}; use tokio_stream::StreamExt; +use jsonish::ResponseBamlValue; use crate::{ client_registry::ClientRegistry, errors::ExposedError, - internal::llm_client::{LLMResponse, ResponseBamlValue}, + internal::llm_client::LLMResponse, BamlRuntime, FunctionResult, RuntimeContextManager, }; use internal_baml_codegen::openapi::OpenApiSchema; @@ -369,7 +370,7 @@ Tip: test that the server is up using `curl http://localhost:{}/_debug/ping` LLMResponse::Success(_) => { match function_result.result_with_constraints_content() { // Just because the LLM returned 2xx doesn't mean that it returned parse-able content! - Ok(parsed) => (StatusCode::OK, Json::(parsed.clone())) + Ok(parsed) => (StatusCode::OK, Json(parsed.serialize_final())) .into_response(), Err(e) => { if let Some(ExposedError::ValidationError { @@ -483,7 +484,7 @@ Tip: test that the server is up using `curl http://localhost:{}/_debug/ping` match function_result.result_with_constraints_content() { // Just because the LLM returned 2xx doesn't mean that it returned parse-able content! Ok(parsed) => { - (StatusCode::OK, Json::(parsed.clone())) + (StatusCode::OK, Json(&parsed.serialize_partial())) .into_response() } @@ -657,7 +658,7 @@ impl Stream for EventStream { match self.receiver.poll_recv(cx) { Poll::Ready(Some(item)) => match item.result_with_constraints_content() { // TODO: not sure if this is the correct way to implement this. - Ok(parsed) => Poll::Ready(Some(parsed.into())), + Ok(parsed) => Poll::Ready(Some(parsed.0.clone().into())), Err(_) => Poll::Pending, }, Poll::Ready(None) => Poll::Ready(None), diff --git a/engine/baml-runtime/src/internal/llm_client/mod.rs b/engine/baml-runtime/src/internal/llm_client/mod.rs index a40def101f..a404a7dc40 100644 --- a/engine/baml-runtime/src/internal/llm_client/mod.rs +++ b/engine/baml-runtime/src/internal/llm_client/mod.rs @@ -9,13 +9,21 @@ pub mod retry_policy; mod strategy; pub mod traits; -use anyhow::Result; +use anyhow::{Context, Result}; -use baml_types::{BamlMap, BamlValueWithMeta, JinjaExpression, ResponseCheck}; -use internal_baml_core::ir::ClientWalker; +use baml_types::{BamlMap, BamlValueWithMeta, FieldType, JinjaExpression, ResponseCheck}; +use internal_baml_core::ir::{repr::IntermediateRepr, ClientWalker}; use internal_baml_jinja::RenderedPrompt; use internal_llm_client::AllowedRoleMetadata; -use jsonish::BamlValueWithFlags; +pub use jsonish::ResponseBamlValue; +use jsonish::{ + deserializer::{ + deserialize_flags::{constraint_results, DeserializerConditions, Flag}, + semantic_streaming::validate_streaming_state, + semantic_streaming::validate_streaming_state2, + }, + BamlValueWithFlags, +}; use serde::{Deserialize, Serialize}; use std::error::Error; @@ -24,14 +32,57 @@ use reqwest::StatusCode; #[cfg(target_arch = "wasm32")] use wasm_bindgen::JsValue; -pub type ResponseBamlValue = BamlValueWithMeta>; - /// Validate a parsed value, checking asserts and checks. -pub fn parsed_value_to_response(baml_value: &BamlValueWithFlags) -> ResponseBamlValue { +pub fn parsed_value_to_response( + ir: &IntermediateRepr, + baml_value: BamlValueWithFlags, + field_type: &FieldType, + allow_partials: bool, +) -> Result { + let meta_flags: BamlValueWithMeta> = baml_value.clone().into(); let baml_value_with_meta: BamlValueWithMeta> = baml_value.clone().into(); - baml_value_with_meta.map_meta(|cs| { - cs.iter() + + let value_with_response_checks: BamlValueWithMeta> = baml_value_with_meta + .map_meta(|cs| { + cs.iter() + .map(|(label, expr, result)| { + let status = (if *result { "succeeded" } else { "failed" }).to_string(); + ResponseCheck { + name: label.clone(), + expression: expr.0.clone(), + status, + } + }) + .collect() + }); + + let baml_value_with_streaming = + validate_streaming_state(ir, &baml_value, field_type, allow_partials) + .map_err(|s| anyhow::anyhow!("{s:?}"))?; + + // Combine the baml_value, its types, the parser flags, and the streaming state + // into a final value. + // Node that we set the StreamState to `None` unless `allow_partials`. + let response_value = baml_value_with_streaming + .zip_meta(&value_with_response_checks)? + .zip_meta(&meta_flags)? + .map_meta(|((x, y), z)| (z.clone(), y.clone(), x.clone() )); + Ok(ResponseBamlValue(response_value)) +} + +/// Validate a parsed value, checking asserts and checks. +pub fn parsed_value_to_response2( + ir: &IntermediateRepr, + baml_value: BamlValueWithFlags, + field_type: &FieldType, + allow_partials: bool, +) -> Result { + let meta_flags: BamlValueWithMeta> = baml_value.into(); + let baml_value_with_streaming2 = meta_flags.map_meta_owned(|flags| { + let constraint_results = constraint_results(&flags); + let response_checks: Vec = constraint_results + .iter() .map(|(label, expr, result)| { let status = (if *result { "succeeded" } else { "failed" }).to_string(); ResponseCheck { @@ -40,8 +91,15 @@ pub fn parsed_value_to_response(baml_value: &BamlValueWithFlags) -> ResponseBaml status, } }) - .collect() - }) + .collect(); + (flags, response_checks) + }); + + let response_value2 = + validate_streaming_state2(ir, baml_value_with_streaming2, field_type, allow_partials) + .map_err(|s| anyhow::anyhow!("TODO {s:?}"))?; + + Ok(ResponseBamlValue(response_value2)) } #[derive(Clone, Copy, PartialEq)] @@ -355,3 +413,156 @@ impl crate::tracing::Visualize for LLMErrorResponse { s.join("\n") } } + +#[cfg(test)] +mod tests { + use super::*; + use baml_types::{BamlValueWithMeta, FieldType}; + use internal_baml_core::ir::repr::{make_test_ir, IntermediateRepr}; + use jsonish::{ + deserializer::{deserialize_flags::DeserializerConditions, types::ValueWithFlags}, + BamlValueWithFlags, + }; + + fn mk_ir() -> IntermediateRepr { + make_test_ir( + r##" + class Foo { + i int + s string @stream.done + } + "##, + ) + .expect("Source is valid") + } + + #[test] + fn to_response() { + let ir = mk_ir(); + let val = BamlValueWithFlags::Class( + "Foo".to_string(), + DeserializerConditions { + flags: vec![Flag::Incomplete], + }, + vec![ + ( + "i".to_string(), + BamlValueWithFlags::Int(ValueWithFlags { + value: 1, + flags: DeserializerConditions { flags: Vec::new() }, + }), + ), + ( + "s".to_string(), + BamlValueWithFlags::String(ValueWithFlags { + value: "H".to_string(), + flags: DeserializerConditions { + flags: vec![Flag::Incomplete], + }, + }), + ), + ] + .into_iter() + .collect(), + ); + let response = parsed_value_to_response(&ir, val, &FieldType::class("Foo"), true); + assert!(response.is_ok()); + } + + fn mk_null() -> BamlValueWithFlags { + BamlValueWithFlags::Null(DeserializerConditions::default()) + } + + fn mk_string(s: &str) -> BamlValueWithFlags { + BamlValueWithFlags::String(ValueWithFlags { + value: s.to_string(), + flags: DeserializerConditions::default(), + }) + } + fn mk_float(s: f64) -> BamlValueWithFlags { + BamlValueWithFlags::Float(ValueWithFlags { + value: s, + flags: DeserializerConditions::default(), + }) + } + + #[test] + fn stable_keys2() { + let ir = make_test_ir( + r##" + class Address { + street string + state string + } + class Name { + first string + last string? + } + class Info { + name Name + address Address? + hair_color string + height float + } + "##, + ) + .unwrap(); + + let value = BamlValueWithFlags::Class( + "Info".to_string(), + DeserializerConditions::default(), + vec![ + ( + "name".to_string(), + BamlValueWithFlags::Class( + "Name".to_string(), + DeserializerConditions::default(), + vec![ + ("first".to_string(), mk_string("Greg")), + ("last".to_string(), mk_string("Hale")), + ] + .into_iter() + .collect(), + ), + ), + ("address".to_string(), mk_null()), + ("hair_color".to_string(), mk_string("Grey")), + ("height".to_string(), mk_float(1.75)), + ] + .into_iter() + .collect(), + ); + let field_type = FieldType::class("Info"); + + let res = parsed_value_to_response(&ir, value, &field_type, true).unwrap(); + + let json = serde_json::to_value(&res).unwrap(); + + match &json { + serde_json::Value::Object(items) => { + let (k, _) = items.iter().next().unwrap(); + assert_eq!(k, "name") + } + _ => panic!("Expected json object"), + } + } + + #[test] + fn integ_test_failure() { + let ir = make_test_ir(r#" + class Foo { + prop1 string + prop2 int + } + "#).unwrap(); + let target_type = FieldType::class("Foo"); + let target = jsonish::helpers::render_output_format(&ir, &target_type, &Default::default()).unwrap(); + + let msg = r#"{"prop1": "something", "prop2": 2}"#; + + let parsed = jsonish::from_str(&target, &target_type, msg, true).unwrap(); + let response = parsed_value_to_response(&ir, parsed, &target_type, true).unwrap(); + let json = serde_json::to_string(&response).unwrap(); + assert_eq!(json, r#"{"prop1":"something","prop2":2}"#); + } +} \ No newline at end of file diff --git a/engine/baml-runtime/src/internal/llm_client/orchestrator/call.rs b/engine/baml-runtime/src/internal/llm_client/orchestrator/call.rs index 1156b43e82..b68861a090 100644 --- a/engine/baml-runtime/src/internal/llm_client/orchestrator/call.rs +++ b/engine/baml-runtime/src/internal/llm_client/orchestrator/call.rs @@ -1,7 +1,7 @@ use anyhow::Result; use baml_types::BamlValue; use internal_baml_core::ir::repr::IntermediateRepr; -use jsonish::BamlValueWithFlags; +use jsonish::{BamlValueWithFlags, ResponseBamlValue}; use web_time::Duration; use crate::{ @@ -9,7 +9,7 @@ use crate::{ llm_client::{ parsed_value_to_response, traits::{WithClientProperties, WithPrompt, WithSingleCallable}, - LLMResponse, ResponseBamlValue, + LLMResponse, }, prompt_renderer::PromptRenderer, }, @@ -24,12 +24,11 @@ pub async fn orchestrate( ctx: &RuntimeContext, prompt: &PromptRenderer, params: &BamlValue, - parse_fn: impl Fn(&str) -> Result, + parse_fn: impl Fn(&str) -> Result, ) -> ( Vec<( OrchestrationScope, LLMResponse, - Option>, Option>, )>, Duration, @@ -45,7 +44,6 @@ pub async fn orchestrate( node.scope, LLMResponse::InternalFailure(e.to_string()), None, - None, )); continue; } @@ -71,22 +69,16 @@ pub async fn orchestrate( }; let sleep_duration = node.error_sleep_duration().cloned(); - let (parsed_response, response_with_constraints) = match parsed_response { - Some(Ok(v)) => (Some(Ok(v.clone())), Some(Ok(parsed_value_to_response(&v)))), - Some(Err(e)) => (None, Some(Err(e))), - None => (None, None), - }; results.push(( node.scope, response, parsed_response, - response_with_constraints, )); // Currently, we break out of the loop if an LLM responded, even if we couldn't parse the result. if results .last() - .map_or(false, |(_, r, _, _)| matches!(r, LLMResponse::Success(_))) + .map_or(false, |(_, r, _)| matches!(r, LLMResponse::Success(_))) { break; } else if let Some(duration) = sleep_duration { diff --git a/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs b/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs index c3cf30fa26..68422676fb 100644 --- a/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs +++ b/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs @@ -25,14 +25,13 @@ pub async fn orchestrate_stream( ctx: &RuntimeContext, prompt: &PromptRenderer, params: &BamlValue, - partial_parse_fn: impl Fn(&str) -> Result, - parse_fn: impl Fn(&str) -> Result, + partial_parse_fn: impl Fn(&str) -> Result, + parse_fn: impl Fn(&str) -> Result, on_event: Option, ) -> ( Vec<( OrchestrationScope, LLMResponse, - Option>, Option>, )>, Duration, @@ -52,7 +51,6 @@ where node.scope, LLMResponse::InternalFailure(e.to_string()), None, - None, )); continue; } @@ -65,18 +63,11 @@ where .map(|stream_part| { if let Some(on_event) = on_event.as_ref() { if let LLMResponse::Success(s) = &stream_part { - let parsed = partial_parse_fn(&s.content); - let (parsed, response_value) = match parsed { - Ok(v) => { - (Some(Ok(v.clone())), Some(Ok(parsed_value_to_response(&v)))) - } - Err(e) => (None, Some(Err(e))), - }; + let response_value = partial_parse_fn(&s.content); on_event(FunctionResult::new( node.scope.clone(), LLMResponse::Success(s.clone()), - parsed, - response_value, + Some(response_value), )); } } @@ -99,7 +90,7 @@ where Err(response) => response, }; - let parsed_response = match &final_response { + let response_value = match &final_response { LLMResponse::Success(s) => { if !node .finish_reason_filter() @@ -117,19 +108,13 @@ where }, _ => None, }; - let (parsed_response, response_value) = match parsed_response { - Some(Ok(v)) => (Some(Ok(v.clone())), Some(Ok(parsed_value_to_response(&v)))), - Some(Err(e)) => (None, Some(Err(e))), - None => (None, None), - }; - // parsed_response.map(|r| r.and_then(|v| parsed_value_to_response(v))); let sleep_duration = node.error_sleep_duration().cloned(); - results.push((node.scope, final_response, parsed_response, response_value)); + results.push((node.scope, final_response, response_value)); // Currently, we break out of the loop if an LLM responded, even if we couldn't parse the result. if results .last() - .map_or(false, |(_, r, _, _)| matches!(r, LLMResponse::Success(_))) + .map_or(false, |(_, r, _)| matches!(r, LLMResponse::Success(_))) { break; } else if let Some(duration) = sleep_duration { diff --git a/engine/baml-runtime/src/internal/prompt_renderer/mod.rs b/engine/baml-runtime/src/internal/prompt_renderer/mod.rs index 2dbbfa65a3..4ce4f4fcff 100644 --- a/engine/baml-runtime/src/internal/prompt_renderer/mod.rs +++ b/engine/baml-runtime/src/internal/prompt_renderer/mod.rs @@ -1,6 +1,6 @@ mod render_output_format; use internal_llm_client::ClientSpec; -use jsonish::BamlValueWithFlags; +use jsonish::{BamlValueWithFlags, ResponseBamlValue}; use render_output_format::render_output_format; use anyhow::Result; @@ -16,6 +16,8 @@ use internal_baml_jinja::{ use crate::RuntimeContext; +use super::llm_client::parsed_value_to_response; + pub struct PromptRenderer { function_name: String, client_spec: ClientSpec, @@ -49,13 +51,20 @@ impl PromptRenderer { &self.client_spec } - pub fn parse(&self, raw_string: &str, allow_partials: bool) -> Result { - jsonish::from_str( + pub fn parse( + &self, + ir: &IntermediateRepr, + raw_string: &str, + allow_partials: bool + ) -> Result { + let parsed = jsonish::from_str( &self.output_defs, &self.output_type, raw_string, allow_partials, - ) + )?; + let res = parsed_value_to_response(ir, parsed, &self.output_type, allow_partials); + res } pub fn render_prompt( diff --git a/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs b/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs index 069ced74e0..2417386ebf 100644 --- a/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs +++ b/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs @@ -74,7 +74,7 @@ fn find_new_class_field( class_walker: &Result>, overrides: &RuntimeClassOverride, _ctx: &RuntimeContext, -) -> Result<(Name, FieldType, Option)> { +) -> Result<(Name, FieldType, Option, bool)> { let Some(field_overrides) = overrides.new_fields.get(field_name) else { anyhow::bail!("Class {} does not have a field: {}", class_name, field_name); }; @@ -96,7 +96,7 @@ fn find_new_class_field( let name = Name::new_with_alias(field_name.to_string(), alias.value()); let desc = desc.value(); - Ok((name, field_overrides.0.clone(), desc)) + Ok((name, field_overrides.0.clone(), desc, false)) // TODO: Field overrides are not "stream.not_nul". Should this be configurable? } fn find_existing_class_field( @@ -105,7 +105,7 @@ fn find_existing_class_field( class_walker: &Result>, overrides: &Option<&RuntimeClassOverride>, ctx: &RuntimeContext, -) -> Result<(Name, FieldType, Option)> { +) -> Result<(Name, FieldType, Option, bool)> { let Ok(class_walker) = class_walker else { anyhow::bail!("Class {} does not exist", class_name); }; @@ -118,10 +118,12 @@ fn find_existing_class_field( let mut alias = OverridableValue::Unset; let mut desc = OverridableValue::Unset; + let mut needed = OverridableValue::Unset; if let Some(attrs) = field_overrides { alias = OverridableValue::::from(attrs.alias.as_ref()); desc = OverridableValue::::from(attrs.meta.get("description")); + needed = OverridableValue::::from(attrs.meta.get("stream.not_null")); } let eval_ctx = ctx.eval_ctx(false); @@ -138,10 +140,12 @@ fn find_existing_class_field( } } + let name = Name::new_with_alias(field_name.to_string(), alias.value()); let desc = desc.value(); let r#type = field_walker.r#type(); - Ok((name, r#type.clone(), desc)) + let needed = needed.value().unwrap_or(false); + Ok((name, r#type.clone(), desc, needed)) } fn find_enum_value( @@ -228,8 +232,8 @@ fn relevant_data_models<'a>( let eval_ctx = ctx.eval_ctx(false); while let Some(output) = start.pop() { - match ir.distribute_constraints(&output) { - (FieldType::Enum(enm), constraints) => { + match ir.distribute_metadata(&output) { + (FieldType::Enum(enm), (constraints, streaming_behavior)) => { if checked_types.insert(output.to_string()) { let overrides = ctx.enum_overrides.get(enm); let walker = ir.find_enum(enm); @@ -297,7 +301,7 @@ fn relevant_data_models<'a>( } } } - (FieldType::Class(cls), constraints) => { + (FieldType::Class(cls), (constraints, streaming_behavior)) => { if checked_types.insert(output.to_string()) { let overrides = ctx.class_override.get(cls); let walker = ir.find_class(cls); @@ -345,7 +349,7 @@ fn relevant_data_models<'a>( let fields = fields.chain(new_fields).collect::>>()?; - for (_, t, _) in fields.iter().as_ref() { + for (_, t, _, _) in fields.iter().as_ref() { if !checked_types.contains(&t.to_string()) { start.push(t.clone()); } @@ -371,6 +375,7 @@ fn relevant_data_models<'a>( name: Name::new_with_alias(cls.to_string(), alias.value()), fields, constraints, + streaming_behavior, }); } else { // TODO: @antonio This one was nasty! If aliases are not @@ -396,7 +401,7 @@ fn relevant_data_models<'a>( } (FieldType::Literal(_), _) => {} (FieldType::Primitive(_), _) => {} - (FieldType::Constrained { .. }, _) => { + (FieldType::WithMetadata { .. }, _) => { unreachable!("It is guaranteed that a call to distribute_constraints will not return FieldType::Constrained") } } diff --git a/engine/baml-runtime/src/lib.rs b/engine/baml-runtime/src/lib.rs index 8af6ddcdba..5dc493846e 100644 --- a/engine/baml-runtime/src/lib.rs +++ b/engine/baml-runtime/src/lib.rs @@ -8,7 +8,7 @@ pub(crate) mod internal; #[cfg(not(target_arch = "wasm32"))] pub mod cli; pub mod client_registry; -pub mod constraints; +pub mod test_constraints; pub mod errors; pub mod request; mod runtime; @@ -64,7 +64,7 @@ pub use internal_baml_core::internal_baml_diagnostics; pub use internal_baml_core::internal_baml_diagnostics::Diagnostics as DiagnosticsError; pub use internal_baml_core::ir::{scope_diagnostics, FieldType, IRHelper, TypeValue}; -use crate::constraints::{evaluate_test_constraints, TestConstraintsResult}; +use crate::test_constraints::{evaluate_test_constraints, TestConstraintsResult}; use crate::internal::llm_client::LLMResponse; #[cfg(not(target_arch = "wasm32"))] @@ -241,7 +241,7 @@ impl BamlRuntime { let (response_res, span_uuid) = stream.run(on_event, ctx, None, None).await; log::info!("response_res: {:#?}", response_res); let res = response_res?; - let (_, llm_resp, _, val) = res + let (_, llm_resp, val) = res .event_chain() .iter() .last() @@ -263,7 +263,8 @@ impl BamlRuntime { } else { match val { Some(Ok(value)) => { - evaluate_test_constraints(¶ms, value, complete_resp, constraints) + let value_with_constraints = value.0.map_meta(|(_,constraints,_)| constraints.clone()); + evaluate_test_constraints(¶ms, &value_with_constraints, complete_resp, constraints) } _ => TestConstraintsResult::empty(), } diff --git a/engine/baml-runtime/src/runtime/runtime_interface.rs b/engine/baml-runtime/src/runtime/runtime_interface.rs index c1c9669bf3..9655baff67 100644 --- a/engine/baml-runtime/src/runtime/runtime_interface.rs +++ b/engine/baml-runtime/src/runtime/runtime_interface.rs @@ -340,7 +340,6 @@ impl RuntimeInterface for InternalBamlRuntime { e )), None, - None, )) } }; @@ -379,7 +378,9 @@ impl RuntimeInterface for InternalBamlRuntime { // Now actually execute the code. let (history, _) = orchestrate_call(orchestrator, self.ir(), &ctx, &renderer, &baml_args, |s| { - renderer.parse(s, false) + // eprintln!("RAW"); + // eprintln!("{}", s); + renderer.parse(self.ir(), s, false) }) .await; diff --git a/engine/baml-runtime/src/constraints.rs b/engine/baml-runtime/src/test_constraints.rs similarity index 100% rename from engine/baml-runtime/src/constraints.rs rename to engine/baml-runtime/src/test_constraints.rs diff --git a/engine/baml-runtime/src/tracing/mod.rs b/engine/baml-runtime/src/tracing/mod.rs index 0b53378876..64af6ff282 100644 --- a/engine/baml-runtime/src/tracing/mod.rs +++ b/engine/baml-runtime/src/tracing/mod.rs @@ -3,7 +3,7 @@ pub mod api_wrapper; use crate::on_log_event::LogEventCallbackSync; use crate::InnerTraceStats; use anyhow::{Context, Result}; -use baml_types::{BamlMap, BamlMediaType, BamlValue}; +use baml_types::{BamlMap, BamlMediaType, BamlValue, BamlValueWithMeta}; use cfg_if::cfg_if; use colored::{ColoredString, Colorize}; use internal_baml_jinja::RenderedPrompt; @@ -12,6 +12,7 @@ use std::collections::HashMap; use std::sync::{Arc, Mutex}; use uuid::Uuid; +use jsonish::ResponseBamlValue; use crate::{ client_registry::ClientRegistry, internal::llm_client::LLMResponse, @@ -125,9 +126,9 @@ impl Visualize for FunctionResult { Some(Ok(val)) => { s.push(format!( "{}", - format!("---Parsed Response ({})---", val.r#type()).blue() + format!("---Parsed Response ({})---", val.0.r#type()).blue() )); - let json_str = serde_json::to_string_pretty(&val).unwrap(); + let json_str = serde_json::to_string_pretty(&val.serialize_final()).unwrap(); s.push(truncate_string(&json_str, max_chunk_size).to_string()); } Some(Err(e)) => { @@ -501,12 +502,12 @@ impl BamlTracer { .result_with_constraints() .as_ref() .and_then(|r| r.as_ref().ok()) - .map(|v| v.r#type().to_string()), + .map(|v| v.0.r#type().to_string()), parsed_response: response .result_with_constraints() .as_ref() .and_then(|r| r.as_ref().ok()) - .map(|v| serde_json::to_string(v).unwrap_or_default()), + .map(|v| serde_json::to_string(&v.serialize_final()).unwrap_or_default()), error, }, LLMResponse::LLMFailure(err) => BamlEventJson { @@ -830,7 +831,7 @@ impl ToLogSchema for FunctionResult { .as_ref() .and_then(|r| r.as_ref().ok()) .map(|r| { - let v: BamlValue = r.into(); + let v: BamlValue = r.0.clone().into(); IOValue::from(&v) }), }, @@ -846,7 +847,7 @@ impl From<&FunctionResult> for MetadataType { result .event_chain() .iter() - .map(|(_, r, _, _)| r.into()) + .map(|(_, r, _)| r.into()) .collect::>(), ) } diff --git a/engine/baml-runtime/src/types/response.rs b/engine/baml-runtime/src/types/response.rs index 2c04bee01d..2f5dacb472 100644 --- a/engine/baml-runtime/src/types/response.rs +++ b/engine/baml-runtime/src/types/response.rs @@ -1,21 +1,20 @@ pub use crate::internal::llm_client::LLMResponse; use crate::{ - constraints::TestConstraintsResult, + test_constraints::TestConstraintsResult, errors::ExposedError, - internal::llm_client::{orchestrator::OrchestrationScope, ResponseBamlValue}, + internal::llm_client::orchestrator::OrchestrationScope, }; use anyhow::Result; use colored::*; -use baml_types::BamlValue; -use jsonish::BamlValueWithFlags; +use baml_types::{BamlValue, BamlValueWithMeta}; +use jsonish::{deserializer::deserialize_flags::Flag, BamlValueWithFlags, ResponseBamlValue, SerializeMode}; #[derive(Debug)] pub struct FunctionResult { event_chain: Vec<( OrchestrationScope, LLMResponse, - Option>, Option>, )>, } @@ -36,9 +35,9 @@ impl std::fmt::Display for FunctionResult { writeln!( f, "{}", - format!("---Parsed Response ({})---", val.r#type()).blue() + format!("---Parsed Response ({})---", val.0.r#type()).blue() )?; - write!(f, "{:#}", serde_json::json!(val)) + write!(f, "{:#}", serde_json::json!(val.serialize_partial())) } Some(Err(e)) => { writeln!(f, "{}", "---Parsed Response---".blue())?; @@ -53,11 +52,10 @@ impl FunctionResult { pub fn new( scope: OrchestrationScope, response: LLMResponse, - parsed: Option>, baml_value: Option>, ) -> Self { Self { - event_chain: vec![(scope, response, parsed, baml_value)], + event_chain: vec![(scope, response, baml_value)], } } @@ -66,7 +64,6 @@ impl FunctionResult { ) -> &Vec<( OrchestrationScope, LLMResponse, - Option>, Option>, )> { &self.event_chain @@ -76,7 +73,6 @@ impl FunctionResult { chain: Vec<( OrchestrationScope, LLMResponse, - Option>, Option>, )>, ) -> Result { @@ -99,41 +95,18 @@ impl FunctionResult { &self.event_chain.last().unwrap().0 } - pub fn parsed(&self) -> &Option> { - &self.event_chain.last().unwrap().2 - } - - /// Get the parsed result. This logic is strange because parsing errors can - /// be forwarded to a different field in the orchestrator. - /// TODO: (Greg) Fix the strange logic. - /// Historical note: Most of the consumers of the orchestrator use a final - /// `ResponseBamlValue`, a type designed to hold only the information needed - /// in those responses. But one consumer, the wasm client, requires extra info - /// from the parsing stage. Therefore we preserve both the parsing stage data - /// and the `ResponseValue` side by side. And because `anyhow::Error` is not - /// `Clone`, errors from the parsing stage are handled the most easily by - /// migrating them to the `ResponseValue` in cases where parsing failed. - /// The proper solution is to create a `RuntimeBamlValue` that contains - /// enough information for all clients, and then types like - /// `SDKClientResponseBamlValue` and `WasmResponseBamlValue` which derive - /// from `RuntimeBamlValue` where needed. - pub fn parsed_content(&self) -> Result<&BamlValueWithFlags> { - match (self.parsed(), self.result_with_constraints()) { - // Error at parse time was forwarded to later result. - (None, Some(Err(e))) => Err(self.format_err(e)), - // Parsing succeeded. - (Some(Ok(v)), _) => Ok(v), - // Error at parse time was not forwarded to later results. - (Some(Err(e)), _) => Err(self.format_err(e)), - (None, None) => Err(anyhow::anyhow!(self.llm_response().clone())), - (None, Some(_)) => { - unreachable!("A response could not have been created without a successful parse") - } + pub fn parsed(&self) -> &Option> { + match self.event_chain.last() { + Some((_,_,result)) => result, + None => &None, } } pub fn result_with_constraints(&self) -> &Option> { - &self.event_chain.last().unwrap().3 + match self.event_chain.last() { + Some((_, _, result)) => result, + None => &None + } } pub fn result_with_constraints_content(&self) -> Result<&ResponseBamlValue> { diff --git a/engine/baml-runtime/src/types/stream.rs b/engine/baml-runtime/src/types/stream.rs index f9abe81237..a4d02b529b 100644 --- a/engine/baml-runtime/src/types/stream.rs +++ b/engine/baml-runtime/src/types/stream.rs @@ -100,8 +100,8 @@ impl FunctionResultStream { &rctx, &self.renderer, &baml_types::BamlValue::Map(local_params), - |content| self.renderer.parse(content, true), - |content| self.renderer.parse(content, false), + |content| self.renderer.parse(self.ir.as_ref(), content, true), + |content| self.renderer.parse(self.ir.as_ref(), content, false), on_event, ) .await; diff --git a/engine/baml-schema-wasm/src/runtime_wasm/mod.rs b/engine/baml-schema-wasm/src/runtime_wasm/mod.rs index bf4e2af60f..1b4ecca501 100644 --- a/engine/baml-schema-wasm/src/runtime_wasm/mod.rs +++ b/engine/baml-schema-wasm/src/runtime_wasm/mod.rs @@ -11,6 +11,8 @@ use baml_runtime::RenderCurlSettings; use baml_runtime::{ internal::llm_client::LLMResponse, BamlRuntime, DiagnosticsError, IRHelper, RenderedPrompt, }; +use baml_types::BamlValueWithMeta; +use baml_types::ResponseCheck; use baml_types::{BamlMediaType, BamlValue, GeneratorOutputType, TypeValue}; use indexmap::IndexMap; use internal_baml_codegen::version_check::GeneratorType; @@ -23,6 +25,7 @@ use baml_runtime::internal::llm_client::orchestrator::ExecutionScope; use itertools::join; use js_sys::Promise; use js_sys::Uint8Array; +use jsonish::ResponseBamlValue; use serde::{Deserialize, Serialize}; use serde_json::json; use std::collections::HashMap; @@ -485,7 +488,7 @@ impl WasmFunctionResponse { pub fn parsed_response(&self) -> Option { self.function_response .result_with_constraints_content() - .map(|p| serde_json::to_string(&BamlValue::from(p))) + .map(|p| serde_json::to_string(&p.serialize_partial())) .map_or_else(|_| None, |s| s.ok()) } @@ -507,83 +510,24 @@ impl WasmFunctionResponse { } } -// TODO: What is supposed to happen with the serialized baml_value? -// That value has checks nested inside. Are they meant to be removed -// during flattening? Or duplicated into the top-level list of checks? -fn flatten_checks(value: &BamlValueWithFlags) -> (serde_json::Value, usize) { - // // Note: (Greg) depending on the goal of this function, we may be able - // // to replace most of it like this: - // let value_with_meta: BamlValueWithMeta> = parsed_value_to_response(value); - // let n_checks: usize = value_with_meta.iter().map(|node| node.meta().len()).sum(); - // let bare_baml_value: BamlValue = value_with_meta.into(); - // let json_value: serde_json::Value = serde_json::to_value(bare_baml_value).unwrap_or( - // "Error converting value to JSON".into() - // ); - +fn serialize_value_counting_checks(value: &ResponseBamlValue) -> (serde_json::Value, usize) { type J = serde_json::Value; let checks = value - .conditions() - .flags() + .0 + .meta() + .1 .iter() - .flat_map(|f| match f { - Flag::ConstraintResults(c) => c - .iter() - .map(|(label, _expr, b)| (label.clone(), *b)) - .collect::>(), - _ => vec![], - }) - .collect::>(); - - let (retval, sub_check_count) = match value { - BamlValueWithFlags::String(s) => (J::String(s.value.clone()), 0), - BamlValueWithFlags::Int(i) => (i.value.into(), 0), - BamlValueWithFlags::Float(f) => (f.value.into(), 0), - BamlValueWithFlags::Bool(b) => (J::Bool(b.value), 0), - BamlValueWithFlags::List(_, v) => { - let (values, counts): (Vec<_>, Vec<_>) = v.iter().map(|e| flatten_checks(e)).unzip(); - (J::Array(values), counts.iter().sum()) - } - BamlValueWithFlags::Map(_, m) => { - let (values, counts): (serde_json::Map, Vec<_>) = m - .iter() - .map(|(k, (_, v))| { - let (value, count) = flatten_checks(v); - ((k.clone(), value), count) - }) - .unzip(); - (J::Object(values), counts.iter().sum()) - } - BamlValueWithFlags::Enum(_, v) => (J::String(v.value.clone()), 0), - BamlValueWithFlags::Class(_, _, m) => { - let (values, counts): (serde_json::Map, Vec<_>) = m - .iter() - .map(|(k, v)| { - let (value, count) = flatten_checks(v); - ((k.clone(), value), count) - }) - .unzip(); - (J::Object(values), counts.iter().sum()) - } - BamlValueWithFlags::Null(_) => (J::Null, 0), - BamlValueWithFlags::Media(_) => ( - serde_json::Value::String("media type not supported".to_string()), - 0, - ), - }; + .map(|ResponseCheck { name, status, .. }| (name.clone(), status.clone())) + .collect::>(); - let check_count = checks.len() + sub_check_count; + let sub_check_count: usize = value.0.iter().map(|node| node.meta().1.len()).sum(); + let json_value: serde_json::Value = + serde_json::to_value(value.serialize_final()).unwrap_or("Error converting value to JSON".into()); - let final_value = if checks.is_empty() { - retval - } else { - json!({ - "value": retval, - "checks": checks, - }) - }; + let check_count = checks.len() + sub_check_count; - (final_value, check_count) + (json_value, check_count) } #[wasm_bindgen] @@ -617,14 +561,23 @@ impl WasmTestResponse { } fn parsed_response_impl(&self) -> anyhow::Result { - let parsed_response = self + let maybe_parsed_response = &self .test_response .as_ref() .ok() .context("No test response")? .function_response - .parsed_content()?; - let (flattened_checks, check_count) = flatten_checks(&parsed_response); + .parsed() + .as_ref(); + let parsed_response = match maybe_parsed_response { + Some(Ok(value)) => Ok(value), + _ => Err(anyhow::anyhow!("No parsed value")), + } + .context("No parsed value")?; + // let baml_value_with_response_checks = parsed_response + // .0 + // .map_meta(|(_, response_checks, _)| response_checks.clone()); + let (flattened_checks, check_count) = serialize_value_counting_checks(&parsed_response); Ok(WasmParsedTestResponse { value: serde_json::to_string(&flattened_checks)?, check_count, @@ -927,7 +880,7 @@ fn get_dummy_value( Some(format!("({},)", dummy)) } baml_runtime::FieldType::Optional(_) => None, - baml_runtime::FieldType::Constrained { base, .. } => { + baml_runtime::FieldType::WithMetadata { base, .. } => { get_dummy_value(indent, allow_multiline, base) } } diff --git a/engine/language_client_codegen/src/lib.rs b/engine/language_client_codegen/src/lib.rs index 1775ead3dd..a67167550a 100644 --- a/engine/language_client_codegen/src/lib.rs +++ b/engine/language_client_codegen/src/lib.rs @@ -305,9 +305,10 @@ pub fn type_check_attributes(ir: &IntermediateRepr) -> HashSet Option { match field_type { - FieldType::Constrained { base, constraints } => { + FieldType::WithMetadata { base, constraints, .. } => { let direct_sub_attributes = field_type_attributes(base); let mut check_names = TypeCheckAttributes( constraints diff --git a/engine/language_client_codegen/src/openapi.rs b/engine/language_client_codegen/src/openapi.rs index 9fb2cca911..cfd1406e59 100644 --- a/engine/language_client_codegen/src/openapi.rs +++ b/engine/language_client_codegen/src/openapi.rs @@ -650,7 +650,7 @@ impl<'ir> ToTypeReferenceInTypeDefinition<'ir> for FieldType { // something i saw suggested doing this inner.to_type_spec(_ir)? } - FieldType::Constrained { base, .. } => match field_type_attributes(self) { + FieldType::WithMetadata { base, .. } => match field_type_attributes(self) { Some(checks) => { let base_type_ref = base.to_type_spec(_ir)?; let checks_type_spec = type_def_for_checks(checks); diff --git a/engine/language_client_codegen/src/python/generate_types.rs b/engine/language_client_codegen/src/python/generate_types.rs index a59b05be8d..2216c2c072 100644 --- a/engine/language_client_codegen/src/python/generate_types.rs +++ b/engine/language_client_codegen/src/python/generate_types.rs @@ -124,8 +124,7 @@ impl<'ir> From> for PythonClass<'ir> { ( Cow::Borrowed(f.elem.name.as_str()), add_default_value( - &f.elem.r#type.elem, - &f.elem.r#type.elem.to_type_ref(c.db), + &f.elem.r#type.elem.to_type_ref(c.db, false), ), f.elem.docstring.as_ref().map(render_docstring), ) @@ -146,7 +145,7 @@ impl<'ir> From> for PythonTypeAlias<' ) -> Self { PythonTypeAlias { name: Cow::Borrowed(name), - target: target.to_type_ref(db), + target: target.to_type_ref(db, false), } } } @@ -175,12 +174,25 @@ impl<'ir> From> for PartialPythonClass<'ir> { .static_fields .iter() .map(|f| { + // Fields with @stream.done should take their type from + let needed: bool = f.attributes.get("stream.not_null").is_some(); + let (_, metadata) = c.db.distribute_metadata(&f.elem.r#type.elem); + let done: bool = metadata.1.done; + let field = match (done, needed) { + // A normal partial field. + (false, false) => add_default_value( + &f.elem.r#type.elem.to_partial_type_ref(c.db, false, false)), + // A field with @stream.done and no @stream.not_null + (true, false) => add_default_value( + &optional(&f.elem.r#type.elem.to_type_ref(c.db, true)) + ), + (false, true) => add_default_value( + &f.elem.r#type.elem.to_partial_type_ref(c.db, false, true)), + (true, true) => f.elem.r#type.elem.to_type_ref(c.db, true), // TODO: Fix. + }; ( f.elem.name.as_str(), - add_default_value( - &f.elem.r#type.elem, - &f.elem.r#type.elem.to_partial_type_ref(c.db, false), - ), + field, f.elem.docstring.as_ref().map(render_docstring), ) }) @@ -190,7 +202,8 @@ impl<'ir> From> for PartialPythonClass<'ir> { } } -pub fn add_default_value(node: &FieldType, type_str: &String) -> String { +/// For a field whose type +pub fn add_default_value(type_str: &String) -> String { if type_str.starts_with("Optional[") { format!("{} = None", type_str) } else { @@ -226,12 +239,14 @@ pub fn to_python_literal(literal: &LiteralValue) -> String { } trait ToTypeReferenceInTypeDefinition { - fn to_type_ref(&self, ir: &IntermediateRepr) -> String; - fn to_partial_type_ref(&self, ir: &IntermediateRepr, wrapped: bool) -> String; + fn to_type_ref(&self, ir: &IntermediateRepr, module_prefix: bool) -> String; + fn to_partial_type_ref(&self, ir: &IntermediateRepr, wrapped: bool, needed: bool) -> String; } impl ToTypeReferenceInTypeDefinition for FieldType { - fn to_type_ref(&self, ir: &IntermediateRepr) -> String { + // TODO: use_module_prefix boolean blindness. Replace with str? + fn to_type_ref(&self, ir: &IntermediateRepr, use_module_prefix: bool) -> String { + let module_prefix = if use_module_prefix { "types." } else {""}; match self { FieldType::Enum(name) => { if ir @@ -239,24 +254,24 @@ impl ToTypeReferenceInTypeDefinition for FieldType { .map(|e| e.item.attributes.get("dynamic_type").is_some()) .unwrap_or(false) { - format!("Union[\"{name}\", str]") + format!("Union[\"{module_prefix}{name}\", str]") } else { - format!("\"{name}\"") + format!("\"{module_prefix}{name}\"") } } FieldType::RecursiveTypeAlias(name) => format!("\"{name}\""), FieldType::Literal(value) => to_python_literal(value), - FieldType::Class(name) => format!("\"{name}\""), - FieldType::List(inner) => format!("List[{}]", inner.to_type_ref(ir)), + FieldType::Class(name) => format!("\"{module_prefix}{name}\""), + FieldType::List(inner) => format!("List[{}]", inner.to_type_ref(ir, use_module_prefix)), FieldType::Map(key, value) => { - format!("Dict[{}, {}]", key.to_type_ref(ir), value.to_type_ref(ir)) + format!("Dict[{}, {}]", key.to_type_ref(ir, use_module_prefix), value.to_type_ref(ir, use_module_prefix)) } FieldType::Primitive(r#type) => r#type.to_python(), FieldType::Union(inner) => format!( "Union[{}]", inner .iter() - .map(|t| t.to_type_ref(ir)) + .map(|t| t.to_type_ref(ir, use_module_prefix)) .collect::>() .join(", ") ), @@ -264,29 +279,36 @@ impl ToTypeReferenceInTypeDefinition for FieldType { "Tuple[{}]", inner .iter() - .map(|t| t.to_type_ref(ir)) + .map(|t| t.to_type_ref(ir, use_module_prefix)) .collect::>() .join(", ") ), - FieldType::Optional(inner) => format!("Optional[{}]", inner.to_type_ref(ir)), - FieldType::Constrained { base, .. } => match field_type_attributes(self) { + FieldType::Optional(inner) => format!("Optional[{}]", inner.to_type_ref(ir, use_module_prefix)), + FieldType::WithMetadata { base, .. } => match field_type_attributes(self) { Some(checks) => { - let base_type_ref = base.to_type_ref(ir); + let base_type_ref = base.to_type_ref(ir, use_module_prefix); let checks_type_ref = type_name_for_checks(&checks); format!("Checked[{base_type_ref},{checks_type_ref}]") } - None => base.to_type_ref(ir), + None => base.to_type_ref(ir, use_module_prefix), }, } } - fn to_partial_type_ref(&self, ir: &IntermediateRepr, wrapped: bool) -> String { - match self { + + fn to_partial_type_ref(&self, ir: &IntermediateRepr, wrapped: bool, needed: bool) -> String { + let (base_type, metadata) = ir.distribute_metadata(self); + let is_partial_type = !metadata.1.done; + let use_module_prefix = !is_partial_type; + let with_state = metadata.1.state; + let constraints = metadata.0; + let module_prefix = if is_partial_type { "" } else { "types." }; + let base_rep = match base_type { FieldType::Class(name) => { - if wrapped { - format!("\"{name}\"") + if wrapped || needed { + format!("\"{module_prefix}{name}\"") } else { - format!("Optional[\"{name}\"]") + format!("Optional[\"{module_prefix}{name}\"]") } } FieldType::Enum(name) => { @@ -297,7 +319,11 @@ impl ToTypeReferenceInTypeDefinition for FieldType { { format!("Optional[Union[types.{name}, str]]") } else { - format!("Optional[types.{name}]") + if needed { + format!("types.{name}") + } else { + format!("Optional[types.{name}]") + } } } FieldType::RecursiveTypeAlias(name) => { @@ -307,45 +333,70 @@ impl ToTypeReferenceInTypeDefinition for FieldType { format!("Optional[\"{name}\"]") } } - FieldType::Literal(value) => format!("Optional[{}]", to_python_literal(value)), - FieldType::List(inner) => format!("List[{}]", inner.to_partial_type_ref(ir, true)), + FieldType::Literal(value) => format!("Optional[{}]", to_python_literal(value)), // TODO: Handle `needed` here. + FieldType::List(inner) => format!("List[{}]", inner.to_partial_type_ref(ir, true, false)), FieldType::Map(key, value) => { format!( "Dict[{}, {}]", - key.to_type_ref(ir), - value.to_partial_type_ref(ir, false) + key.to_type_ref(ir, use_module_prefix), + value.to_partial_type_ref(ir, false, false) ) } - FieldType::Primitive(r#type) => format!("Optional[{}]", r#type.to_python()), - FieldType::Union(inner) => format!( - "Optional[Union[{}]]", + FieldType::Primitive(r#type) => { + if needed { + r#type.to_python() + } else { + format!("Optional[{}]", r#type.to_python()) + } + }, + FieldType::Union(inner) => { + let union_contents = inner .iter() - .map(|t| t.to_partial_type_ref(ir, true)) + .map(|t| t.to_partial_type_ref(ir, true, false)) .collect::>() - .join(", ") - ), - FieldType::Tuple(inner) => format!( - "Optional[Tuple[{}]]", + .join(", "); + if needed { + format!("Union[{union_contents}]") + } else { + format!("Optional[Union[{union_contents}]]") + } + }, + FieldType::Tuple(inner) => { + let tuple_contents = inner .iter() - .map(|t| t.to_partial_type_ref(ir, false)) + .map(|t| t.to_partial_type_ref(ir, false, false)) .collect::>() - .join(", ") - ), - FieldType::Optional(inner) => inner.to_partial_type_ref(ir, false), - FieldType::Constrained { base, .. } => { - let base_type_ref = base.to_partial_type_ref(ir, false); - match field_type_attributes(self) { - Some(checks) => { - let base_type_ref = base.to_partial_type_ref(ir, false); - let checks_type_ref = type_name_for_checks(&checks); - format!("Checked[{base_type_ref},{checks_type_ref}]") - } - None => base_type_ref, + .join(", "); + if needed { format!("Tuple[{tuple_contents}]") } else { format!("Optional[Tuple[{tuple_contents}]]") } + }, + FieldType::Optional(inner) => inner.to_partial_type_ref(ir, false, false), + FieldType::WithMetadata{..} => unreachable!("distribute_metadata makes this branch unreachable."), + }; + let base_type_ref = if is_partial_type { + base_rep + } else { + if needed { + base_type.to_type_ref(ir, use_module_prefix) + } else { + base_rep } - } + }; + let rep_with_checks = match field_type_attributes(self) { + Some(checks) => { + let checks_type_ref = type_name_for_checks(&checks); + format!("Checked[{base_type_ref},{checks_type_ref}]") + }, + None => base_type_ref + }; + let rep_with_stream_state = if with_state { + stream_state(&rep_with_checks) + } else { + rep_with_checks + }; + rep_with_stream_state } } @@ -355,3 +406,12 @@ fn render_docstring(d: &Docstring) -> String { let lines = d.0.as_str().replace("\n", "\n "); format!("\"\"\"{lines}\"\"\"") } + +fn optional(base: &str) -> String { + format!("Optional[{base}]") +} + +fn stream_state(base: &str) -> String { + format!("StreamState[{base}]") +} + diff --git a/engine/language_client_codegen/src/python/mod.rs b/engine/language_client_codegen/src/python/mod.rs index d2e0e7e91e..a9ee02ac97 100644 --- a/engine/language_client_codegen/src/python/mod.rs +++ b/engine/language_client_codegen/src/python/mod.rs @@ -231,7 +231,7 @@ impl ToTypeReferenceInClientDefinition for FieldType { FieldType::Optional(inner) => { format!("Optional[{}]", inner.to_type_ref(ir, _with_checked)) } - FieldType::Constrained { base, .. } => match field_type_attributes(self) { + FieldType::WithMetadata { base, .. } => match field_type_attributes(self) { Some(checks) => { let base_type_ref = base.to_type_ref(ir, _with_checked); let checks_type_ref = type_name_for_checks(&checks); @@ -286,7 +286,7 @@ impl ToTypeReferenceInClientDefinition for FieldType { .join(", ") ), FieldType::Optional(inner) => inner.to_partial_type_ref(ir, with_checked), - FieldType::Constrained { base, .. } => match field_type_attributes(self) { + FieldType::WithMetadata { base, .. } => match field_type_attributes(self) { Some(checks) => { let base_type_ref = base.to_partial_type_ref(ir, with_checked); let checks_type_ref = type_name_for_checks(&checks); @@ -297,3 +297,60 @@ impl ToTypeReferenceInClientDefinition for FieldType { } } } + +#[cfg(test)] +mod tests { + use internal_baml_core::ir::repr::make_test_ir; + + use crate::GeneratorArgs; + + use super::*; + + fn mk_ir() -> IntermediateRepr { + make_test_ir(r#" +class Greg { + inner Foo? @stream.not_null @stream.with_state @check(foo, {{ true }}) +} + +class Foo { + s string +} + +// class Foo { +// i int @stream.not_null @stream.with_state +// b Bar @stream.done +// } + +// class Foo { +// str string @stream.with_state +// } +// +// class Inner { +// inner_int int +// inner_string string @stream.not_null +// inner_string_2 string @stream.not_null @stream.done +// } +// +// class InnerDone { +// inner_done_inner Inner @stream.done +// inner_done_int int +// inner_done_str string +// @@stream.done +// } + "#).unwrap() + } + + fn mk_gen() -> GeneratorArgs { + GeneratorArgs::new("baml_client", "baml_src", vec![], "no_version".to_string(), true, GeneratorDefaultClientMode::Async, Vec::new()).unwrap() + } + + #[test] + fn generate_streaming_python() { + let ir = mk_ir(); + let generator_args = mk_gen(); + let res = generate(&ir, &generator_args).unwrap(); + let partial_types = res.get(&PathBuf::from("partial_types.py")).unwrap(); + eprintln!("{}", partial_types); + assert!(false); + } +} diff --git a/engine/language_client_codegen/src/python/templates/async_client.py.j2 b/engine/language_client_codegen/src/python/templates/async_client.py.j2 index eb904b9651..36c75fda0e 100644 --- a/engine/language_client_codegen/src/python/templates/async_client.py.j2 +++ b/engine/language_client_codegen/src/python/templates/async_client.py.j2 @@ -60,7 +60,7 @@ class BamlAsyncClient: tb, __cr__, ) - return cast({{fn.return_type}}, raw.cast_to(types, types)) + return cast({{fn.return_type}}, raw.cast_to(types, types, partial_types, False)) {% endfor %} @@ -102,8 +102,8 @@ class BamlStreamClient: return baml_py.BamlStream[{{ fn.partial_return_type }}, {{ fn.return_type }}]( raw, - lambda x: cast({{fn.partial_return_type}}, x.cast_to(types, partial_types)), - lambda x: cast({{fn.return_type}}, x.cast_to(types, types)), + lambda x: cast({{fn.partial_return_type}}, x.cast_to(types, types, partial_types, True)), + lambda x: cast({{fn.return_type}}, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) {% endfor %} diff --git a/engine/language_client_codegen/src/python/templates/partial_types.py.j2 b/engine/language_client_codegen/src/python/templates/partial_types.py.j2 index f3d75e3809..456938fcb7 100644 --- a/engine/language_client_codegen/src/python/templates/partial_types.py.j2 +++ b/engine/language_client_codegen/src/python/templates/partial_types.py.j2 @@ -2,7 +2,7 @@ import baml_py from enum import Enum from pydantic import BaseModel, ConfigDict -from typing import Dict, List, Optional, Union, Literal +from typing import Dict, Generic, List, Optional, TypeVar, Union, Literal from . import types from .types import Checked, Check @@ -14,6 +14,11 @@ from .types import Checked, Check # ############################################################################### +T = TypeVar('T') +class StreamState(BaseModel, Generic[T]): + value: T + state: Literal["Pending", "Incomplete", "Complete"] + {# Partial classes (used for streaming) -#} {% for cls in partial_classes %} class {{cls.name}}(BaseModel): diff --git a/engine/language_client_codegen/src/python/templates/sync_client.py.j2 b/engine/language_client_codegen/src/python/templates/sync_client.py.j2 index c5aa6de619..7d7f94d926 100644 --- a/engine/language_client_codegen/src/python/templates/sync_client.py.j2 +++ b/engine/language_client_codegen/src/python/templates/sync_client.py.j2 @@ -57,7 +57,7 @@ class BamlSyncClient: tb, __cr__, ) - return cast({{fn.return_type}}, raw.cast_to(types, types)) + return cast({{fn.return_type}}, raw.cast_to(types, types, partial_types, False)) {% endfor %} @@ -100,8 +100,8 @@ class BamlStreamClient: return baml_py.BamlSyncStream[{{ fn.partial_return_type }}, {{ fn.return_type }}]( raw, - lambda x: cast({{fn.partial_return_type}}, x.cast_to(types, partial_types)), - lambda x: cast({{fn.return_type}}, x.cast_to(types, types)), + lambda x: cast({{fn.partial_return_type}}, x.cast_to(types, types, partial_types, True)), + lambda x: cast({{fn.return_type}}, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) {% endfor %} diff --git a/engine/language_client_codegen/src/python/templates/types.py.j2 b/engine/language_client_codegen/src/python/templates/types.py.j2 index bbc7a7e9ab..997ad8f009 100644 --- a/engine/language_client_codegen/src/python/templates/types.py.j2 +++ b/engine/language_client_codegen/src/python/templates/types.py.j2 @@ -23,7 +23,6 @@ def get_checks(checks: Dict[CheckName, Check]) -> List[Check]: def all_succeeded(checks: Dict[CheckName, Check]) -> bool: return all(check.status == "succeeded" for check in get_checks(checks)) - {# Enums -#} {% for enum in enums %} class {{enum.name}}(str, Enum): diff --git a/engine/language_client_codegen/src/ruby/field_type.rs b/engine/language_client_codegen/src/ruby/field_type.rs index c6cdba5908..89dc7756f3 100644 --- a/engine/language_client_codegen/src/ruby/field_type.rs +++ b/engine/language_client_codegen/src/ruby/field_type.rs @@ -57,7 +57,7 @@ impl ToRuby for FieldType { .join(", ") ), FieldType::Optional(inner) => format!("T.nilable({})", inner.to_ruby()), - FieldType::Constrained { base, .. } => match field_type_attributes(self) { + FieldType::WithMetadata { base, .. } => match field_type_attributes(self) { Some(_) => { let base_type_ref = base.to_ruby(); format!("Baml::Checked[{base_type_ref}]") diff --git a/engine/language_client_codegen/src/ruby/generate_types.rs b/engine/language_client_codegen/src/ruby/generate_types.rs index 495b7e9242..cdcef8c681 100644 --- a/engine/language_client_codegen/src/ruby/generate_types.rs +++ b/engine/language_client_codegen/src/ruby/generate_types.rs @@ -10,7 +10,7 @@ use crate::{field_type_attributes, type_check_attributes, TypeCheckAttributes}; use super::ruby_language_features::ToRuby; use internal_baml_core::ir::{ repr::{Docstring, IntermediateRepr}, - ClassWalker, EnumWalker, FieldType, + ClassWalker, EnumWalker, FieldType, IRHelper, }; #[derive(askama::Template)] @@ -136,9 +136,21 @@ impl<'ir> From> for PartialRubyStruct<'ir> { .static_fields .iter() .map(|f| { + let not_null: bool = f.attributes.get("stream.not_null").is_some(); + let (_, metadata) = c.db.distribute_metadata(&f.elem.r#type.elem); + let done = metadata.1.done; + let field_type = f.elem.r#type.elem.clone(); + let generated_field_type = match (done, not_null) { + (false, false) => { + format!("{}", field_type.to_partial_type_ref(c.db, false)) + } + (true, false) => format!("T.nilable({})", field_type.to_type_ref()), + (false, true) => field_type.to_partial_type_ref(c.db, true), + (true, true) => field_type.to_type_ref(), + }; ( f.elem.name.as_str(), - f.elem.r#type.elem.to_partial_type_ref(), + generated_field_type, f.elem.docstring.as_ref().map(|d| render_docstring(d, true)), ) }) @@ -153,27 +165,44 @@ impl<'ir> From> for PartialRubyStruct<'ir> { } } -pub(super) trait ToTypeReferenceInTypeDefinition { +pub(super) trait ToTypeReferenceInTypeDefinition<'a> { fn to_type_ref(&self) -> String; - fn to_partial_type_ref(&self) -> String; + fn to_partial_type_ref(&self, ir: &'a IntermediateRepr, already_nilable: bool) -> String; } -impl ToTypeReferenceInTypeDefinition for FieldType { +impl ToTypeReferenceInTypeDefinition<'_> for FieldType { fn to_type_ref(&self) -> String { use ToRuby; self.to_ruby() } - fn to_partial_type_ref(&self) -> String { - match self { - FieldType::Class(name) => format!("Baml::PartialTypes::{}", name.clone()), - FieldType::Enum(name) => format!("T.nilable(Baml::Types::{})", name.clone()), + /// Render a type into a string for use in a partial-types context. + /// The `already_nilable` field indicates whether the caller will wrap + /// the returned string with `nilable`, and this function does not need + fn to_partial_type_ref(&self, ir: &IntermediateRepr, already_nilable: bool) -> String { + + let (field_type, metadata) = ir.distribute_metadata(self); + let inner = match field_type { + FieldType::Class(name) => if already_nilable { + format!("Baml::PartialTypes::{}", name.clone()) + } else { + format!("T.nilable(Baml::PartialTypes::{})", name.clone()) + }, + FieldType::Enum(name) => { + if already_nilable { + format!("T.nilable(Baml::Types::{})", name.clone()) + } else { + format!("T.nilable(Baml::Types::{})", name.clone()) + } + } // TODO: Can we define recursive aliases in Ruby with Sorbet? FieldType::RecursiveTypeAlias(_name) => "T.anything".to_string(), // TODO: Temporary solution until we figure out Ruby literals. - FieldType::Literal(value) => value.literal_base_type().to_partial_type_ref(), + FieldType::Literal(value) => value + .literal_base_type() + .to_partial_type_ref(ir, already_nilable), // https://sorbet.org/docs/stdlib-generics - FieldType::List(inner) => format!("T::Array[{}]", inner.to_partial_type_ref()), + FieldType::List(inner) => format!("T::Array[{}]", inner.to_partial_type_ref(ir, false)), FieldType::Map(key, value) => format!( "T::Hash[{}, {}]", match key.as_ref() { @@ -183,36 +212,51 @@ impl ToTypeReferenceInTypeDefinition for FieldType { | FieldType::Union(_) => FieldType::string().to_type_ref(), _ => key.to_type_ref(), }, - value.to_partial_type_ref() + value.to_partial_type_ref(ir, false) ), - FieldType::Primitive(_) => format!("T.nilable({})", self.to_type_ref()), - FieldType::Union(inner) => format!( + FieldType::Primitive(_) => { + if already_nilable { + self.to_type_ref() + } else { + format!("T.nilable({})", self.to_type_ref()) + } + } + FieldType::Union(inner) => { + let inner_string = // https://sorbet.org/docs/union-types - "T.nilable(T.any({}))", inner .iter() - .map(|t| t.to_partial_type_ref()) + .map(|t| t.to_partial_type_ref(ir, false)) .collect::>() .join(", ") - ), - FieldType::Tuple(inner) => format!( + ; + if already_nilable { format!("T.any({inner_string})") } else { + format!("T.nilable(T.any({}))", inner_string) + } + }, + FieldType::Tuple(inner) => { + let inner_string = // https://sorbet.org/docs/tuples - "T.nilable([{}])", inner .iter() - .map(|t| t.to_partial_type_ref()) + .map(|t| t.to_partial_type_ref(ir, false)) .collect::>() .join(", ") - ), - FieldType::Optional(inner) => inner.to_partial_type_ref(), - FieldType::Constrained { base, .. } => match field_type_attributes(self) { + ; + if already_nilable { format!("[{}]", inner_string )} else { format!("T.nilable([{}])", inner_string)} + }, + FieldType::Optional(inner) => inner.to_partial_type_ref(ir, false), + FieldType::WithMetadata { base, .. } => match field_type_attributes(self) { Some(checks) => { - let base_type_ref = base.to_partial_type_ref(); + let base_type_ref = base.to_partial_type_ref(ir, false); format!("Baml::Checked[{base_type_ref}]") } - None => base.to_partial_type_ref(), + None => base.to_partial_type_ref(ir, false), }, - } + }; + if metadata.1.state { + format!("Baml::StreamState[{inner}]") + } else { inner } } } diff --git a/engine/language_client_codegen/src/ruby/mod.rs b/engine/language_client_codegen/src/ruby/mod.rs index f0793610df..867ac72b94 100644 --- a/engine/language_client_codegen/src/ruby/mod.rs +++ b/engine/language_client_codegen/src/ruby/mod.rs @@ -9,7 +9,7 @@ use anyhow::Result; use indexmap::IndexMap; use ruby_language_features::ToRuby; -use internal_baml_core::ir::repr::IntermediateRepr; +use internal_baml_core::ir::{repr::IntermediateRepr, IRHelper}; use crate::dir_writer::FileCollector; @@ -62,9 +62,18 @@ impl<'ir> TryFrom<(&'ir IntermediateRepr, &'ir crate::GeneratorArgs)> for RubyCl let funcs = configs .map(|c| { let (_function, _impl_) = c.item; + let return_type = f.elem().output(); + let done = ir.distribute_metadata(&return_type).1.1.done; + let state = ir.distribute_metadata(&return_type).1.1.state; + let partial_return_type = match (done, state) { + (false, false) => return_type.to_partial_type_ref(ir, true), + (true, false) => format!("T.nilable({})", return_type.to_type_ref()), + (false, true) => format!("Baml::StreamState[T.nilable({})]", return_type.to_partial_type_ref(ir, true)), + (true, true) => format!("Baml::StreamState[T.nilable({})]", return_type.to_type_ref()), + }; Ok(RubyFunction { name: f.name().to_string(), - partial_return_type: f.elem().output().to_partial_type_ref(), + partial_return_type: f.elem().output().to_partial_type_ref(ir, false), return_type: f.elem().output().to_ruby(), args: f .inputs() diff --git a/engine/language_client_codegen/src/ruby/templates/client.rb.j2 b/engine/language_client_codegen/src/ruby/templates/client.rb.j2 index afc7e40622..0d92f310ed 100644 --- a/engine/language_client_codegen/src/ruby/templates/client.rb.j2 +++ b/engine/language_client_codegen/src/ruby/templates/client.rb.j2 @@ -72,7 +72,7 @@ module Baml baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end {% endfor %} diff --git a/engine/language_client_codegen/src/typescript/generate_types.rs b/engine/language_client_codegen/src/typescript/generate_types.rs index 195764bc5b..b65c5205e0 100644 --- a/engine/language_client_codegen/src/typescript/generate_types.rs +++ b/engine/language_client_codegen/src/typescript/generate_types.rs @@ -5,7 +5,7 @@ use itertools::Itertools; use internal_baml_core::ir::{ repr::{Docstring, IntermediateRepr, Walker}, - ClassWalker, EnumWalker, FieldType, + ClassWalker, EnumWalker, FieldType, IRHelper, }; use crate::{type_check_attributes, GeneratorArgs, TypeCheckAttributes}; @@ -27,6 +27,12 @@ pub(crate) struct TypescriptTypes<'ir> { structural_recursive_alias_cycles: Vec>, } +#[derive(askama::Template)] +#[template(path = "partial_types.ts.j2", escape = "none")] +pub(crate) struct TypescriptStreamTypes<'ir> { + partial_classes: Vec>, +} + struct TypescriptEnum<'ir> { pub name: &'ir str, pub values: Vec<(&'ir str, Option)>, @@ -46,6 +52,13 @@ struct TypescriptTypeAlias<'ir> { target: String, } +pub struct PartialTypescriptClass<'ir> { + name: Cow<'ir, str>, + fields: Vec<(Cow<'ir, str>, bool, String, Option)>, + dynamic: bool, + docstring: Option, +} + impl<'ir> TryFrom<(&'ir IntermediateRepr, &'ir GeneratorArgs)> for TypescriptTypes<'ir> { type Error = anyhow::Error; @@ -69,6 +82,21 @@ impl<'ir> TryFrom<(&'ir IntermediateRepr, &'ir GeneratorArgs)> for TypescriptTyp } } +impl<'ir> TryFrom<(&'ir IntermediateRepr, &'ir GeneratorArgs)> for TypescriptStreamTypes<'ir> { + type Error = anyhow::Error; + + fn try_from( + (ir, _): (&'ir IntermediateRepr, &'ir GeneratorArgs), + ) -> Result> { + Ok(TypescriptStreamTypes { + partial_classes: ir + .walk_classes() + .map(|e| Into::::into(e)) + .collect::>(), + }) + } +} + impl<'ir> TryFrom<(&'ir IntermediateRepr, &'ir GeneratorArgs)> for TypeBuilder<'ir> { type Error = anyhow::Error; @@ -127,7 +155,7 @@ impl<'ir> From<&ClassWalker<'ir>> for TypescriptClass<'ir> { ( Cow::Borrowed(f.elem.name.as_str()), f.elem.r#type.elem.is_optional(), - f.elem.r#type.elem.to_type_ref(c.db), + f.elem.r#type.elem.to_type_ref(c.db, false), f.elem.docstring.as_ref().map(|d| render_docstring(d, true)), ) }) @@ -152,7 +180,46 @@ impl<'ir> From> for TypescriptTypeAli ) -> Self { Self { name: Cow::Borrowed(name), - target: target.to_type_ref(db), + target: target.to_type_ref(db, false), + } + } +} + +impl<'ir> From> for PartialTypescriptClass<'ir> { + fn from(c: ClassWalker<'ir>) -> PartialTypescriptClass<'ir> { + PartialTypescriptClass { + name: Cow::Borrowed(c.name()), + dynamic: c.item.attributes.get("dynamic_type").is_some(), + fields: c + .item + .elem + .static_fields + .iter() + .map(|f| { + let ir = c.db; + let needed: bool = f.attributes.get("stream.not_null").is_some(); + let (_, metadata) = ir.distribute_metadata(&f.elem.r#type.elem); + let done: bool = metadata.1.done; + let (field, optional) = match (done, needed) { + (false, false) => f.elem.r#type.elem.to_partial_type_ref(c.db, false), + (true, false) => (f.elem.r#type.elem.to_type_ref(c.db, true), false), + (false, true) => f.elem.r#type.elem.to_partial_type_ref(c.db, true), + (true, true) => (f.elem.r#type.elem.to_type_ref(c.db, true), false), + }; + ( + Cow::Borrowed(f.elem.name.as_str()), + optional, + field, + f.elem.docstring.as_ref().map(|d| render_docstring(d, true)), + ) + }) + .collect(), + docstring: c + .item + .elem + .docstring + .as_ref() + .map(|d| render_docstring(d, false)), } } } @@ -169,7 +236,7 @@ pub fn type_name_for_checks(checks: &TypeCheckAttributes) -> String { /// Render the BAML documentation (a bare string with padding stripped) /// into a TS docstring. /// (Optionally indented and formatted as a TS block comment). -fn render_docstring(d: &Docstring, indented: bool) -> String { +pub fn render_docstring(d: &Docstring, indented: bool) -> String { if indented { let lines = d.0.as_str().replace("\n", "\n * "); format!("/**\n * {lines}\n */") diff --git a/engine/language_client_codegen/src/typescript/mod.rs b/engine/language_client_codegen/src/typescript/mod.rs index e3178534ff..7d02ae8d72 100644 --- a/engine/language_client_codegen/src/typescript/mod.rs +++ b/engine/language_client_codegen/src/typescript/mod.rs @@ -5,7 +5,7 @@ use std::path::PathBuf; use anyhow::Result; use baml_types::LiteralValue; -use generate_types::type_name_for_checks; +use generate_types::{render_docstring, type_name_for_checks}; use indexmap::IndexMap; use internal_baml_core::{ configuration::GeneratorDefaultClientMode, @@ -55,8 +55,8 @@ impl From for SyncTypescriptClient { #[derive(Debug)] struct TypescriptFunction { name: String, - // partial_return_type: String, return_type: String, + partial_return_type: String, args: Vec<(String, bool, String)>, } @@ -88,6 +88,10 @@ pub(crate) fn generate( ) -> Result> { let mut collector = FileCollector::::new(); collector.add_template::("types.ts", (ir, generator))?; + collector.add_template::( + "partial_types.ts", + (ir, generator), + )?; collector.add_template::("type_builder.ts", (ir, generator))?; collector.add_template::("async_client.ts", (ir, generator))?; collector.add_template::("sync_client.ts", (ir, generator))?; @@ -131,8 +135,8 @@ impl TryFrom<(&'_ IntermediateRepr, &'_ crate::GeneratorArgs)> for TypescriptCli let (_function, _impl_) = c.item; Ok(TypescriptFunction { name: f.name().to_string(), - return_type: f.elem().output().to_type_ref(ir), - // partial_return_type: f.elem().output().to_partial_type_ref(ir), + return_type: f.elem().output().to_type_ref(ir, false), + partial_return_type: f.elem().output().to_partial_type_ref(ir, true).0, args: f .inputs() .iter() @@ -140,7 +144,7 @@ impl TryFrom<(&'_ IntermediateRepr, &'_ crate::GeneratorArgs)> for TypescriptCli ( name.to_string(), r#type.is_optional(), - r#type.to_type_ref(ir), + r#type.to_type_ref(ir, false), ) }) .collect(), @@ -203,57 +207,132 @@ impl TryFrom<(&'_ IntermediateRepr, &'_ crate::GeneratorArgs)> for TypescriptIni } trait ToTypeReferenceInClientDefinition { - fn to_type_ref(&self, ir: &IntermediateRepr) -> String; - - // fn to_partial_type_ref(&self, ir: &IntermediateRepr) -> String; + fn to_type_ref(&self, ir: &IntermediateRepr, use_module_prefix: bool) -> String; + /// The string representation of a field type, and whether the field is optional. + fn to_partial_type_ref(&self, ir: &IntermediateRepr, needed: bool) -> (String, bool); } impl ToTypeReferenceInClientDefinition for FieldType { - // fn to_partial_type_ref(&self, ir: &IntermediateRepr) -> String { - // match self { - // FieldType::Enum(name) => { - // if ir - // .find_enum(name) - // .map(|e| e.item.attributes.get("dynamic_type").is_some()) - // .unwrap_or(false) - // { - // format!("(string | {name} | null)") - // } else { - // format!("({name} | null)") - // } - // } - // FieldType::Class(name) => format!("(RecursivePartialNull<{name}>)"), - // FieldType::List(inner) => format!("{}[]", inner.to_partial_type_ref(ir)), - // FieldType::Map(key, value) => { - // format!( - // "(Record<{}, {}> | null)", - // key.to_type_ref(ir), - // value.to_partial_type_ref(ir) - // ) - // } - // FieldType::Literal(value) => value.to_string(), - // FieldType::Primitive(r#type) => format!("({} | null)", r#type.to_typescript()), - // FieldType::Union(inner) => format!( - // "({} | null)", - // inner - // .iter() - // .map(|t| t.to_partial_type_ref(ir)) - // .collect::>() - // .join(" | ") - // ), - // FieldType::Tuple(inner) => format!( - // "([{}] | null)", - // inner - // .iter() - // .map(|t| t.to_partial_type_ref(ir)) - // .collect::>() - // .join(", ") - // ), - // FieldType::Optional(inner) => format!("({} | null)", inner.to_partial_type_ref(ir)), - // } - // } - - fn to_type_ref(&self, ir: &IntermediateRepr) -> String { + /// How to serialize a type for use in a function's type signature. + fn to_partial_type_ref(&self, ir: &IntermediateRepr, needed: bool) -> (String, bool) { + let (base_type, metadata) = ir.distribute_metadata(self); + let is_partial_type = !metadata.1.done; + let use_module_prefix = !is_partial_type; + let with_state = metadata.1.state; + let constraints = metadata.0; + let module_prefix = if use_module_prefix { "types." } else { "partial_types." }; + let (base_rep, optional) = match base_type { + FieldType::Class(name) => { + if needed { + (format!("{module_prefix}{name}"), false) + } else { + (format!("{module_prefix}{name} | null"), true) + } + } + FieldType::RecursiveTypeAlias(name) => (name.to_owned(), !needed), + FieldType::Enum(name) => { + let res = if ir + .find_enum(name) + .map(|e| e.item.attributes.get("dynamic_type").is_some()) + .unwrap_or(false) + { + if needed { + (format!("(string | {name})"), false) + } else { + (format!("(string | {name} | null)"), true) + } + } else { + if needed { + (format!("types.{name}"), false) + } else { + (format!("({name} | null)"), true) + } + }; + res + } + FieldType::Literal(value) => { + (value.to_string(), false) + } + FieldType::List(inner) => ( + format!("{}[]", inner.to_partial_type_ref(ir, false).0), + true, + ), + FieldType::Map(key, value) => { + let or_null = if needed { "" } else { "| null" }; + ( + format!( + "(Record<{}, {}> {or_null})", + key.to_type_ref(ir, false), + value.to_partial_type_ref(ir, false).0 + ), + !needed, + ) + } + FieldType::Primitive(r#type) => { + if needed { + (r#type.to_typescript(), false) + } else { + (format!("({} | null)", r#type.to_typescript()), true) + } + } + FieldType::Union(inner) => { + let union_contents = inner + .iter() + .map(|t| t.to_partial_type_ref(ir, false).0) + .collect::>() + .join(" | "); + if needed { + (format!("({})", union_contents), false) + } else { + (format!("({} | null)", union_contents), true) + } + } + FieldType::Tuple(inner) => { + let tuple_contents = inner + .iter() + .map(|t| t.to_partial_type_ref(ir, false).0) + .collect::>() + .join(", "); + if needed { + (format!("[{tuple_contents}]"), false) + } else { + (format!("([{tuple_contents}] | null)"), true) + } + } + FieldType::Optional(inner) => ( + format!("({} | null)", inner.to_partial_type_ref(ir, false).0), + false, + ), + FieldType::WithMetadata { .. } => { + unreachable!("distribute_metadata makes this field unreachable.") + } + }; + let base_type_ref = if is_partial_type { + base_rep + } else { + if needed { + base_type.to_type_ref(ir, use_module_prefix) + } else { + base_rep + } + }; + let rep_with_checks = match field_type_attributes(self) { + Some(checks) => { + let checks_type_ref = type_name_for_checks(&checks); + format!("Checked<{base_type_ref},{checks_type_ref}>") + } + None => base_type_ref, + }; + let rep_with_stream_state = if with_state { + format!("StreamState<{rep_with_checks}>") + } else { + rep_with_checks + }; + (rep_with_stream_state, optional) + } + + fn to_type_ref(&self, ir: &IntermediateRepr, use_module_prefix: bool) -> String { + let module_prefix = if use_module_prefix { "types." } else { "" }; match self { FieldType::Enum(name) => { if ir @@ -261,22 +340,22 @@ impl ToTypeReferenceInClientDefinition for FieldType { .map(|e| e.item.attributes.get("dynamic_type").is_some()) .unwrap_or(false) { - format!("(string | {name})") + format!("(string | {module_prefix}{name})") } else { - name.to_string() + format!("{module_prefix}{name}") } } - FieldType::Class(name) => name.to_string(), FieldType::RecursiveTypeAlias(name) => name.to_owned(), + FieldType::Class(name) => format!("{module_prefix}{name}"), FieldType::List(inner) => match inner.as_ref() { FieldType::Union(_) | FieldType::Optional(_) => { - format!("({})[]", inner.to_type_ref(ir)) + format!("({})[]", inner.to_type_ref(ir, use_module_prefix)) } - _ => format!("{}[]", inner.to_type_ref(ir)), + _ => format!("{}[]", inner.to_type_ref(ir, use_module_prefix)), }, FieldType::Map(key, value) => { - let k = key.to_type_ref(ir); - let v = value.to_type_ref(ir); + let k = key.to_type_ref(ir, true); + let v = value.to_type_ref(ir, use_module_prefix); match key.as_ref() { FieldType::Enum(_) @@ -292,7 +371,7 @@ impl ToTypeReferenceInClientDefinition for FieldType { FieldType::Literal(value) => value.to_string(), FieldType::Union(inner) => inner .iter() - .map(|t| t.to_type_ref(ir)) + .map(|t| t.to_type_ref(ir, use_module_prefix)) .collect::>() .join(" | ") .to_string(), @@ -300,19 +379,105 @@ impl ToTypeReferenceInClientDefinition for FieldType { "[{}]", inner .iter() - .map(|t| t.to_type_ref(ir)) + .map(|t| t.to_type_ref(ir, use_module_prefix)) .collect::>() .join(", ") ), - FieldType::Optional(inner) => format!("{} | null", inner.to_type_ref(ir)), - FieldType::Constrained { base, .. } => match field_type_attributes(self) { + FieldType::Optional(inner) => { + format!("{} | null", inner.to_type_ref(ir, use_module_prefix)) + } + FieldType::WithMetadata { base, .. } => match field_type_attributes(self) { Some(checks) => { - let base_type_ref = base.to_type_ref(ir); + let base_type_ref = base.to_type_ref(ir, use_module_prefix); let checks_type_ref = type_name_for_checks(&checks); format!("Checked<{base_type_ref},{checks_type_ref}>") } - None => base.to_type_ref(ir), + None => base.to_type_ref(ir, use_module_prefix), }, } } } + +#[cfg(test)] +mod tests { + use internal_baml_core::ir::repr::make_test_ir; + + use crate::GeneratorArgs; + + use super::*; + + fn mk_ir() -> IntermediateRepr { + make_test_ir( + r##" +class Greg { + inner Foo? @stream.done @stream.not_null @stream.with_state @check(foo, {{ true }}) +} + +class Foo { + s string +} + +client GPT4 { + provider openai + options { + model gpt-4o + api_key env.OPENAI_API_KEY + } +} + +function MkFoo() -> Foo { + client GPT4 + prompt #""# +} + +// class Foo { +// i int @stream.not_null @stream.with_state +// b Bar @stream.done +// } + +// class Foo { +// str string @stream.with_state +// } +// +// class Inner { +// inner_int int +// inner_string string @stream.not_null +// inner_string_2 string @stream.not_null @stream.done +// } +// +// class InnerDone { +// inner_done_inner Inner @stream.done +// inner_done_int int +// inner_done_str string +// @@stream.done +// } + "##, + ) + .unwrap() + } + + fn mk_gen() -> GeneratorArgs { + GeneratorArgs::new( + "baml_client", + "baml_src", + vec![], + "no_version".to_string(), + true, + GeneratorDefaultClientMode::Async, + Vec::new(), + ) + .unwrap() + } + + #[test] + fn generate_streaming_typescript() { + let ir = mk_ir(); + let generator_args = mk_gen(); + let res = generate(&ir, &generator_args).unwrap(); + let partial_types = res.get(&PathBuf::from("partial_types.ts")).unwrap(); + eprintln!("{}", partial_types); + let async_client = res.get(&PathBuf::from("async_client.ts")).unwrap(); + eprintln!("{}", async_client); + assert!(false); + } +} diff --git a/engine/language_client_codegen/src/typescript/templates/async_client.ts.j2 b/engine/language_client_codegen/src/typescript/templates/async_client.ts.j2 index b7c28ea107..a8d19eb4fc 100644 --- a/engine/language_client_codegen/src/typescript/templates/async_client.ts.j2 +++ b/engine/language_client_codegen/src/typescript/templates/async_client.ts.j2 @@ -1,17 +1,12 @@ import { BamlRuntime, FunctionResult, BamlCtxManager, BamlStream, Image, ClientRegistry, BamlValidationError, createBamlValidationError } from "@boundaryml/baml" import { Checked, Check } from "./types" +import "./partial_types" import { {%- for t in types %}{{ t }}{% if !loop.last %}, {% endif %}{% endfor -%} } from "./types" import TypeBuilder from "./type_builder" import { DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX, DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_RUNTIME } from "./globals" -export type RecursivePartialNull = T extends object - ? { - [P in keyof T]?: RecursivePartialNull; - } - : T | null; - export class BamlAsyncClient { private runtime: BamlRuntime private ctx_manager: BamlCtxManager @@ -46,7 +41,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as {{fn.return_type}} + return raw.parsed(false) as {{fn.return_type}} } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -68,7 +63,7 @@ class BamlStreamClient { {{name}}{% if optional %}?{% endif %}: {{type}}, {%- endfor %} __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, {{ fn.return_type }}> { + ): BamlStream<{{ fn.partial_return_type }}, {{ fn.return_type }}> { try { const raw = this.runtime.streamFunction( "{{fn.name}}", @@ -82,9 +77,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, {{ fn.return_type }}>( + return new BamlStream<{{ fn.partial_return_type }}, {{ fn.return_type }}>( raw, - (a): a is RecursivePartialNull<{{ fn.return_type }}> => a, + (a): a is {{ fn.partial_return_type }} => a, (a): a is {{ fn.return_type }} => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), diff --git a/engine/language_client_codegen/src/typescript/templates/partial_types.ts.j2 b/engine/language_client_codegen/src/typescript/templates/partial_types.ts.j2 new file mode 100644 index 0000000000..c79b876c3f --- /dev/null +++ b/engine/language_client_codegen/src/typescript/templates/partial_types.ts.j2 @@ -0,0 +1,35 @@ +import { Image } from "@boundaryml/baml" + +import * as types from "./types" + +/****************************************************************************** +* +* These types are used for streaming, for when an instance of a type +* is still being built up and any of its fields is not yet fully available. +* +******************************************************************************/ + +export interface StreamState { + value: T + state: "Pending" | "Incomplete" | "Complete" +} + +{# Partial classes (used for streaming) -#} +{% for cls in partial_classes %} +{%- if let Some(docstring) = cls.docstring %} +{{docstring}} +{%- endif %} +export interface {{cls.name}} { + + {%- for (name, optional, type, m_docstring) in cls.fields %} + {%- if let Some(docstring) = m_docstring %} + {{ docstring }} + {%- endif %} + {{name}}{% if optional %}?{% endif%}: {{type}} + {%- endfor %} + + {%- if cls.dynamic %} + [key: string]: any; + {%- endif %} +} +{% endfor %} diff --git a/engine/language_client_codegen/src/typescript/templates/sync_client.ts.j2 b/engine/language_client_codegen/src/typescript/templates/sync_client.ts.j2 index 1a6b1b41bf..170a411cab 100644 --- a/engine/language_client_codegen/src/typescript/templates/sync_client.ts.j2 +++ b/engine/language_client_codegen/src/typescript/templates/sync_client.ts.j2 @@ -6,12 +6,6 @@ import { import TypeBuilder from "./type_builder" import { DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX, DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_RUNTIME } from "./globals" -export type RecursivePartialNull = T extends object - ? { - [P in keyof T]?: RecursivePartialNull; - } - : T | null; - export class BamlSyncClient { private runtime: BamlRuntime private ctx_manager: BamlCtxManager @@ -46,7 +40,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as {{fn.return_type}} + return raw.parsed(false) as {{fn.return_type}} } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { diff --git a/engine/language_client_python/Cargo.toml b/engine/language_client_python/Cargo.toml index 03417b9374..c7abdf67cc 100644 --- a/engine/language_client_python/Cargo.toml +++ b/engine/language_client_python/Cargo.toml @@ -29,6 +29,7 @@ env_logger.workspace = true futures.workspace = true indexmap.workspace = true libc = "0.2" +jsonish = { path = "../baml-lib/jsonish" } log.workspace = true ctrlc = "3.4" # Consult https://pyo3.rs/main/migration for migration instructions diff --git a/engine/language_client_python/src/types/function_results.rs b/engine/language_client_python/src/types/function_results.rs index 0a52cb8313..e62937cc23 100644 --- a/engine/language_client_python/src/types/function_results.rs +++ b/engine/language_client_python/src/types/function_results.rs @@ -1,7 +1,8 @@ use baml_types::{BamlValueWithMeta, ResponseCheck}; +use jsonish::ResponseBamlValue; use pyo3::prelude::{pymethods, PyResult}; use pyo3::types::{PyAnyMethods, PyDict, PyModule, PyTuple, PyType}; -use pyo3::{Bound, IntoPyObject, IntoPyObjectExt, PyAny, PyObject, Python}; +use pyo3::{Bound, IntoPyObject, IntoPyObjectExt, Py, PyAny, PyErr, PyObject, Python}; use crate::errors::BamlError; @@ -32,15 +33,18 @@ impl FunctionResult { py: Python<'_>, enum_module: Bound<'_, PyModule>, cls_module: Bound<'_, PyModule>, + partial_cls_module: Bound<'_, PyModule>, + allow_partials: bool, ) -> PyResult { let parsed = self .inner .result_with_constraints_content() .map_err(BamlError::from_anyhow)?; - let parsed = pythonize_strict(py, parsed.clone(), &enum_module, &cls_module)?; + let parsed = pythonize_strict(py, parsed.clone(), &enum_module, &cls_module, &partial_cls_module, allow_partials); + // eprintln!("parsed result: {:?}", parsed); - Ok(parsed) + Ok(parsed?) } } @@ -74,12 +78,16 @@ fn pythonize_checks<'a>( fn pythonize_strict( py: Python<'_>, - parsed: BamlValueWithMeta>, + parsed: ResponseBamlValue, enum_module: &Bound<'_, PyModule>, cls_module: &Bound<'_, PyModule>, + partial_cls_module: &Bound<'_, PyModule>, + allow_partials: bool, ) -> PyResult { - let meta = parsed.meta().clone(); - let py_value_without_constraints = match parsed { + let allow_partials = allow_partials && !parsed.0.meta().2.required_done; + // eprintln!("pythonize_strict parsed: {:?}", parsed); + let meta = parsed.0.meta().clone(); + let py_value_without_constraints = match parsed.0 { BamlValueWithMeta::String(val, _) => val.into_py_any(py), BamlValueWithMeta::Int(val, _) => val.into_py_any(py), BamlValueWithMeta::Float(val, _) => val.into_py_any(py), @@ -88,7 +96,7 @@ fn pythonize_strict( let dict = pyo3::types::PyDict::new(py); for (key, value) in index_map { let key = key.into_pyobject(py)?; - let value = pythonize_strict(py, value, enum_module, cls_module)?; + let value = pythonize_strict(py, ResponseBamlValue(value), enum_module, cls_module, partial_cls_module, allow_partials)?; dict.set_item(key, value)?; } Ok(dict.into()) @@ -96,7 +104,7 @@ fn pythonize_strict( BamlValueWithMeta::List(vec, _) => pyo3::types::PyList::new( py, vec.into_iter() - .map(|v| pythonize_strict(py, v, enum_module, cls_module)) + .map(|v| pythonize_strict(py, ResponseBamlValue(v), enum_module, cls_module, partial_cls_module, allow_partials)) .collect::>>()?, )? .into_py_any(py), @@ -138,7 +146,8 @@ fn pythonize_strict( let properties = index_map .into_iter() .map(|(key, value)| { - let value = pythonize_strict(py, value, enum_module, cls_module)?; + let subvalue_allow_partials = allow_partials && !value.meta().2.required_done; + let value = pythonize_strict(py, ResponseBamlValue(value), enum_module, cls_module, partial_cls_module, subvalue_allow_partials)?; Ok((key.clone(), value)) }) .collect::>>()?; @@ -169,7 +178,9 @@ fn pythonize_strict( } } - let class_type = match cls_module.getattr(class_name.as_str()) { + let target_class = if allow_partials { partial_cls_module } else { cls_module }; + let backup_class = if allow_partials { cls_module } else {partial_cls_module}; + let class_type = match target_class.getattr(class_name.as_str()) { Ok(class) => class, // This can be true in the case of dynamic types. /* @@ -179,52 +190,106 @@ fn pythonize_strict( Err(_) => return Ok(properties_dict.into()), }; - let instance = - class_type.call_method("model_validate", (properties_dict.clone(),), None)?; + let backup_class_type = match backup_class.getattr(class_name.as_str()) { + Ok(class) => class, + Err(_) => unreachable!("The return value for the Err case in class_type would have triggered before we reached this line."), + }; + + let instance = match class_type.call_method("model_validate", (properties_dict.clone(),), None) { + Ok(x) => Ok(x), + Err(original_error) => match backup_class_type.call_method("model_validate", (properties_dict.clone(),), None) { + Ok(x) => Ok(x), + Err(_) => Err(original_error) + } + }?; Ok(instance.into()) } BamlValueWithMeta::Null(_) => Ok(py.None()), }?; - if meta.is_empty() { + let (_, checks, completion_state) = meta; + if checks.is_empty() && !completion_state.display { + // eprintln!("ret1: {:?}", py_value_without_constraints); Ok(py_value_without_constraints) } else { - // Generate the Python checks - let python_checks = pythonize_checks(py, cls_module, &meta)?; - - // Get the type of the original value - let value_type = py_value_without_constraints.bind(py).get_type(); // Import the necessary modules and objects - let typing = py.import("typing")?; - let literal = typing.getattr("Literal")?; + let typing = py.import("typing").expect("typing"); + let literal = typing.getattr("Literal").expect("Literal"); + let value_with_possible_checks = if !checks.is_empty() { + + // Generate the Python checks + let python_checks = pythonize_checks(py, cls_module, &checks).expect("pythonize_checks"); + + // Get the type of the original value + let value_type = py_value_without_constraints.bind(py).get_type(); + + + // Collect check names as &str and turn them into a Python tuple + let check_names: Vec<&str> = checks.iter().map(|check| check.name.as_str()).collect(); + let literal_args = PyTuple::new_bound(py, check_names); + + // Call Literal[...] dynamically + let literal_check_names = literal.get_item(literal_args).expect("get_item"); - // Collect check names as &str and turn them into a Python tuple - let check_names: Vec<&str> = meta.iter().map(|check| check.name.as_str()).collect(); - let literal_args = PyTuple::new(py, check_names)?; - // Call Literal[...] dynamically - let literal_check_names = literal.get_item(literal_args)?; + let class_checked_type_constructor = cls_module.getattr("Checked").expect("getattr(Checked)"); - // Prepare the properties dictionary - let properties_dict = pyo3::types::PyDict::new(py); - properties_dict.set_item("value", py_value_without_constraints)?; - properties_dict.set_item("checks", python_checks)?; + // Prepare type parameters for Checked[...] + let type_parameters_tuple = PyTuple::new(py, [value_type.as_ref(), &literal_check_names]).expect("PyTuple::new"); + + // Create the Checked type using __class_getitem__ + let class_checked_type: Bound<'_, PyAny> = class_checked_type_constructor + .call_method1("__class_getitem__", (type_parameters_tuple,)).expect("__class_getitem__"); + + // Prepare the properties dictionary + let properties_dict = pyo3::types::PyDict::new(py); + properties_dict.set_item("value", py_value_without_constraints)?; + if !checks.is_empty() { + properties_dict.set_item("checks", python_checks)?; + } + + // Validate the model with the constructed type + let checked_instance = + class_checked_type.call_method("model_validate", (properties_dict.clone(),), None).expect("model_validate"); + + // eprintln!("ret2: {:?}", checked_instance); + + Ok::, PyErr>(checked_instance.into()) + } else { + Ok(py_value_without_constraints) + }?; + + let value_with_possible_completion_state = if completion_state.display && allow_partials { + let value_type = value_with_possible_checks.bind(py).get_type(); + // eprintln!("value_type: {:?}", value_type); + + // Prepare the properties dictionary + let properties_dict = pyo3::types::PyDict::new(py); + properties_dict.set_item("value", value_with_possible_checks)?; + properties_dict.set_item("state", format!("{:?}", completion_state.state))?; - let class_checked_type_constructor = cls_module.getattr("Checked")?; + // Prepare type parameters for StreamingState[...] + let type_parameters_tuple = PyTuple::new(py, [value_type.as_ref()]).expect("PyTuple::new"); + // dbg!(&type_parameters_tuple); - // Prepare type parameters for Checked[...] - let type_parameters_tuple = PyTuple::new(py, [value_type.as_ref(), &literal_check_names])?; + let class_streaming_state_type_constructor = partial_cls_module.getattr("StreamState").expect("getattr(StreamState)"); + let class_completion_state_type: Bound<'_, PyAny> = class_streaming_state_type_constructor + .call_method1("__class_getitem__", (type_parameters_tuple,)) + .expect("__class_getitem__ for streaming"); + // dbg!(&class_completion_state_type); - // Create the Checked type using __class_getitem__ - let class_checked_type: Bound<'_, PyAny> = class_checked_type_constructor - .call_method1("__class_getitem__", (type_parameters_tuple,))?; + // eprintln!("properties dict: {:?}", properties_dict); + let streaming_state_instance = class_completion_state_type + .call_method("model_validate", (properties_dict.clone(),), None) + .expect("model_validate for streaming"); - // Validate the model with the constructed type - let checked_instance = - class_checked_type.call_method("model_validate", (properties_dict.clone(),), None)?; + Ok::, PyErr>(streaming_state_instance.into()) + } else { + Ok(value_with_possible_checks) + }?; - Ok(checked_instance.into()) + Ok(value_with_possible_completion_state) } } diff --git a/engine/language_client_ruby/Gemfile.lock b/engine/language_client_ruby/Gemfile.lock index 0b8642e059..bf02ae224b 100644 --- a/engine/language_client_ruby/Gemfile.lock +++ b/engine/language_client_ruby/Gemfile.lock @@ -1,7 +1,7 @@ PATH remote: . specs: - baml (0.68.0) + baml (0.73.5) GEM remote: https://rubygems.org/ diff --git a/engine/language_client_ruby/ext/ruby_ffi/Cargo.toml b/engine/language_client_ruby/ext/ruby_ffi/Cargo.toml index 7899cfa80a..e71e289e0f 100644 --- a/engine/language_client_ruby/ext/ruby_ffi/Cargo.toml +++ b/engine/language_client_ruby/ext/ruby_ffi/Cargo.toml @@ -17,6 +17,7 @@ base64.workspace = true env_logger.workspace = true futures.workspace = true indexmap.workspace = true +jsonish = { path = "../../../baml-lib/jsonish" } log.workspace = true magnus = { version = "0.7.1", features = ["rb-sys"] } # Must be kept in sync with ../../Gemfile diff --git a/engine/language_client_ruby/ext/ruby_ffi/src/function_result.rs b/engine/language_client_ruby/ext/ruby_ffi/src/function_result.rs index 234980ec9f..24f3c58ca2 100644 --- a/engine/language_client_ruby/ext/ruby_ffi/src/function_result.rs +++ b/engine/language_client_ruby/ext/ruby_ffi/src/function_result.rs @@ -35,20 +35,29 @@ impl FunctionResult { ruby: &Ruby, rb_self: &FunctionResult, types: RModule, + partial_types: RModule, + allow_partials: bool, ) -> Result { - match rb_self.inner.result_with_constraints_content() { - Ok(parsed) => ruby_to_json::RubyToJson::serialize_baml(ruby, types, parsed.clone()) + dbg!(&types); + dbg!(&partial_types); + dbg!(&allow_partials); + let res = match rb_self.inner.result_with_constraints_content() { + Ok(parsed) => { + ruby_to_json::RubyToJson::serialize_baml(ruby, types, partial_types, allow_partials, parsed.clone()) .map_err(|e| { magnus::Error::new( ruby.exception_type_error(), format!("failing inside parsed_using_types: {:?}", e), ) - }), + }) + }, Err(_) => Err(Error::new( ruby.exception_runtime_error(), format!("Failed to parse LLM response: {}", rb_self.inner), )), - } + }; + dbg!(&res); + res } /// For usage in magnus::init @@ -59,7 +68,7 @@ impl FunctionResult { cls.define_method( "parsed_using_types", - method!(FunctionResult::parsed_using_types, 1), + method!(FunctionResult::parsed_using_types, 3), )?; Ok(()) diff --git a/engine/language_client_ruby/ext/ruby_ffi/src/lib.rs b/engine/language_client_ruby/ext/ruby_ffi/src/lib.rs index 0ec1bc3a02..640e45284b 100644 --- a/engine/language_client_ruby/ext/ruby_ffi/src/lib.rs +++ b/engine/language_client_ruby/ext/ruby_ffi/src/lib.rs @@ -280,7 +280,7 @@ fn init(ruby: &Ruby) -> Result<()> { )?; module.define_module_function( "serialize", - function!(ruby_to_json::RubyToJson::serialize, 2), + function!(ruby_to_json::RubyToJson::serialize, 4), )?; Ok(()) diff --git a/engine/language_client_ruby/ext/ruby_ffi/src/ruby_to_json.rs b/engine/language_client_ruby/ext/ruby_ffi/src/ruby_to_json.rs index 10fedc8c7b..fd2b910137 100644 --- a/engine/language_client_ruby/ext/ruby_ffi/src/ruby_to_json.rs +++ b/engine/language_client_ruby/ext/ruby_ffi/src/ruby_to_json.rs @@ -1,4 +1,4 @@ -use baml_types::{BamlMap, BamlValue, BamlValueWithMeta, ResponseCheck}; +use baml_types::{BamlMap, BamlValue, BamlValueWithMeta, CompletionState, ResponseCheck}; use indexmap::IndexMap; use magnus::{ prelude::*, typed_data::Obj, value::Value, Error, Float, Integer, IntoValue, RArray, RClass, @@ -10,6 +10,7 @@ use crate::types::{ self, media::{Audio, Image}, }; +use jsonish::{deserializer::deserialize_flags::Flag, ResponseBamlValue}; struct SerializationError { position: Vec, @@ -57,77 +58,127 @@ impl<'rb> RubyToJson<'rb> { pub fn serialize_baml( ruby: &Ruby, types: RModule, - mut from: BamlValueWithMeta>, + partial_types: RModule, + allow_partials: bool, + mut from: ResponseBamlValue, ) -> crate::Result { + + let allow_partials = allow_partials && !from.0.meta().2.required_done; // If we encounter a BamlValue node with check results, serialize it as // { value: T, checks: K }. To compute `value`, we strip the metadata // off the node and pass it back to `serialize_baml`. - if !from.meta().is_empty() { - let meta = from.meta().clone(); - let checks = Self::serialize_response_checks(ruby, &meta)?; - - *from.meta_mut() = vec![]; - let serialized_subvalue = Self::serialize_baml(ruby, types, from)?; + eprintln!("SERIALIZE: {:?}", &from); + let (_flags, checks, completion) = from.0.meta_mut(); - let checked_class = ruby.eval::("Baml::Checked")?; + if completion.display && allow_partials { + eprintln!("... with state"); let hash = ruby.hash_new(); + let stream_state_class = ruby.eval::("Baml::StreamState")?; + hash.aset(ruby.sym_new("state"), ruby.sym_new(serde_json::to_string(&completion.state).expect("Serializing CompletionState is safe.")))?; + completion.display = false; + let serialized_subvalue = RubyToJson::serialize_baml(ruby, types, partial_types, allow_partials, from)?; hash.aset(ruby.sym_new("value"), serialized_subvalue)?; - hash.aset(ruby.sym_new("checks"), checks)?; - Ok(checked_class.funcall("new", (hash,))?) + let res = stream_state_class.funcall("new", (hash,)); + eprintln!("with_state res: {res:?}"); + Ok(res?) } // Otherwise encode it directly. else { - match from { - BamlValueWithMeta::Class(class_name, class_fields, _) => { - let hash = ruby.hash_new(); - for (k, v) in class_fields.into_iter() { - let k = ruby.sym_new(k.as_str()); - let v = RubyToJson::serialize_baml(ruby, types, v)?; - hash.aset(k, v)?; - } - match types.const_get::<_, RClass>(class_name.as_str()) { - Ok(class_type) => class_type.funcall("new", (hash,)), - Err(_) => { - let dynamic_class_type = ruby.eval::("Baml::DynamicStruct")?; - dynamic_class_type.funcall("new", (hash,)) - } - } + + if !checks.is_empty() { + let serialized_checks = Self::serialize_response_checks(ruby, &checks)?; + + checks.clear(); + + let serialized_subvalue = Self::serialize_baml(ruby, types, partial_types, allow_partials, from)?; + + let checked_class = ruby.eval::("Baml::Checked")?; + let hash = ruby.hash_new(); + hash.aset(ruby.sym_new("value"), serialized_subvalue)?; + if !serialized_checks.is_empty() { + hash.aset(ruby.sym_new("checks"), serialized_checks)?; } - BamlValueWithMeta::Enum(enum_name, enum_value, _) => { - if let Ok(enum_type) = types.const_get::<_, RClass>(enum_name.as_str()) { - let enum_value = ruby.str_new(&enum_value); - if let Ok(enum_instance) = enum_type.funcall("deserialize", (enum_value,)) { - return Ok(enum_instance); + let res = checked_class.funcall("new", (hash,)); + dbg!(&res); + eprintln!("with_checks res: {res:?}"); + Ok(res?) + } + // Otherwise encode it directly. + else { + eprintln!("...without_state"); + let res = match from.0 { + BamlValueWithMeta::Class(class_name, class_fields, _) => { + let hash = ruby.hash_new(); + for (k, v) in class_fields.into_iter() { + let subvalue_allow_partials = allow_partials && !v.meta().2.required_done; + dbg!(&subvalue_allow_partials); + let k = ruby.sym_new(k.as_str()); + let v = RubyToJson::serialize_baml(ruby, types, partial_types, subvalue_allow_partials, ResponseBamlValue(v))?; + hash.aset(k, v)?; + } + + let (preferred_module, backup_module) = if allow_partials { (partial_types, types) } else { (types, partial_types) }; + let preferred_class = match preferred_module.const_get::<_, RClass>(class_name.as_str()) { + Ok(class_type) => class_type, + Err(_) => ruby.eval::("Baml::DynamicStruct")?, + }; + let backup_class = match backup_module.const_get::<_, RClass>(class_name.as_str()) { + Ok(class_type) => class_type, + // Err(_) => ruby.eval::("Baml::DynamicStruct")?, + Err(_) => unreachable!("trying to avoid these"), + }; + match preferred_class.funcall("new", (hash,)) { + Ok(res) => Ok(res), + Err(original_error) => { + eprintln!("preferred_class {:?}, failed: {:?}. falling back to {:?}", preferred_class, original_error, backup_class); + match backup_class.funcall("new", (hash,)) { + Ok(res) => Ok(res), + Err(e) => { + eprintln!("backup {:?} failed with {:?}", backup_class, e); + Err(original_error) + } + } + } } } + BamlValueWithMeta::Enum(enum_name, enum_value, _) => { + if let Ok(enum_type) = types.const_get::<_, RClass>(enum_name.as_str()) { + let enum_value = ruby.str_new(&enum_value); + if let Ok(enum_instance) = enum_type.funcall("deserialize", (enum_value,)) { + return Ok(enum_instance); + } + } - Ok(ruby.str_new(&enum_value).into_value_with(ruby)) - } - BamlValueWithMeta::Map(m, _) => { - let hash = ruby.hash_new(); - for (k, v) in m.into_iter() { - let k = ruby.str_new(&k); - let v = RubyToJson::serialize_baml(ruby, types, v)?; - hash.aset(k, v)?; + Ok(ruby.str_new(&enum_value).into_value_with(ruby)) } - Ok(hash.into_value_with(ruby)) - } - BamlValueWithMeta::List(l, _) => { - let arr = ruby.ary_new(); - for v in l.into_iter() { - let v = RubyToJson::serialize_baml(ruby, types, v)?; - arr.push(v)?; + BamlValueWithMeta::Map(m, _) => { + let hash = ruby.hash_new(); + for (k, v) in m.into_iter() { + let k = ruby.str_new(&k); + let v = RubyToJson::serialize_baml(ruby, types, partial_types, allow_partials, ResponseBamlValue(v))?; + hash.aset(k, v)?; + } + Ok(hash.into_value_with(ruby)) } - Ok(arr.into_value_with(ruby)) - } - _ => serde_magnus::serialize(&from), + BamlValueWithMeta::List(l, _) => { + let arr = ruby.ary_new(); + for v in l.into_iter() { + let v = RubyToJson::serialize_baml(ruby, types, partial_types, allow_partials, ResponseBamlValue(v))?; + arr.push(v)?; + } + Ok(arr.into_value_with(ruby)) + } + _ => serde_magnus::serialize(&from.0.value()), + }; + dbg!(&res); + res } } } - pub fn serialize(ruby: &Ruby, types: RModule, from: Value) -> crate::Result { + pub fn serialize(ruby: &Ruby, types: RModule, partial_types: RModule, allow_partials: bool, from: Value) -> crate::Result { let json = RubyToJson::convert(from)?; - RubyToJson::serialize_baml(ruby, types, BamlValueWithMeta::with_default_meta(&json)) + RubyToJson::serialize_baml(ruby, types, partial_types, allow_partials, ResponseBamlValue(BamlValueWithMeta::with_default_meta(&json))) } /// Convert a Ruby object to a JSON object. @@ -575,4 +626,5 @@ impl<'rb> RubyToJson<'rb> { Some(Ok(any)) => self.to_json(any, field_pos), } } + } diff --git a/engine/language_client_ruby/lib/baml.rb b/engine/language_client_ruby/lib/baml.rb index 430a5c1abd..052f31b61a 100644 --- a/engine/language_client_ruby/lib/baml.rb +++ b/engine/language_client_ruby/lib/baml.rb @@ -18,6 +18,9 @@ module Baml Checked = Baml::Checks::Checked Check = Baml::Checks::Check + # Reexport StreamState. + StreamState = Baml::StreamState::StreamState + # Dynamically + idempotently define Baml::TypeConverter # NB: this does not respect raise_coercion_error = false def self.convert_to(type) diff --git a/engine/language_client_ruby/lib/stream.rb b/engine/language_client_ruby/lib/stream.rb index caea9a2621..6bdc44b70f 100644 --- a/engine/language_client_ruby/lib/stream.rb +++ b/engine/language_client_ruby/lib/stream.rb @@ -15,6 +15,24 @@ module Baml # end # end + module StreamState + class StreamState < T::Struct + extend T::Sig + + extend T::Generic + + Value = type_member + + const :value, Value + const :state, Symbol + + def initialize(props) + super(value: props[:value], state: props[:state]) + end + + end + end + class BamlStream extend T::Sig extend T::Generic @@ -47,7 +65,7 @@ def each(&block) # collection: https://ruby-doc.org/3.1.6/Enumerable.html#module-Enumerable-label-Usage if @final_response == nil @final_response = @ffi_stream.done(@ctx_manager) do |event| - block.call event.parsed_using_types(Baml::PartialTypes) + block.call event.parsed_using_types(Baml::Types, Baml::PartialTypes, true) end end @@ -64,7 +82,8 @@ def get_final_response @final_response = @ffi_stream.done(@ctx_manager) end - @final_response.parsed_using_types(Baml::Types) + @final_response.parsed_using_types(Baml::Types, Baml::PartialTypes, false) end end + end \ No newline at end of file diff --git a/engine/language_client_typescript/native.d.ts b/engine/language_client_typescript/native.d.ts index 5b34ee358c..19dd516101 100644 --- a/engine/language_client_typescript/native.d.ts +++ b/engine/language_client_typescript/native.d.ts @@ -73,7 +73,7 @@ export declare class FieldType { export declare class FunctionResult { isOk(): boolean - parsed(): any + parsed(allowPartials: boolean): any } export declare class FunctionResultStream { diff --git a/engine/language_client_typescript/src/types/function_results.rs b/engine/language_client_typescript/src/types/function_results.rs index dd0c00af46..7490b5aaac 100644 --- a/engine/language_client_typescript/src/types/function_results.rs +++ b/engine/language_client_typescript/src/types/function_results.rs @@ -16,12 +16,18 @@ impl FunctionResult { } #[napi] - pub fn parsed(&self) -> napi::Result { + pub fn parsed(&self, allow_partials: bool) -> napi::Result { let parsed = self .inner .result_with_constraints_content() .map_err(from_anyhow_error)?; - Ok(serde_json::to_value(parsed)?) + let response = serde_json::to_value( if allow_partials { + parsed.serialize_partial() + } else { + parsed.serialize_final() + } + )?; + Ok(response) } } diff --git a/engine/language_client_typescript/stream.js b/engine/language_client_typescript/stream.js index a252e59a98..168cd054f6 100644 --- a/engine/language_client_typescript/stream.js +++ b/engine/language_client_typescript/stream.js @@ -49,13 +49,13 @@ class BamlStream { break; } if (event.isOk()) { - yield this.partialCoerce(event.parsed()); + yield this.partialCoerce(event.parsed(true)); } } } async getFinalResponse() { const final = await this.driveToCompletionInBg(); - return this.finalCoerce(final.parsed()); + return this.finalCoerce(final.parsed(false)); } } exports.BamlStream = BamlStream; diff --git a/engine/language_client_typescript/typescript_src/stream.ts b/engine/language_client_typescript/typescript_src/stream.ts index 554bcb2215..fd30645c3a 100644 --- a/engine/language_client_typescript/typescript_src/stream.ts +++ b/engine/language_client_typescript/typescript_src/stream.ts @@ -53,7 +53,7 @@ export class BamlStream { } if (event.isOk()) { - yield this.partialCoerce(event.parsed()) + yield this.partialCoerce(event.parsed(true)) } } } @@ -61,6 +61,6 @@ export class BamlStream { async getFinalResponse(): Promise { const final = await this.driveToCompletionInBg() - return this.finalCoerce(final.parsed()) + return this.finalCoerce(final.parsed(false)) } } diff --git a/fern/01-guide/04-baml-basics/streaming.mdx b/fern/01-guide/04-baml-basics/streaming.mdx index 6f04c6614a..d73bdc8b87 100644 --- a/fern/01-guide/04-baml-basics/streaming.mdx +++ b/fern/01-guide/04-baml-basics/streaming.mdx @@ -12,14 +12,15 @@ If you tried streaming in a JSON output from an LLM you'd see something like: {"items": [{"name": "Apple", "quantity": 2, "price": 1.50}], "total_cost": 3.00} # Completed ``` -BAML automatically fixes this partial JSON, and transforms all your types into `Partial` types with all `Optional` fields only during the stream. +BAML gives you fine-grained control of how it fixes this partial JSON and transforms +it into a series of semantically valid partial objects. You can check out more examples (including streaming in FastAPI and NextJS) in the [BAML Examples] repo. [call BAML functions]: /docs/calling-baml/calling-functions [BAML Examples]: https://github.com/BoundaryML/baml-examples/tree/main -Lets stream the output of this function `function ExtractReceiptInfo(email: string) -> ReceiptInfo` for our example: +Let's stream the output of this function `function ExtractReceiptInfo(email: string) -> ReceiptInfo` for our example: @@ -49,6 +50,14 @@ function ExtractReceiptInfo(email: string) -> ReceiptInfo { ``` +The BAML code generator creates a set of types in the `baml_client` library +in a module called `partial` in `baml_client`. These types are modified +from your original types to support streaming. + +By default, BAML will convert all Class fields into nullable fields, and +fill those fields with non-null values as much as possible given the tokens +received so far. + @@ -253,5 +262,245 @@ Streaming is not yet supported via OpenAPI, but it will be coming soon! -Number fields are always streamed in only when the LLM completes them. E.g. if the final number is 129.95, you'll only see null or 129.95 instead of partial numbers like 1, 12, 129.9, etc. - \ No newline at end of file +Number fields are always streamed in only when the LLM completes them. E.g. if +the final number is 129.95, you'll only see null or 129.95 instead of partial +numbers like 1, 12, 129.9, etc. + + +## Semantic Streaming + +The BAML language provides several attributes that can be attached to types +to control streaming behavior, ensuring that the partial values streamed to you +are always valid within your own semantics. + + - `@stream.done`: Marks a type that should only be streamed when it is + done being read from the LLM response. + - `@stream.not_null`: Marks a field to indicate that the class containing + that field should only be streamed if that field is present (the field needed + not be completed) + - `@stream.with_state`: Adds metadata to a type indicating whether types appearing + +### `@stream.done` + +To demonstrate the use of `@stream.done`, imagine that a `ReceiptItem` +must only be consider valid and can only reach the client when its `name`, +`description`, `quantity` and `price` fields are completely streamed in. +To achieve this we can annotate the `ReceiptItem` class with the +`@stream.done` attribute: + +```rust +class ReceiptItem { + name string + description string? + quantity int + price float + @@stream.done +} +``` + +When generating the client code for `ReceiptType` none of the fields of +`ReceiptItem` will be converted to optional. And when parsing an LLM response, +no `ReceiptItem` will be streamed out until all of its fields are done being +streamed in. + +### `@stream.not_null` + +Sometimes the presence of a value is important to the correct interpretation +of a containing value. This commonly occurs with tags used to determine which +part of a union is being used. + +For example, in this code block, `@stream.not_null` on each of the +`message_type` fields will ensure that an `Event` is never streamed until +enough tokens have been received to precisely know what the message type is, +allowing you to build UI appropriate to the message type before the other +fields have been completed. + +```rust +class Message { + message_type "greeting" | "within-convo" | "farewell" @stream.not_null + gesture ("gesticulate" | "wave" | "shake-hands" | "hug")? + message string +} + +class Event { + event_message: Message + speaker string +} + +function Chat(history: Event[]) -> Event { ... } +``` + + +You might wonder if it's sufficient to use `@stream.done` on the +`message_type` field. `@stream.done` applies to types, preventing them +from streaming out until they are completed. On the other hand, +`@stream.not_null` applies to fields and prevents a containing object +from streaming out until that field is present. + +A type with `@stream.done` on it will still be converted to a +nullable field in the generated partial types, so this change would not +produce the desired result of witholding a `Message` until its type is +known. Messages would be streamed with `message_type: null`. + +This is a subtle distinction between `@stream.done` and +`@stream.not_null`. As a rule of thumb, remember that `@stream.done` +is about the type itself, and `@stream.not_null` is about the type's +containing context. + + +### `@stream.with_state` + +It is often useful to know in client code whether a value is finished, or +could be updated in future messages. The `@stream.with_state` attribute lets +you attach metadata to types to indicate this state in client code. + + + +```rust +class Message { + message_type "greeting" | "within-convo" | "farewell" @stream.not_null + gesture ("gesticulate" | "wave" | "shake-hands" | "hug")? + message string @stream.with_state +} +``` + + +```python +class StreamState(BaseModel, Generic[T]): + value: T, + state: "incomplete" | "complete" + +class Message(BaseModel): + message_type: Union["greeting", "within-convo", "farewell"] + gesture: Option[Union["gesticulate", "wave", "shake-hands", "hug"]] + message: StreamState[String] +``` + + + +```typescript +interface StreamState { + value: T, + state: "incomplete" | "complete" +} + +interface Message { + message_type: "greeting" | "within-convo" | "farewell", + gesture: ("gesticulate" | "wave" | "shake-hands" | "hug")?, + message: StreamState, +} +``` + + + + +## Putting it all together + +Let's put all of these concepts together to design an application that +streams a conversation containing stock recommendations, using semantic +streaming to ensure that the streamed data obeys our domain's invariants. + +```rust +enum Stock { + APPL, + MSFT, + GOOG, + BAML, +} + +// Make recommendations atomic - we do not want a recommendation to be +// modified by streaming additional messages. +class Recommendation { + stock Stock + amount float + action "buy" | "sell" + @@stream.done +} + +class Message { + message_type "greeting" | "conversation" | "farewell" @stream.not_null + message string @stream.with_state @stream.not_null +} + +function Respond( + history: (Message | Recommendation | UserMessage)[] +) -> Message | Recommendation { ... } +``` + + + + +The above BAML code will generate the following Python definitions in the +`partial` module. The use of streaming attributes has several effects on +the generated code: + + - `Recommendation` does not have any partial field because it was marked + `@stream.done`. + - The `Message.message` `string` is wrapped in `StreamState`, allowing + runtime checking of its completion status. This status could be used + to render a spinner as the message streams in. + - The `Message.message_type` field may not be `null`, because it was marked + as `@stream.not_null`. + +```python +class StreamState(BaseModel, Generic[T]): + value: T, + state: Union[Literal["incomplete"] | Literal[]] + +class Stock(str, Enum): + APPL = "APPL" + MSFT = "MSFT" + GOOG = "GOOG" + BAML = "BAML" + +class Recommendation(BaseClass): + stock: Stock + amount: float + action: Union[Literal["buy"], Literal["sell"]] + +class Message(BaseClass): + message_type: Union[Literal["gretting"], Literal["conversation"], Literal["farewell"]] + message: StreamState[string] +``` + + + +The above BAML code will generate the following Typescript definitions in the +`partial` module. The use of streaming attributes has several effects on +the generated code: + + - `Recommendation` does not have any partial field because it was marked + `@stream.done`. + - The `Message.message` `string` is wrapped in `StreamState`, allowing + runtime checking of its completion status. This status could be used + to render a spinner as the message streams in. + - The `Message.message_type` field may not be `null`, because it was marked + as `@stream.not_null`. + +```typescript +export interface StreamState { + value: T, + state: "incomplete" | "complete" +} + +export enum Category { + APPL = "APPl", + MSFT = "MSFT", + GOOG = "GOOG", + BAML = "BAML", +} + +export interface Recommendation { + stock: Stock, + amount: float, + action: "buy" | "sell" +} + +export interface Message { + message_type: "gretting" | "conversation" | "farewell" + message: StreamState +} +``` + + + diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000000..2dab9fe68d --- /dev/null +++ b/flake.lock @@ -0,0 +1,117 @@ +{ + "nodes": { + "fenix": { + "inputs": { + "nixpkgs": [ + "nixpkgs" + ], + "rust-analyzer-src": "rust-analyzer-src" + }, + "locked": { + "lastModified": 1734935689, + "narHash": "sha256-yl/iko/0pvRN3PF6Z4FjQeb6AuGiavMENEisQWJ78h0=", + "owner": "nix-community", + "repo": "fenix", + "rev": "30616281e9bfe0883acb3369f2b89aad6850706f", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "fenix", + "type": "github" + } + }, + "flake-compat": { + "flake": false, + "locked": { + "lastModified": 1733328505, + "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1734649271, + "narHash": "sha256-4EVBRhOjMDuGtMaofAIqzJbg4Ql7Ai0PSeuVZTHjyKQ=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "d70bd19e0a38ad4790d3913bf08fcbfc9eeca507", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "fenix": "fenix", + "flake-compat": "flake-compat", + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + }, + "rust-analyzer-src": { + "flake": false, + "locked": { + "lastModified": 1734874959, + "narHash": "sha256-NlsVD/fI32wsHFua9Xvc7IFHCUpQIOs6D6RS/3AhMT8=", + "owner": "rust-lang", + "repo": "rust-analyzer", + "rev": "fa4a40bbe867ed54f5a7c905b591fd7d60ba35eb", + "type": "github" + }, + "original": { + "owner": "rust-lang", + "ref": "nightly", + "repo": "rust-analyzer", + "type": "github" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/integ-tests/baml_src/test-files/functions/output/class.baml b/integ-tests/baml_src/test-files/functions/output/class.baml index 8803251553..ea29a779ca 100644 --- a/integ-tests/baml_src/test-files/functions/output/class.baml +++ b/integ-tests/baml_src/test-files/functions/output/class.baml @@ -1,5 +1,5 @@ class TestOutputClass { - prop1 string + prop1 string @description("A long string with about 200 words") prop2 int } diff --git a/integ-tests/baml_src/test-files/semantic_streaming/semantic_streaming.baml b/integ-tests/baml_src/test-files/semantic_streaming/semantic_streaming.baml new file mode 100644 index 0000000000..28352a1bc6 --- /dev/null +++ b/integ-tests/baml_src/test-files/semantic_streaming/semantic_streaming.baml @@ -0,0 +1,33 @@ +class SemanticContainer { + sixteen_digit_number int + string_with_twenty_words string @stream.done + class_1 ClassWithoutDone + class_2 ClassWithBlockDone + class_done_needed ClassWithBlockDone @stream.not_null + class_needed ClassWithoutDone @stream.not_null + three_small_things SmallThing[] @description("Should have three items.") + final_string string +} + +class ClassWithoutDone { + i_16_digits int + s_20_words string @description("A string with 20 words in it") @stream.with_state +} + +class ClassWithBlockDone { + i_16_digits int + s_20_words string + @@stream.done +} + +class SmallThing { + i_16_digits int @stream.not_null + i_8_digits int +} + +function MakeSemanticContainer() -> SemanticContainer { + client GPT35 + prompt #" + {{ ctx.output_format }} + "# +} \ No newline at end of file diff --git a/integ-tests/python/baml_client/async_client.py b/integ-tests/python/baml_client/async_client.py index e4b2545619..5d68abcad1 100644 --- a/integ-tests/python/baml_client/async_client.py +++ b/integ-tests/python/baml_client/async_client.py @@ -71,7 +71,7 @@ async def AaaSamOutputFormat( tb, __cr__, ) - return cast(types.Recipe, raw.cast_to(types, types)) + return cast(types.Recipe, raw.cast_to(types, types, partial_types, False)) async def AliasThatPointsToRecursiveType( self, @@ -94,7 +94,7 @@ async def AliasThatPointsToRecursiveType( tb, __cr__, ) - return cast(types.LinkedListAliasNode, raw.cast_to(types, types)) + return cast(types.LinkedListAliasNode, raw.cast_to(types, types, partial_types, False)) async def AliasWithMultipleAttrs( self, @@ -117,7 +117,7 @@ async def AliasWithMultipleAttrs( tb, __cr__, ) - return cast(Checked[int,types.Literal["gt_ten"]], raw.cast_to(types, types)) + return cast(Checked[int,types.Literal["gt_ten"]], raw.cast_to(types, types, partial_types, False)) async def AliasedInputClass( self, @@ -140,7 +140,7 @@ async def AliasedInputClass( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def AliasedInputClass2( self, @@ -163,7 +163,7 @@ async def AliasedInputClass2( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def AliasedInputClassNested( self, @@ -186,7 +186,7 @@ async def AliasedInputClassNested( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def AliasedInputEnum( self, @@ -209,7 +209,7 @@ async def AliasedInputEnum( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def AliasedInputList( self, @@ -232,7 +232,7 @@ async def AliasedInputList( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def AllowedOptionals( self, @@ -255,7 +255,7 @@ async def AllowedOptionals( tb, __cr__, ) - return cast(types.OptionalListAndMap, raw.cast_to(types, types)) + return cast(types.OptionalListAndMap, raw.cast_to(types, types, partial_types, False)) async def AssertFn( self, @@ -278,7 +278,7 @@ async def AssertFn( tb, __cr__, ) - return cast(int, raw.cast_to(types, types)) + return cast(int, raw.cast_to(types, types, partial_types, False)) async def AudioInput( self, @@ -301,7 +301,7 @@ async def AudioInput( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def BuildLinkedList( self, @@ -324,7 +324,7 @@ async def BuildLinkedList( tb, __cr__, ) - return cast(types.LinkedList, raw.cast_to(types, types)) + return cast(types.LinkedList, raw.cast_to(types, types, partial_types, False)) async def BuildTree( self, @@ -347,7 +347,7 @@ async def BuildTree( tb, __cr__, ) - return cast(types.Tree, raw.cast_to(types, types)) + return cast(types.Tree, raw.cast_to(types, types, partial_types, False)) async def ClassThatPointsToRecursiveClassThroughAlias( self, @@ -370,7 +370,7 @@ async def ClassThatPointsToRecursiveClassThroughAlias( tb, __cr__, ) - return cast(types.ClassToRecAlias, raw.cast_to(types, types)) + return cast(types.ClassToRecAlias, raw.cast_to(types, types, partial_types, False)) async def ClassifyDynEnumTwo( self, @@ -393,7 +393,7 @@ async def ClassifyDynEnumTwo( tb, __cr__, ) - return cast(Union[types.DynEnumTwo, str], raw.cast_to(types, types)) + return cast(Union[types.DynEnumTwo, str], raw.cast_to(types, types, partial_types, False)) async def ClassifyMessage( self, @@ -416,7 +416,7 @@ async def ClassifyMessage( tb, __cr__, ) - return cast(types.Category, raw.cast_to(types, types)) + return cast(types.Category, raw.cast_to(types, types, partial_types, False)) async def ClassifyMessage2( self, @@ -439,7 +439,7 @@ async def ClassifyMessage2( tb, __cr__, ) - return cast(types.Category, raw.cast_to(types, types)) + return cast(types.Category, raw.cast_to(types, types, partial_types, False)) async def ClassifyMessage3( self, @@ -462,7 +462,7 @@ async def ClassifyMessage3( tb, __cr__, ) - return cast(types.Category, raw.cast_to(types, types)) + return cast(types.Category, raw.cast_to(types, types, partial_types, False)) async def Completion( self, @@ -485,7 +485,7 @@ async def Completion( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def CustomTask( self, @@ -508,7 +508,7 @@ async def CustomTask( tb, __cr__, ) - return cast(Union[types.BookOrder, types.FlightConfirmation, types.GroceryReceipt], raw.cast_to(types, types)) + return cast(Union[types.BookOrder, types.FlightConfirmation, types.GroceryReceipt], raw.cast_to(types, types, partial_types, False)) async def DescribeImage( self, @@ -531,7 +531,7 @@ async def DescribeImage( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def DescribeImage2( self, @@ -554,7 +554,7 @@ async def DescribeImage2( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def DescribeImage3( self, @@ -577,7 +577,7 @@ async def DescribeImage3( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def DescribeImage4( self, @@ -600,7 +600,7 @@ async def DescribeImage4( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def DifferentiateUnions( self, @@ -623,7 +623,7 @@ async def DifferentiateUnions( tb, __cr__, ) - return cast(Union[types.OriginalA, types.OriginalB], raw.cast_to(types, types)) + return cast(Union[types.OriginalA, types.OriginalB], raw.cast_to(types, types, partial_types, False)) async def DummyOutputFunction( self, @@ -646,7 +646,7 @@ async def DummyOutputFunction( tb, __cr__, ) - return cast(types.DummyOutput, raw.cast_to(types, types)) + return cast(types.DummyOutput, raw.cast_to(types, types, partial_types, False)) async def DynamicFunc( self, @@ -669,7 +669,7 @@ async def DynamicFunc( tb, __cr__, ) - return cast(types.DynamicClassTwo, raw.cast_to(types, types)) + return cast(types.DynamicClassTwo, raw.cast_to(types, types, partial_types, False)) async def DynamicInputOutput( self, @@ -692,7 +692,7 @@ async def DynamicInputOutput( tb, __cr__, ) - return cast(types.DynInputOutput, raw.cast_to(types, types)) + return cast(types.DynInputOutput, raw.cast_to(types, types, partial_types, False)) async def DynamicListInputOutput( self, @@ -715,7 +715,7 @@ async def DynamicListInputOutput( tb, __cr__, ) - return cast(List[types.DynInputOutput], raw.cast_to(types, types)) + return cast(List[types.DynInputOutput], raw.cast_to(types, types, partial_types, False)) async def ExpectFailure( self, @@ -738,7 +738,7 @@ async def ExpectFailure( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def ExtractContactInfo( self, @@ -761,7 +761,7 @@ async def ExtractContactInfo( tb, __cr__, ) - return cast(types.ContactInfo, raw.cast_to(types, types)) + return cast(types.ContactInfo, raw.cast_to(types, types, partial_types, False)) async def ExtractHobby( self, @@ -784,7 +784,7 @@ async def ExtractHobby( tb, __cr__, ) - return cast(List[Union[types.Hobby, str]], raw.cast_to(types, types)) + return cast(List[Union[types.Hobby, str]], raw.cast_to(types, types, partial_types, False)) async def ExtractNames( self, @@ -807,7 +807,7 @@ async def ExtractNames( tb, __cr__, ) - return cast(List[str], raw.cast_to(types, types)) + return cast(List[str], raw.cast_to(types, types, partial_types, False)) async def ExtractPeople( self, @@ -830,7 +830,7 @@ async def ExtractPeople( tb, __cr__, ) - return cast(List[types.Person], raw.cast_to(types, types)) + return cast(List[types.Person], raw.cast_to(types, types, partial_types, False)) async def ExtractReceiptInfo( self, @@ -853,7 +853,7 @@ async def ExtractReceiptInfo( tb, __cr__, ) - return cast(types.ReceiptInfo, raw.cast_to(types, types)) + return cast(types.ReceiptInfo, raw.cast_to(types, types, partial_types, False)) async def ExtractResume( self, @@ -876,7 +876,7 @@ async def ExtractResume( tb, __cr__, ) - return cast(types.Resume, raw.cast_to(types, types)) + return cast(types.Resume, raw.cast_to(types, types, partial_types, False)) async def ExtractResume2( self, @@ -899,7 +899,7 @@ async def ExtractResume2( tb, __cr__, ) - return cast(types.Resume, raw.cast_to(types, types)) + return cast(types.Resume, raw.cast_to(types, types, partial_types, False)) async def FnClassOptionalOutput( self, @@ -922,7 +922,7 @@ async def FnClassOptionalOutput( tb, __cr__, ) - return cast(Optional[types.ClassOptionalOutput], raw.cast_to(types, types)) + return cast(Optional[types.ClassOptionalOutput], raw.cast_to(types, types, partial_types, False)) async def FnClassOptionalOutput2( self, @@ -945,7 +945,7 @@ async def FnClassOptionalOutput2( tb, __cr__, ) - return cast(Optional[types.ClassOptionalOutput2], raw.cast_to(types, types)) + return cast(Optional[types.ClassOptionalOutput2], raw.cast_to(types, types, partial_types, False)) async def FnEnumListOutput( self, @@ -968,7 +968,7 @@ async def FnEnumListOutput( tb, __cr__, ) - return cast(List[types.EnumOutput], raw.cast_to(types, types)) + return cast(List[types.EnumOutput], raw.cast_to(types, types, partial_types, False)) async def FnEnumOutput( self, @@ -991,7 +991,7 @@ async def FnEnumOutput( tb, __cr__, ) - return cast(types.EnumOutput, raw.cast_to(types, types)) + return cast(types.EnumOutput, raw.cast_to(types, types, partial_types, False)) async def FnLiteralClassInputOutput( self, @@ -1014,7 +1014,7 @@ async def FnLiteralClassInputOutput( tb, __cr__, ) - return cast(types.LiteralClassHello, raw.cast_to(types, types)) + return cast(types.LiteralClassHello, raw.cast_to(types, types, partial_types, False)) async def FnLiteralUnionClassInputOutput( self, @@ -1037,7 +1037,7 @@ async def FnLiteralUnionClassInputOutput( tb, __cr__, ) - return cast(Union[types.LiteralClassOne, types.LiteralClassTwo], raw.cast_to(types, types)) + return cast(Union[types.LiteralClassOne, types.LiteralClassTwo], raw.cast_to(types, types, partial_types, False)) async def FnNamedArgsSingleStringOptional( self, @@ -1060,7 +1060,7 @@ async def FnNamedArgsSingleStringOptional( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def FnOutputBool( self, @@ -1083,7 +1083,7 @@ async def FnOutputBool( tb, __cr__, ) - return cast(bool, raw.cast_to(types, types)) + return cast(bool, raw.cast_to(types, types, partial_types, False)) async def FnOutputClass( self, @@ -1106,7 +1106,7 @@ async def FnOutputClass( tb, __cr__, ) - return cast(types.TestOutputClass, raw.cast_to(types, types)) + return cast(types.TestOutputClass, raw.cast_to(types, types, partial_types, False)) async def FnOutputClassList( self, @@ -1129,7 +1129,7 @@ async def FnOutputClassList( tb, __cr__, ) - return cast(List[types.TestOutputClass], raw.cast_to(types, types)) + return cast(List[types.TestOutputClass], raw.cast_to(types, types, partial_types, False)) async def FnOutputClassNested( self, @@ -1152,7 +1152,7 @@ async def FnOutputClassNested( tb, __cr__, ) - return cast(types.TestClassNested, raw.cast_to(types, types)) + return cast(types.TestClassNested, raw.cast_to(types, types, partial_types, False)) async def FnOutputClassWithEnum( self, @@ -1175,7 +1175,7 @@ async def FnOutputClassWithEnum( tb, __cr__, ) - return cast(types.TestClassWithEnum, raw.cast_to(types, types)) + return cast(types.TestClassWithEnum, raw.cast_to(types, types, partial_types, False)) async def FnOutputInt( self, @@ -1198,7 +1198,7 @@ async def FnOutputInt( tb, __cr__, ) - return cast(int, raw.cast_to(types, types)) + return cast(int, raw.cast_to(types, types, partial_types, False)) async def FnOutputLiteralBool( self, @@ -1221,7 +1221,7 @@ async def FnOutputLiteralBool( tb, __cr__, ) - return cast(Literal[False], raw.cast_to(types, types)) + return cast(Literal[False], raw.cast_to(types, types, partial_types, False)) async def FnOutputLiteralInt( self, @@ -1244,7 +1244,7 @@ async def FnOutputLiteralInt( tb, __cr__, ) - return cast(Literal[5], raw.cast_to(types, types)) + return cast(Literal[5], raw.cast_to(types, types, partial_types, False)) async def FnOutputLiteralString( self, @@ -1267,7 +1267,7 @@ async def FnOutputLiteralString( tb, __cr__, ) - return cast(Literal["example output"], raw.cast_to(types, types)) + return cast(Literal["example output"], raw.cast_to(types, types, partial_types, False)) async def FnOutputStringList( self, @@ -1290,7 +1290,7 @@ async def FnOutputStringList( tb, __cr__, ) - return cast(List[str], raw.cast_to(types, types)) + return cast(List[str], raw.cast_to(types, types, partial_types, False)) async def FnTestAliasedEnumOutput( self, @@ -1313,7 +1313,7 @@ async def FnTestAliasedEnumOutput( tb, __cr__, ) - return cast(types.TestEnum, raw.cast_to(types, types)) + return cast(types.TestEnum, raw.cast_to(types, types, partial_types, False)) async def FnTestClassAlias( self, @@ -1336,7 +1336,7 @@ async def FnTestClassAlias( tb, __cr__, ) - return cast(types.TestClassAlias, raw.cast_to(types, types)) + return cast(types.TestClassAlias, raw.cast_to(types, types, partial_types, False)) async def FnTestNamedArgsSingleEnum( self, @@ -1359,7 +1359,7 @@ async def FnTestNamedArgsSingleEnum( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def GetDataType( self, @@ -1382,7 +1382,7 @@ async def GetDataType( tb, __cr__, ) - return cast(types.RaysData, raw.cast_to(types, types)) + return cast(types.RaysData, raw.cast_to(types, types, partial_types, False)) async def GetOrderInfo( self, @@ -1405,7 +1405,7 @@ async def GetOrderInfo( tb, __cr__, ) - return cast(types.OrderInfo, raw.cast_to(types, types)) + return cast(types.OrderInfo, raw.cast_to(types, types, partial_types, False)) async def GetQuery( self, @@ -1428,7 +1428,7 @@ async def GetQuery( tb, __cr__, ) - return cast(types.SearchParams, raw.cast_to(types, types)) + return cast(types.SearchParams, raw.cast_to(types, types, partial_types, False)) async def InOutEnumMapKey( self, @@ -1451,7 +1451,7 @@ async def InOutEnumMapKey( tb, __cr__, ) - return cast(Dict[types.MapKey, str], raw.cast_to(types, types)) + return cast(Dict[types.MapKey, str], raw.cast_to(types, types, partial_types, False)) async def InOutLiteralStringUnionMapKey( self, @@ -1474,7 +1474,7 @@ async def InOutLiteralStringUnionMapKey( tb, __cr__, ) - return cast(Dict[Union[Literal["one"], Literal["two"], Union[Literal["three"], Literal["four"]]], str], raw.cast_to(types, types)) + return cast(Dict[Union[Literal["one"], Literal["two"], Union[Literal["three"], Literal["four"]]], str], raw.cast_to(types, types, partial_types, False)) async def InOutSingleLiteralStringMapKey( self, @@ -1497,7 +1497,7 @@ async def InOutSingleLiteralStringMapKey( tb, __cr__, ) - return cast(Dict[Literal["key"], str], raw.cast_to(types, types)) + return cast(Dict[Literal["key"], str], raw.cast_to(types, types, partial_types, False)) async def JsonTypeAliasCycle( self, @@ -1520,7 +1520,7 @@ async def JsonTypeAliasCycle( tb, __cr__, ) - return cast(types.JsonValue, raw.cast_to(types, types)) + return cast(types.JsonValue, raw.cast_to(types, types, partial_types, False)) async def LiteralUnionsTest( self, @@ -1543,7 +1543,7 @@ async def LiteralUnionsTest( tb, __cr__, ) - return cast(Union[Literal[1], Literal[True], Literal["string output"]], raw.cast_to(types, types)) + return cast(Union[Literal[1], Literal[True], Literal["string output"]], raw.cast_to(types, types, partial_types, False)) async def MakeBlockConstraint( self, @@ -1566,7 +1566,7 @@ async def MakeBlockConstraint( tb, __cr__, ) - return cast(Checked[types.BlockConstraint,types.Literal["cross_field"]], raw.cast_to(types, types)) + return cast(Checked[types.BlockConstraint,types.Literal["cross_field"]], raw.cast_to(types, types, partial_types, False)) async def MakeNestedBlockConstraint( self, @@ -1589,7 +1589,30 @@ async def MakeNestedBlockConstraint( tb, __cr__, ) - return cast(types.NestedBlockConstraint, raw.cast_to(types, types)) + return cast(types.NestedBlockConstraint, raw.cast_to(types, types, partial_types, False)) + + async def MakeSemanticContainer( + self, + + baml_options: BamlCallOptions = {}, + ) -> types.SemanticContainer: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = await self.__runtime.call_function( + "MakeSemanticContainer", + { + + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(types.SemanticContainer, raw.cast_to(types, types, partial_types, False)) async def MapAlias( self, @@ -1612,7 +1635,7 @@ async def MapAlias( tb, __cr__, ) - return cast(Dict[str, List[str]], raw.cast_to(types, types)) + return cast(Dict[str, List[str]], raw.cast_to(types, types, partial_types, False)) async def MergeAliasAttributes( self, @@ -1635,7 +1658,7 @@ async def MergeAliasAttributes( tb, __cr__, ) - return cast(types.MergeAttrs, raw.cast_to(types, types)) + return cast(types.MergeAttrs, raw.cast_to(types, types, partial_types, False)) async def MyFunc( self, @@ -1658,7 +1681,7 @@ async def MyFunc( tb, __cr__, ) - return cast(types.DynamicOutput, raw.cast_to(types, types)) + return cast(types.DynamicOutput, raw.cast_to(types, types, partial_types, False)) async def NestedAlias( self, @@ -1681,7 +1704,7 @@ async def NestedAlias( tb, __cr__, ) - return cast(Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]], raw.cast_to(types, types)) + return cast(Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]], raw.cast_to(types, types, partial_types, False)) async def NullLiteralClassHello( self, @@ -1704,7 +1727,7 @@ async def NullLiteralClassHello( tb, __cr__, ) - return cast(types.ClassForNullLiteral, raw.cast_to(types, types)) + return cast(types.ClassForNullLiteral, raw.cast_to(types, types, partial_types, False)) async def OptionalTest_Function( self, @@ -1727,7 +1750,7 @@ async def OptionalTest_Function( tb, __cr__, ) - return cast(List[Optional[types.OptionalTest_ReturnType]], raw.cast_to(types, types)) + return cast(List[Optional[types.OptionalTest_ReturnType]], raw.cast_to(types, types, partial_types, False)) async def PredictAge( self, @@ -1750,7 +1773,7 @@ async def PredictAge( tb, __cr__, ) - return cast(types.FooAny, raw.cast_to(types, types)) + return cast(types.FooAny, raw.cast_to(types, types, partial_types, False)) async def PredictAgeBare( self, @@ -1773,7 +1796,7 @@ async def PredictAgeBare( tb, __cr__, ) - return cast(Checked[int,types.Literal["too_big"]], raw.cast_to(types, types)) + return cast(Checked[int,types.Literal["too_big"]], raw.cast_to(types, types, partial_types, False)) async def PrimitiveAlias( self, @@ -1796,7 +1819,7 @@ async def PrimitiveAlias( tb, __cr__, ) - return cast(Union[int, str, bool, float], raw.cast_to(types, types)) + return cast(Union[int, str, bool, float], raw.cast_to(types, types, partial_types, False)) async def PromptTestClaude( self, @@ -1819,7 +1842,7 @@ async def PromptTestClaude( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def PromptTestClaudeChat( self, @@ -1842,7 +1865,7 @@ async def PromptTestClaudeChat( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def PromptTestClaudeChatNoSystem( self, @@ -1865,7 +1888,7 @@ async def PromptTestClaudeChatNoSystem( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def PromptTestOpenAI( self, @@ -1888,7 +1911,7 @@ async def PromptTestOpenAI( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def PromptTestOpenAIChat( self, @@ -1911,7 +1934,7 @@ async def PromptTestOpenAIChat( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def PromptTestOpenAIChatNoSystem( self, @@ -1934,7 +1957,7 @@ async def PromptTestOpenAIChatNoSystem( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def PromptTestStreaming( self, @@ -1957,7 +1980,7 @@ async def PromptTestStreaming( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def RecursiveAliasCycle( self, @@ -1980,7 +2003,7 @@ async def RecursiveAliasCycle( tb, __cr__, ) - return cast(types.RecAliasOne, raw.cast_to(types, types)) + return cast(types.RecAliasOne, raw.cast_to(types, types, partial_types, False)) async def RecursiveClassWithAliasIndirection( self, @@ -2003,7 +2026,7 @@ async def RecursiveClassWithAliasIndirection( tb, __cr__, ) - return cast(types.NodeWithAliasIndirection, raw.cast_to(types, types)) + return cast(types.NodeWithAliasIndirection, raw.cast_to(types, types, partial_types, False)) async def ReturnAliasWithMergedAttributes( self, @@ -2026,7 +2049,7 @@ async def ReturnAliasWithMergedAttributes( tb, __cr__, ) - return cast(Checked[int,types.Literal["gt_ten"]], raw.cast_to(types, types)) + return cast(Checked[int,types.Literal["gt_ten"]], raw.cast_to(types, types, partial_types, False)) async def ReturnFailingAssert( self, @@ -2049,7 +2072,7 @@ async def ReturnFailingAssert( tb, __cr__, ) - return cast(int, raw.cast_to(types, types)) + return cast(int, raw.cast_to(types, types, partial_types, False)) async def ReturnMalformedConstraints( self, @@ -2072,7 +2095,7 @@ async def ReturnMalformedConstraints( tb, __cr__, ) - return cast(types.MalformedConstraints, raw.cast_to(types, types)) + return cast(types.MalformedConstraints, raw.cast_to(types, types, partial_types, False)) async def SchemaDescriptions( self, @@ -2095,7 +2118,7 @@ async def SchemaDescriptions( tb, __cr__, ) - return cast(types.Schema, raw.cast_to(types, types)) + return cast(types.Schema, raw.cast_to(types, types, partial_types, False)) async def SimpleRecursiveListAlias( self, @@ -2118,7 +2141,7 @@ async def SimpleRecursiveListAlias( tb, __cr__, ) - return cast(types.RecursiveListAlias, raw.cast_to(types, types)) + return cast(types.RecursiveListAlias, raw.cast_to(types, types, partial_types, False)) async def SimpleRecursiveMapAlias( self, @@ -2141,7 +2164,7 @@ async def SimpleRecursiveMapAlias( tb, __cr__, ) - return cast(types.RecursiveMapAlias, raw.cast_to(types, types)) + return cast(types.RecursiveMapAlias, raw.cast_to(types, types, partial_types, False)) async def StreamBigNumbers( self, @@ -2164,7 +2187,7 @@ async def StreamBigNumbers( tb, __cr__, ) - return cast(types.BigNumbers, raw.cast_to(types, types)) + return cast(types.BigNumbers, raw.cast_to(types, types, partial_types, False)) async def StreamFailingAssertion( self, @@ -2187,7 +2210,7 @@ async def StreamFailingAssertion( tb, __cr__, ) - return cast(types.TwoStoriesOneTitle, raw.cast_to(types, types)) + return cast(types.TwoStoriesOneTitle, raw.cast_to(types, types, partial_types, False)) async def StreamOneBigNumber( self, @@ -2210,7 +2233,7 @@ async def StreamOneBigNumber( tb, __cr__, ) - return cast(int, raw.cast_to(types, types)) + return cast(int, raw.cast_to(types, types, partial_types, False)) async def StreamUnionIntegers( self, @@ -2233,7 +2256,7 @@ async def StreamUnionIntegers( tb, __cr__, ) - return cast(List[Union[int, str]], raw.cast_to(types, types)) + return cast(List[Union[int, str]], raw.cast_to(types, types, partial_types, False)) async def StreamingCompoundNumbers( self, @@ -2256,7 +2279,7 @@ async def StreamingCompoundNumbers( tb, __cr__, ) - return cast(types.CompoundBigNumbers, raw.cast_to(types, types)) + return cast(types.CompoundBigNumbers, raw.cast_to(types, types, partial_types, False)) async def TestAnthropic( self, @@ -2279,7 +2302,7 @@ async def TestAnthropic( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestAnthropicShorthand( self, @@ -2302,7 +2325,7 @@ async def TestAnthropicShorthand( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestAws( self, @@ -2325,7 +2348,7 @@ async def TestAws( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestAwsInvalidAccessKey( self, @@ -2348,7 +2371,7 @@ async def TestAwsInvalidAccessKey( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestAwsInvalidProfile( self, @@ -2371,7 +2394,7 @@ async def TestAwsInvalidProfile( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestAwsInvalidRegion( self, @@ -2394,7 +2417,7 @@ async def TestAwsInvalidRegion( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestAwsInvalidSessionToken( self, @@ -2417,7 +2440,7 @@ async def TestAwsInvalidSessionToken( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestAzure( self, @@ -2440,7 +2463,7 @@ async def TestAzure( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestAzureFailure( self, @@ -2463,7 +2486,7 @@ async def TestAzureFailure( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestCaching( self, @@ -2486,7 +2509,7 @@ async def TestCaching( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestFallbackClient( self, @@ -2509,7 +2532,7 @@ async def TestFallbackClient( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestFallbackToShorthand( self, @@ -2532,7 +2555,7 @@ async def TestFallbackToShorthand( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestFnNamedArgsSingleBool( self, @@ -2555,7 +2578,7 @@ async def TestFnNamedArgsSingleBool( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestFnNamedArgsSingleClass( self, @@ -2578,7 +2601,7 @@ async def TestFnNamedArgsSingleClass( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestFnNamedArgsSingleEnumList( self, @@ -2601,7 +2624,7 @@ async def TestFnNamedArgsSingleEnumList( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestFnNamedArgsSingleFloat( self, @@ -2624,7 +2647,7 @@ async def TestFnNamedArgsSingleFloat( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestFnNamedArgsSingleInt( self, @@ -2647,7 +2670,7 @@ async def TestFnNamedArgsSingleInt( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestFnNamedArgsSingleMapStringToClass( self, @@ -2670,7 +2693,7 @@ async def TestFnNamedArgsSingleMapStringToClass( tb, __cr__, ) - return cast(Dict[str, types.StringToClassEntry], raw.cast_to(types, types)) + return cast(Dict[str, types.StringToClassEntry], raw.cast_to(types, types, partial_types, False)) async def TestFnNamedArgsSingleMapStringToMap( self, @@ -2693,7 +2716,7 @@ async def TestFnNamedArgsSingleMapStringToMap( tb, __cr__, ) - return cast(Dict[str, Dict[str, str]], raw.cast_to(types, types)) + return cast(Dict[str, Dict[str, str]], raw.cast_to(types, types, partial_types, False)) async def TestFnNamedArgsSingleMapStringToString( self, @@ -2716,7 +2739,7 @@ async def TestFnNamedArgsSingleMapStringToString( tb, __cr__, ) - return cast(Dict[str, str], raw.cast_to(types, types)) + return cast(Dict[str, str], raw.cast_to(types, types, partial_types, False)) async def TestFnNamedArgsSingleString( self, @@ -2739,7 +2762,7 @@ async def TestFnNamedArgsSingleString( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestFnNamedArgsSingleStringArray( self, @@ -2762,7 +2785,7 @@ async def TestFnNamedArgsSingleStringArray( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestFnNamedArgsSingleStringList( self, @@ -2785,7 +2808,7 @@ async def TestFnNamedArgsSingleStringList( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestGemini( self, @@ -2808,7 +2831,7 @@ async def TestGemini( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestGeminiOpenAiGeneric( self, @@ -2831,7 +2854,7 @@ async def TestGeminiOpenAiGeneric( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestGeminiSystem( self, @@ -2854,7 +2877,7 @@ async def TestGeminiSystem( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestGeminiSystemAsChat( self, @@ -2877,7 +2900,7 @@ async def TestGeminiSystemAsChat( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestImageInput( self, @@ -2900,7 +2923,7 @@ async def TestImageInput( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestImageInputAnthropic( self, @@ -2923,7 +2946,7 @@ async def TestImageInputAnthropic( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestImageListInput( self, @@ -2946,7 +2969,7 @@ async def TestImageListInput( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestMulticlassNamedArgs( self, @@ -2969,7 +2992,7 @@ async def TestMulticlassNamedArgs( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestNamedArgsLiteralBool( self, @@ -2992,7 +3015,7 @@ async def TestNamedArgsLiteralBool( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestNamedArgsLiteralInt( self, @@ -3015,7 +3038,7 @@ async def TestNamedArgsLiteralInt( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestNamedArgsLiteralString( self, @@ -3038,7 +3061,7 @@ async def TestNamedArgsLiteralString( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestOllama( self, @@ -3061,7 +3084,7 @@ async def TestOllama( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestOpenAILegacyProvider( self, @@ -3084,7 +3107,7 @@ async def TestOpenAILegacyProvider( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestOpenAIShorthand( self, @@ -3107,7 +3130,7 @@ async def TestOpenAIShorthand( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestRetryConstant( self, @@ -3130,7 +3153,7 @@ async def TestRetryConstant( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestRetryExponential( self, @@ -3153,7 +3176,7 @@ async def TestRetryExponential( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestSingleFallbackClient( self, @@ -3176,7 +3199,7 @@ async def TestSingleFallbackClient( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestVertex( self, @@ -3199,7 +3222,7 @@ async def TestVertex( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def TestVertexWithSystemInstructions( self, @@ -3222,7 +3245,7 @@ async def TestVertexWithSystemInstructions( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) async def UnionTest_Function( self, @@ -3245,7 +3268,7 @@ async def UnionTest_Function( tb, __cr__, ) - return cast(types.UnionTest_ReturnType, raw.cast_to(types, types)) + return cast(types.UnionTest_ReturnType, raw.cast_to(types, types, partial_types, False)) async def UseBlockConstraint( self, @@ -3268,7 +3291,7 @@ async def UseBlockConstraint( tb, __cr__, ) - return cast(int, raw.cast_to(types, types)) + return cast(int, raw.cast_to(types, types, partial_types, False)) async def UseMalformedConstraints( self, @@ -3291,7 +3314,7 @@ async def UseMalformedConstraints( tb, __cr__, ) - return cast(int, raw.cast_to(types, types)) + return cast(int, raw.cast_to(types, types, partial_types, False)) async def UseNestedBlockConstraint( self, @@ -3314,7 +3337,7 @@ async def UseNestedBlockConstraint( tb, __cr__, ) - return cast(int, raw.cast_to(types, types)) + return cast(int, raw.cast_to(types, types, partial_types, False)) @@ -3352,8 +3375,8 @@ def AaaSamOutputFormat( return baml_py.BamlStream[partial_types.Recipe, types.Recipe]( raw, - lambda x: cast(partial_types.Recipe, x.cast_to(types, partial_types)), - lambda x: cast(types.Recipe, x.cast_to(types, types)), + lambda x: cast(partial_types.Recipe, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.Recipe, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3382,8 +3405,8 @@ def AliasThatPointsToRecursiveType( return baml_py.BamlStream[partial_types.LinkedListAliasNode, types.LinkedListAliasNode]( raw, - lambda x: cast(partial_types.LinkedListAliasNode, x.cast_to(types, partial_types)), - lambda x: cast(types.LinkedListAliasNode, x.cast_to(types, types)), + lambda x: cast(partial_types.LinkedListAliasNode, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.LinkedListAliasNode, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3412,8 +3435,8 @@ def AliasWithMultipleAttrs( return baml_py.BamlStream[Checked[Optional[int],types.Literal["gt_ten"]], Checked[int,types.Literal["gt_ten"]]]( raw, - lambda x: cast(Checked[Optional[int],types.Literal["gt_ten"]], x.cast_to(types, partial_types)), - lambda x: cast(Checked[int,types.Literal["gt_ten"]], x.cast_to(types, types)), + lambda x: cast(Checked[Optional[int],types.Literal["gt_ten"]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Checked[int,types.Literal["gt_ten"]], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3442,8 +3465,8 @@ def AliasedInputClass( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3472,8 +3495,8 @@ def AliasedInputClass2( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3502,8 +3525,8 @@ def AliasedInputClassNested( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3532,8 +3555,8 @@ def AliasedInputEnum( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3562,8 +3585,8 @@ def AliasedInputList( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3592,8 +3615,8 @@ def AllowedOptionals( return baml_py.BamlStream[partial_types.OptionalListAndMap, types.OptionalListAndMap]( raw, - lambda x: cast(partial_types.OptionalListAndMap, x.cast_to(types, partial_types)), - lambda x: cast(types.OptionalListAndMap, x.cast_to(types, types)), + lambda x: cast(partial_types.OptionalListAndMap, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.OptionalListAndMap, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3622,8 +3645,8 @@ def AssertFn( return baml_py.BamlStream[Optional[int], int]( raw, - lambda x: cast(Optional[int], x.cast_to(types, partial_types)), - lambda x: cast(int, x.cast_to(types, types)), + lambda x: cast(Optional[int], x.cast_to(types, types, partial_types, True)), + lambda x: cast(int, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3652,8 +3675,8 @@ def AudioInput( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3682,8 +3705,8 @@ def BuildLinkedList( return baml_py.BamlStream[partial_types.LinkedList, types.LinkedList]( raw, - lambda x: cast(partial_types.LinkedList, x.cast_to(types, partial_types)), - lambda x: cast(types.LinkedList, x.cast_to(types, types)), + lambda x: cast(partial_types.LinkedList, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.LinkedList, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3712,8 +3735,8 @@ def BuildTree( return baml_py.BamlStream[partial_types.Tree, types.Tree]( raw, - lambda x: cast(partial_types.Tree, x.cast_to(types, partial_types)), - lambda x: cast(types.Tree, x.cast_to(types, types)), + lambda x: cast(partial_types.Tree, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.Tree, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3742,8 +3765,8 @@ def ClassThatPointsToRecursiveClassThroughAlias( return baml_py.BamlStream[partial_types.ClassToRecAlias, types.ClassToRecAlias]( raw, - lambda x: cast(partial_types.ClassToRecAlias, x.cast_to(types, partial_types)), - lambda x: cast(types.ClassToRecAlias, x.cast_to(types, types)), + lambda x: cast(partial_types.ClassToRecAlias, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.ClassToRecAlias, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3772,8 +3795,8 @@ def ClassifyDynEnumTwo( return baml_py.BamlStream[Optional[Union[types.DynEnumTwo, str]], Union[types.DynEnumTwo, str]]( raw, - lambda x: cast(Optional[Union[types.DynEnumTwo, str]], x.cast_to(types, partial_types)), - lambda x: cast(Union[types.DynEnumTwo, str], x.cast_to(types, types)), + lambda x: cast(Optional[Union[types.DynEnumTwo, str]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Union[types.DynEnumTwo, str], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3802,8 +3825,8 @@ def ClassifyMessage( return baml_py.BamlStream[Optional[types.Category], types.Category]( raw, - lambda x: cast(Optional[types.Category], x.cast_to(types, partial_types)), - lambda x: cast(types.Category, x.cast_to(types, types)), + lambda x: cast(Optional[types.Category], x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.Category, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3832,8 +3855,8 @@ def ClassifyMessage2( return baml_py.BamlStream[Optional[types.Category], types.Category]( raw, - lambda x: cast(Optional[types.Category], x.cast_to(types, partial_types)), - lambda x: cast(types.Category, x.cast_to(types, types)), + lambda x: cast(Optional[types.Category], x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.Category, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3862,8 +3885,8 @@ def ClassifyMessage3( return baml_py.BamlStream[Optional[types.Category], types.Category]( raw, - lambda x: cast(Optional[types.Category], x.cast_to(types, partial_types)), - lambda x: cast(types.Category, x.cast_to(types, types)), + lambda x: cast(Optional[types.Category], x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.Category, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3894,8 +3917,8 @@ def Completion( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3924,8 +3947,8 @@ def CustomTask( return baml_py.BamlStream[Optional[Union[partial_types.BookOrder, partial_types.FlightConfirmation, partial_types.GroceryReceipt]], Union[types.BookOrder, types.FlightConfirmation, types.GroceryReceipt]]( raw, - lambda x: cast(Optional[Union[partial_types.BookOrder, partial_types.FlightConfirmation, partial_types.GroceryReceipt]], x.cast_to(types, partial_types)), - lambda x: cast(Union[types.BookOrder, types.FlightConfirmation, types.GroceryReceipt], x.cast_to(types, types)), + lambda x: cast(Optional[Union[partial_types.BookOrder, partial_types.FlightConfirmation, partial_types.GroceryReceipt]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Union[types.BookOrder, types.FlightConfirmation, types.GroceryReceipt], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3954,8 +3977,8 @@ def DescribeImage( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3985,8 +4008,8 @@ def DescribeImage2( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4016,8 +4039,8 @@ def DescribeImage3( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4047,8 +4070,8 @@ def DescribeImage4( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4076,8 +4099,8 @@ def DifferentiateUnions( return baml_py.BamlStream[Optional[Union[partial_types.OriginalA, partial_types.OriginalB]], Union[types.OriginalA, types.OriginalB]]( raw, - lambda x: cast(Optional[Union[partial_types.OriginalA, partial_types.OriginalB]], x.cast_to(types, partial_types)), - lambda x: cast(Union[types.OriginalA, types.OriginalB], x.cast_to(types, types)), + lambda x: cast(Optional[Union[partial_types.OriginalA, partial_types.OriginalB]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Union[types.OriginalA, types.OriginalB], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4106,8 +4129,8 @@ def DummyOutputFunction( return baml_py.BamlStream[partial_types.DummyOutput, types.DummyOutput]( raw, - lambda x: cast(partial_types.DummyOutput, x.cast_to(types, partial_types)), - lambda x: cast(types.DummyOutput, x.cast_to(types, types)), + lambda x: cast(partial_types.DummyOutput, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.DummyOutput, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4136,8 +4159,8 @@ def DynamicFunc( return baml_py.BamlStream[partial_types.DynamicClassTwo, types.DynamicClassTwo]( raw, - lambda x: cast(partial_types.DynamicClassTwo, x.cast_to(types, partial_types)), - lambda x: cast(types.DynamicClassTwo, x.cast_to(types, types)), + lambda x: cast(partial_types.DynamicClassTwo, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.DynamicClassTwo, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4166,8 +4189,8 @@ def DynamicInputOutput( return baml_py.BamlStream[partial_types.DynInputOutput, types.DynInputOutput]( raw, - lambda x: cast(partial_types.DynInputOutput, x.cast_to(types, partial_types)), - lambda x: cast(types.DynInputOutput, x.cast_to(types, types)), + lambda x: cast(partial_types.DynInputOutput, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.DynInputOutput, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4196,8 +4219,8 @@ def DynamicListInputOutput( return baml_py.BamlStream[List[partial_types.DynInputOutput], List[types.DynInputOutput]]( raw, - lambda x: cast(List[partial_types.DynInputOutput], x.cast_to(types, partial_types)), - lambda x: cast(List[types.DynInputOutput], x.cast_to(types, types)), + lambda x: cast(List[partial_types.DynInputOutput], x.cast_to(types, types, partial_types, True)), + lambda x: cast(List[types.DynInputOutput], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4225,8 +4248,8 @@ def ExpectFailure( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4255,8 +4278,8 @@ def ExtractContactInfo( return baml_py.BamlStream[partial_types.ContactInfo, types.ContactInfo]( raw, - lambda x: cast(partial_types.ContactInfo, x.cast_to(types, partial_types)), - lambda x: cast(types.ContactInfo, x.cast_to(types, types)), + lambda x: cast(partial_types.ContactInfo, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.ContactInfo, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4285,8 +4308,8 @@ def ExtractHobby( return baml_py.BamlStream[List[Optional[Union[types.Hobby, str]]], List[Union[types.Hobby, str]]]( raw, - lambda x: cast(List[Optional[Union[types.Hobby, str]]], x.cast_to(types, partial_types)), - lambda x: cast(List[Union[types.Hobby, str]], x.cast_to(types, types)), + lambda x: cast(List[Optional[Union[types.Hobby, str]]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(List[Union[types.Hobby, str]], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4315,8 +4338,8 @@ def ExtractNames( return baml_py.BamlStream[List[Optional[str]], List[str]]( raw, - lambda x: cast(List[Optional[str]], x.cast_to(types, partial_types)), - lambda x: cast(List[str], x.cast_to(types, types)), + lambda x: cast(List[Optional[str]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(List[str], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4345,8 +4368,8 @@ def ExtractPeople( return baml_py.BamlStream[List[partial_types.Person], List[types.Person]]( raw, - lambda x: cast(List[partial_types.Person], x.cast_to(types, partial_types)), - lambda x: cast(List[types.Person], x.cast_to(types, types)), + lambda x: cast(List[partial_types.Person], x.cast_to(types, types, partial_types, True)), + lambda x: cast(List[types.Person], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4376,8 +4399,8 @@ def ExtractReceiptInfo( return baml_py.BamlStream[partial_types.ReceiptInfo, types.ReceiptInfo]( raw, - lambda x: cast(partial_types.ReceiptInfo, x.cast_to(types, partial_types)), - lambda x: cast(types.ReceiptInfo, x.cast_to(types, types)), + lambda x: cast(partial_types.ReceiptInfo, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.ReceiptInfo, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4407,8 +4430,8 @@ def ExtractResume( return baml_py.BamlStream[partial_types.Resume, types.Resume]( raw, - lambda x: cast(partial_types.Resume, x.cast_to(types, partial_types)), - lambda x: cast(types.Resume, x.cast_to(types, types)), + lambda x: cast(partial_types.Resume, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.Resume, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4437,8 +4460,8 @@ def ExtractResume2( return baml_py.BamlStream[partial_types.Resume, types.Resume]( raw, - lambda x: cast(partial_types.Resume, x.cast_to(types, partial_types)), - lambda x: cast(types.Resume, x.cast_to(types, types)), + lambda x: cast(partial_types.Resume, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.Resume, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4467,8 +4490,8 @@ def FnClassOptionalOutput( return baml_py.BamlStream[partial_types.ClassOptionalOutput, Optional[types.ClassOptionalOutput]]( raw, - lambda x: cast(partial_types.ClassOptionalOutput, x.cast_to(types, partial_types)), - lambda x: cast(Optional[types.ClassOptionalOutput], x.cast_to(types, types)), + lambda x: cast(partial_types.ClassOptionalOutput, x.cast_to(types, types, partial_types, True)), + lambda x: cast(Optional[types.ClassOptionalOutput], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4497,8 +4520,8 @@ def FnClassOptionalOutput2( return baml_py.BamlStream[partial_types.ClassOptionalOutput2, Optional[types.ClassOptionalOutput2]]( raw, - lambda x: cast(partial_types.ClassOptionalOutput2, x.cast_to(types, partial_types)), - lambda x: cast(Optional[types.ClassOptionalOutput2], x.cast_to(types, types)), + lambda x: cast(partial_types.ClassOptionalOutput2, x.cast_to(types, types, partial_types, True)), + lambda x: cast(Optional[types.ClassOptionalOutput2], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4527,8 +4550,8 @@ def FnEnumListOutput( return baml_py.BamlStream[List[Optional[types.EnumOutput]], List[types.EnumOutput]]( raw, - lambda x: cast(List[Optional[types.EnumOutput]], x.cast_to(types, partial_types)), - lambda x: cast(List[types.EnumOutput], x.cast_to(types, types)), + lambda x: cast(List[Optional[types.EnumOutput]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(List[types.EnumOutput], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4557,8 +4580,8 @@ def FnEnumOutput( return baml_py.BamlStream[Optional[types.EnumOutput], types.EnumOutput]( raw, - lambda x: cast(Optional[types.EnumOutput], x.cast_to(types, partial_types)), - lambda x: cast(types.EnumOutput, x.cast_to(types, types)), + lambda x: cast(Optional[types.EnumOutput], x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.EnumOutput, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4587,8 +4610,8 @@ def FnLiteralClassInputOutput( return baml_py.BamlStream[partial_types.LiteralClassHello, types.LiteralClassHello]( raw, - lambda x: cast(partial_types.LiteralClassHello, x.cast_to(types, partial_types)), - lambda x: cast(types.LiteralClassHello, x.cast_to(types, types)), + lambda x: cast(partial_types.LiteralClassHello, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.LiteralClassHello, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4617,8 +4640,8 @@ def FnLiteralUnionClassInputOutput( return baml_py.BamlStream[Optional[Union[partial_types.LiteralClassOne, partial_types.LiteralClassTwo]], Union[types.LiteralClassOne, types.LiteralClassTwo]]( raw, - lambda x: cast(Optional[Union[partial_types.LiteralClassOne, partial_types.LiteralClassTwo]], x.cast_to(types, partial_types)), - lambda x: cast(Union[types.LiteralClassOne, types.LiteralClassTwo], x.cast_to(types, types)), + lambda x: cast(Optional[Union[partial_types.LiteralClassOne, partial_types.LiteralClassTwo]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Union[types.LiteralClassOne, types.LiteralClassTwo], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4647,8 +4670,8 @@ def FnNamedArgsSingleStringOptional( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4677,8 +4700,8 @@ def FnOutputBool( return baml_py.BamlStream[Optional[bool], bool]( raw, - lambda x: cast(Optional[bool], x.cast_to(types, partial_types)), - lambda x: cast(bool, x.cast_to(types, types)), + lambda x: cast(Optional[bool], x.cast_to(types, types, partial_types, True)), + lambda x: cast(bool, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4707,8 +4730,8 @@ def FnOutputClass( return baml_py.BamlStream[partial_types.TestOutputClass, types.TestOutputClass]( raw, - lambda x: cast(partial_types.TestOutputClass, x.cast_to(types, partial_types)), - lambda x: cast(types.TestOutputClass, x.cast_to(types, types)), + lambda x: cast(partial_types.TestOutputClass, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.TestOutputClass, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4737,8 +4760,8 @@ def FnOutputClassList( return baml_py.BamlStream[List[partial_types.TestOutputClass], List[types.TestOutputClass]]( raw, - lambda x: cast(List[partial_types.TestOutputClass], x.cast_to(types, partial_types)), - lambda x: cast(List[types.TestOutputClass], x.cast_to(types, types)), + lambda x: cast(List[partial_types.TestOutputClass], x.cast_to(types, types, partial_types, True)), + lambda x: cast(List[types.TestOutputClass], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4767,8 +4790,8 @@ def FnOutputClassNested( return baml_py.BamlStream[partial_types.TestClassNested, types.TestClassNested]( raw, - lambda x: cast(partial_types.TestClassNested, x.cast_to(types, partial_types)), - lambda x: cast(types.TestClassNested, x.cast_to(types, types)), + lambda x: cast(partial_types.TestClassNested, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.TestClassNested, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4797,8 +4820,8 @@ def FnOutputClassWithEnum( return baml_py.BamlStream[partial_types.TestClassWithEnum, types.TestClassWithEnum]( raw, - lambda x: cast(partial_types.TestClassWithEnum, x.cast_to(types, partial_types)), - lambda x: cast(types.TestClassWithEnum, x.cast_to(types, types)), + lambda x: cast(partial_types.TestClassWithEnum, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.TestClassWithEnum, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4827,8 +4850,8 @@ def FnOutputInt( return baml_py.BamlStream[Optional[int], int]( raw, - lambda x: cast(Optional[int], x.cast_to(types, partial_types)), - lambda x: cast(int, x.cast_to(types, types)), + lambda x: cast(Optional[int], x.cast_to(types, types, partial_types, True)), + lambda x: cast(int, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4857,8 +4880,8 @@ def FnOutputLiteralBool( return baml_py.BamlStream[Optional[Literal[False]], Literal[False]]( raw, - lambda x: cast(Optional[Literal[False]], x.cast_to(types, partial_types)), - lambda x: cast(Literal[False], x.cast_to(types, types)), + lambda x: cast(Optional[Literal[False]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Literal[False], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4887,8 +4910,8 @@ def FnOutputLiteralInt( return baml_py.BamlStream[Optional[Literal[5]], Literal[5]]( raw, - lambda x: cast(Optional[Literal[5]], x.cast_to(types, partial_types)), - lambda x: cast(Literal[5], x.cast_to(types, types)), + lambda x: cast(Optional[Literal[5]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Literal[5], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4917,8 +4940,8 @@ def FnOutputLiteralString( return baml_py.BamlStream[Optional[Literal["example output"]], Literal["example output"]]( raw, - lambda x: cast(Optional[Literal["example output"]], x.cast_to(types, partial_types)), - lambda x: cast(Literal["example output"], x.cast_to(types, types)), + lambda x: cast(Optional[Literal["example output"]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Literal["example output"], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4947,8 +4970,8 @@ def FnOutputStringList( return baml_py.BamlStream[List[Optional[str]], List[str]]( raw, - lambda x: cast(List[Optional[str]], x.cast_to(types, partial_types)), - lambda x: cast(List[str], x.cast_to(types, types)), + lambda x: cast(List[Optional[str]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(List[str], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4977,8 +5000,8 @@ def FnTestAliasedEnumOutput( return baml_py.BamlStream[Optional[types.TestEnum], types.TestEnum]( raw, - lambda x: cast(Optional[types.TestEnum], x.cast_to(types, partial_types)), - lambda x: cast(types.TestEnum, x.cast_to(types, types)), + lambda x: cast(Optional[types.TestEnum], x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.TestEnum, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5007,8 +5030,8 @@ def FnTestClassAlias( return baml_py.BamlStream[partial_types.TestClassAlias, types.TestClassAlias]( raw, - lambda x: cast(partial_types.TestClassAlias, x.cast_to(types, partial_types)), - lambda x: cast(types.TestClassAlias, x.cast_to(types, types)), + lambda x: cast(partial_types.TestClassAlias, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.TestClassAlias, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5037,8 +5060,8 @@ def FnTestNamedArgsSingleEnum( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5067,8 +5090,8 @@ def GetDataType( return baml_py.BamlStream[partial_types.RaysData, types.RaysData]( raw, - lambda x: cast(partial_types.RaysData, x.cast_to(types, partial_types)), - lambda x: cast(types.RaysData, x.cast_to(types, types)), + lambda x: cast(partial_types.RaysData, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.RaysData, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5097,8 +5120,8 @@ def GetOrderInfo( return baml_py.BamlStream[partial_types.OrderInfo, types.OrderInfo]( raw, - lambda x: cast(partial_types.OrderInfo, x.cast_to(types, partial_types)), - lambda x: cast(types.OrderInfo, x.cast_to(types, types)), + lambda x: cast(partial_types.OrderInfo, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.OrderInfo, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5127,8 +5150,8 @@ def GetQuery( return baml_py.BamlStream[partial_types.SearchParams, types.SearchParams]( raw, - lambda x: cast(partial_types.SearchParams, x.cast_to(types, partial_types)), - lambda x: cast(types.SearchParams, x.cast_to(types, types)), + lambda x: cast(partial_types.SearchParams, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.SearchParams, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5158,8 +5181,8 @@ def InOutEnumMapKey( return baml_py.BamlStream[Dict[types.MapKey, Optional[str]], Dict[types.MapKey, str]]( raw, - lambda x: cast(Dict[types.MapKey, Optional[str]], x.cast_to(types, partial_types)), - lambda x: cast(Dict[types.MapKey, str], x.cast_to(types, types)), + lambda x: cast(Dict[types.MapKey, Optional[str]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Dict[types.MapKey, str], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5189,8 +5212,8 @@ def InOutLiteralStringUnionMapKey( return baml_py.BamlStream[Dict[Union[Literal["one"], Literal["two"], Union[Literal["three"], Literal["four"]]], Optional[str]], Dict[Union[Literal["one"], Literal["two"], Union[Literal["three"], Literal["four"]]], str]]( raw, - lambda x: cast(Dict[Union[Literal["one"], Literal["two"], Union[Literal["three"], Literal["four"]]], Optional[str]], x.cast_to(types, partial_types)), - lambda x: cast(Dict[Union[Literal["one"], Literal["two"], Union[Literal["three"], Literal["four"]]], str], x.cast_to(types, types)), + lambda x: cast(Dict[Union[Literal["one"], Literal["two"], Union[Literal["three"], Literal["four"]]], Optional[str]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Dict[Union[Literal["one"], Literal["two"], Union[Literal["three"], Literal["four"]]], str], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5219,8 +5242,8 @@ def InOutSingleLiteralStringMapKey( return baml_py.BamlStream[Dict[Literal["key"], Optional[str]], Dict[Literal["key"], str]]( raw, - lambda x: cast(Dict[Literal["key"], Optional[str]], x.cast_to(types, partial_types)), - lambda x: cast(Dict[Literal["key"], str], x.cast_to(types, types)), + lambda x: cast(Dict[Literal["key"], Optional[str]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Dict[Literal["key"], str], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5249,8 +5272,8 @@ def JsonTypeAliasCycle( return baml_py.BamlStream[types.JsonValue, types.JsonValue]( raw, - lambda x: cast(types.JsonValue, x.cast_to(types, partial_types)), - lambda x: cast(types.JsonValue, x.cast_to(types, types)), + lambda x: cast(types.JsonValue, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.JsonValue, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5279,8 +5302,8 @@ def LiteralUnionsTest( return baml_py.BamlStream[Optional[Union[Optional[Literal[1]], Optional[Literal[True]], Optional[Literal["string output"]]]], Union[Literal[1], Literal[True], Literal["string output"]]]( raw, - lambda x: cast(Optional[Union[Optional[Literal[1]], Optional[Literal[True]], Optional[Literal["string output"]]]], x.cast_to(types, partial_types)), - lambda x: cast(Union[Literal[1], Literal[True], Literal["string output"]], x.cast_to(types, types)), + lambda x: cast(Optional[Union[Optional[Literal[1]], Optional[Literal[True]], Optional[Literal["string output"]]]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Union[Literal[1], Literal[True], Literal["string output"]], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5308,8 +5331,8 @@ def MakeBlockConstraint( return baml_py.BamlStream[Checked[partial_types.BlockConstraint,types.Literal["cross_field"]], Checked[types.BlockConstraint,types.Literal["cross_field"]]]( raw, - lambda x: cast(Checked[partial_types.BlockConstraint,types.Literal["cross_field"]], x.cast_to(types, partial_types)), - lambda x: cast(Checked[types.BlockConstraint,types.Literal["cross_field"]], x.cast_to(types, types)), + lambda x: cast(Checked[partial_types.BlockConstraint,types.Literal["cross_field"]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Checked[types.BlockConstraint,types.Literal["cross_field"]], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5337,8 +5360,37 @@ def MakeNestedBlockConstraint( return baml_py.BamlStream[partial_types.NestedBlockConstraint, types.NestedBlockConstraint]( raw, - lambda x: cast(partial_types.NestedBlockConstraint, x.cast_to(types, partial_types)), - lambda x: cast(types.NestedBlockConstraint, x.cast_to(types, types)), + lambda x: cast(partial_types.NestedBlockConstraint, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.NestedBlockConstraint, x.cast_to(types, types, partial_types, False)), + self.__ctx_manager.get(), + ) + + def MakeSemanticContainer( + self, + + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlStream[partial_types.SemanticContainer, types.SemanticContainer]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function( + "MakeSemanticContainer", + { + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlStream[partial_types.SemanticContainer, types.SemanticContainer]( + raw, + lambda x: cast(partial_types.SemanticContainer, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.SemanticContainer, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5367,8 +5419,8 @@ def MapAlias( return baml_py.BamlStream[Dict[str, List[Optional[str]]], Dict[str, List[str]]]( raw, - lambda x: cast(Dict[str, List[Optional[str]]], x.cast_to(types, partial_types)), - lambda x: cast(Dict[str, List[str]], x.cast_to(types, types)), + lambda x: cast(Dict[str, List[Optional[str]]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Dict[str, List[str]], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5397,8 +5449,8 @@ def MergeAliasAttributes( return baml_py.BamlStream[partial_types.MergeAttrs, types.MergeAttrs]( raw, - lambda x: cast(partial_types.MergeAttrs, x.cast_to(types, partial_types)), - lambda x: cast(types.MergeAttrs, x.cast_to(types, types)), + lambda x: cast(partial_types.MergeAttrs, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.MergeAttrs, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5427,8 +5479,8 @@ def MyFunc( return baml_py.BamlStream[partial_types.DynamicOutput, types.DynamicOutput]( raw, - lambda x: cast(partial_types.DynamicOutput, x.cast_to(types, partial_types)), - lambda x: cast(types.DynamicOutput, x.cast_to(types, types)), + lambda x: cast(partial_types.DynamicOutput, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.DynamicOutput, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5457,8 +5509,8 @@ def NestedAlias( return baml_py.BamlStream[Optional[Union[Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], List[Optional[str]], Dict[str, List[Optional[str]]]]], Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]]]( raw, - lambda x: cast(Optional[Union[Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], List[Optional[str]], Dict[str, List[Optional[str]]]]], x.cast_to(types, partial_types)), - lambda x: cast(Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]], x.cast_to(types, types)), + lambda x: cast(Optional[Union[Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], List[Optional[str]], Dict[str, List[Optional[str]]]]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5487,8 +5539,8 @@ def NullLiteralClassHello( return baml_py.BamlStream[partial_types.ClassForNullLiteral, types.ClassForNullLiteral]( raw, - lambda x: cast(partial_types.ClassForNullLiteral, x.cast_to(types, partial_types)), - lambda x: cast(types.ClassForNullLiteral, x.cast_to(types, types)), + lambda x: cast(partial_types.ClassForNullLiteral, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.ClassForNullLiteral, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5517,8 +5569,8 @@ def OptionalTest_Function( return baml_py.BamlStream[List[partial_types.OptionalTest_ReturnType], List[Optional[types.OptionalTest_ReturnType]]]( raw, - lambda x: cast(List[partial_types.OptionalTest_ReturnType], x.cast_to(types, partial_types)), - lambda x: cast(List[Optional[types.OptionalTest_ReturnType]], x.cast_to(types, types)), + lambda x: cast(List[partial_types.OptionalTest_ReturnType], x.cast_to(types, types, partial_types, True)), + lambda x: cast(List[Optional[types.OptionalTest_ReturnType]], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5547,8 +5599,8 @@ def PredictAge( return baml_py.BamlStream[partial_types.FooAny, types.FooAny]( raw, - lambda x: cast(partial_types.FooAny, x.cast_to(types, partial_types)), - lambda x: cast(types.FooAny, x.cast_to(types, types)), + lambda x: cast(partial_types.FooAny, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.FooAny, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5577,8 +5629,8 @@ def PredictAgeBare( return baml_py.BamlStream[Checked[Optional[int],types.Literal["too_big"]], Checked[int,types.Literal["too_big"]]]( raw, - lambda x: cast(Checked[Optional[int],types.Literal["too_big"]], x.cast_to(types, partial_types)), - lambda x: cast(Checked[int,types.Literal["too_big"]], x.cast_to(types, types)), + lambda x: cast(Checked[Optional[int],types.Literal["too_big"]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Checked[int,types.Literal["too_big"]], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5607,8 +5659,8 @@ def PrimitiveAlias( return baml_py.BamlStream[Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], Union[int, str, bool, float]]( raw, - lambda x: cast(Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], x.cast_to(types, partial_types)), - lambda x: cast(Union[int, str, bool, float], x.cast_to(types, types)), + lambda x: cast(Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Union[int, str, bool, float], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5637,8 +5689,8 @@ def PromptTestClaude( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5667,8 +5719,8 @@ def PromptTestClaudeChat( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5697,8 +5749,8 @@ def PromptTestClaudeChatNoSystem( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5727,8 +5779,8 @@ def PromptTestOpenAI( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5757,8 +5809,8 @@ def PromptTestOpenAIChat( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5787,8 +5839,8 @@ def PromptTestOpenAIChatNoSystem( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5817,8 +5869,8 @@ def PromptTestStreaming( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5847,8 +5899,8 @@ def RecursiveAliasCycle( return baml_py.BamlStream[types.RecAliasOne, types.RecAliasOne]( raw, - lambda x: cast(types.RecAliasOne, x.cast_to(types, partial_types)), - lambda x: cast(types.RecAliasOne, x.cast_to(types, types)), + lambda x: cast(types.RecAliasOne, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.RecAliasOne, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5877,8 +5929,8 @@ def RecursiveClassWithAliasIndirection( return baml_py.BamlStream[partial_types.NodeWithAliasIndirection, types.NodeWithAliasIndirection]( raw, - lambda x: cast(partial_types.NodeWithAliasIndirection, x.cast_to(types, partial_types)), - lambda x: cast(types.NodeWithAliasIndirection, x.cast_to(types, types)), + lambda x: cast(partial_types.NodeWithAliasIndirection, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.NodeWithAliasIndirection, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5907,8 +5959,8 @@ def ReturnAliasWithMergedAttributes( return baml_py.BamlStream[Checked[Optional[int],types.Literal["gt_ten"]], Checked[int,types.Literal["gt_ten"]]]( raw, - lambda x: cast(Checked[Optional[int],types.Literal["gt_ten"]], x.cast_to(types, partial_types)), - lambda x: cast(Checked[int,types.Literal["gt_ten"]], x.cast_to(types, types)), + lambda x: cast(Checked[Optional[int],types.Literal["gt_ten"]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Checked[int,types.Literal["gt_ten"]], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5937,8 +5989,8 @@ def ReturnFailingAssert( return baml_py.BamlStream[Optional[int], int]( raw, - lambda x: cast(Optional[int], x.cast_to(types, partial_types)), - lambda x: cast(int, x.cast_to(types, types)), + lambda x: cast(Optional[int], x.cast_to(types, types, partial_types, True)), + lambda x: cast(int, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5967,8 +6019,8 @@ def ReturnMalformedConstraints( return baml_py.BamlStream[partial_types.MalformedConstraints, types.MalformedConstraints]( raw, - lambda x: cast(partial_types.MalformedConstraints, x.cast_to(types, partial_types)), - lambda x: cast(types.MalformedConstraints, x.cast_to(types, types)), + lambda x: cast(partial_types.MalformedConstraints, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.MalformedConstraints, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5997,8 +6049,8 @@ def SchemaDescriptions( return baml_py.BamlStream[partial_types.Schema, types.Schema]( raw, - lambda x: cast(partial_types.Schema, x.cast_to(types, partial_types)), - lambda x: cast(types.Schema, x.cast_to(types, types)), + lambda x: cast(partial_types.Schema, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.Schema, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6027,8 +6079,8 @@ def SimpleRecursiveListAlias( return baml_py.BamlStream[types.RecursiveListAlias, types.RecursiveListAlias]( raw, - lambda x: cast(types.RecursiveListAlias, x.cast_to(types, partial_types)), - lambda x: cast(types.RecursiveListAlias, x.cast_to(types, types)), + lambda x: cast(types.RecursiveListAlias, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.RecursiveListAlias, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6057,8 +6109,8 @@ def SimpleRecursiveMapAlias( return baml_py.BamlStream[types.RecursiveMapAlias, types.RecursiveMapAlias]( raw, - lambda x: cast(types.RecursiveMapAlias, x.cast_to(types, partial_types)), - lambda x: cast(types.RecursiveMapAlias, x.cast_to(types, types)), + lambda x: cast(types.RecursiveMapAlias, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.RecursiveMapAlias, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6087,8 +6139,8 @@ def StreamBigNumbers( return baml_py.BamlStream[partial_types.BigNumbers, types.BigNumbers]( raw, - lambda x: cast(partial_types.BigNumbers, x.cast_to(types, partial_types)), - lambda x: cast(types.BigNumbers, x.cast_to(types, types)), + lambda x: cast(partial_types.BigNumbers, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.BigNumbers, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6118,8 +6170,8 @@ def StreamFailingAssertion( return baml_py.BamlStream[partial_types.TwoStoriesOneTitle, types.TwoStoriesOneTitle]( raw, - lambda x: cast(partial_types.TwoStoriesOneTitle, x.cast_to(types, partial_types)), - lambda x: cast(types.TwoStoriesOneTitle, x.cast_to(types, types)), + lambda x: cast(partial_types.TwoStoriesOneTitle, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.TwoStoriesOneTitle, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6148,8 +6200,8 @@ def StreamOneBigNumber( return baml_py.BamlStream[Optional[int], int]( raw, - lambda x: cast(Optional[int], x.cast_to(types, partial_types)), - lambda x: cast(int, x.cast_to(types, types)), + lambda x: cast(Optional[int], x.cast_to(types, types, partial_types, True)), + lambda x: cast(int, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6178,8 +6230,8 @@ def StreamUnionIntegers( return baml_py.BamlStream[List[Optional[Union[Optional[int], Optional[str]]]], List[Union[int, str]]]( raw, - lambda x: cast(List[Optional[Union[Optional[int], Optional[str]]]], x.cast_to(types, partial_types)), - lambda x: cast(List[Union[int, str]], x.cast_to(types, types)), + lambda x: cast(List[Optional[Union[Optional[int], Optional[str]]]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(List[Union[int, str]], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6209,8 +6261,8 @@ def StreamingCompoundNumbers( return baml_py.BamlStream[partial_types.CompoundBigNumbers, types.CompoundBigNumbers]( raw, - lambda x: cast(partial_types.CompoundBigNumbers, x.cast_to(types, partial_types)), - lambda x: cast(types.CompoundBigNumbers, x.cast_to(types, types)), + lambda x: cast(partial_types.CompoundBigNumbers, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.CompoundBigNumbers, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6239,8 +6291,8 @@ def TestAnthropic( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6269,8 +6321,8 @@ def TestAnthropicShorthand( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6299,8 +6351,8 @@ def TestAws( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6329,8 +6381,8 @@ def TestAwsInvalidAccessKey( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6359,8 +6411,8 @@ def TestAwsInvalidProfile( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6389,8 +6441,8 @@ def TestAwsInvalidRegion( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6419,8 +6471,8 @@ def TestAwsInvalidSessionToken( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6449,8 +6501,8 @@ def TestAzure( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6479,8 +6531,8 @@ def TestAzureFailure( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6510,8 +6562,8 @@ def TestCaching( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6539,8 +6591,8 @@ def TestFallbackClient( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6569,8 +6621,8 @@ def TestFallbackToShorthand( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6599,8 +6651,8 @@ def TestFnNamedArgsSingleBool( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6629,8 +6681,8 @@ def TestFnNamedArgsSingleClass( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6659,8 +6711,8 @@ def TestFnNamedArgsSingleEnumList( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6689,8 +6741,8 @@ def TestFnNamedArgsSingleFloat( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6719,8 +6771,8 @@ def TestFnNamedArgsSingleInt( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6749,8 +6801,8 @@ def TestFnNamedArgsSingleMapStringToClass( return baml_py.BamlStream[Dict[str, partial_types.StringToClassEntry], Dict[str, types.StringToClassEntry]]( raw, - lambda x: cast(Dict[str, partial_types.StringToClassEntry], x.cast_to(types, partial_types)), - lambda x: cast(Dict[str, types.StringToClassEntry], x.cast_to(types, types)), + lambda x: cast(Dict[str, partial_types.StringToClassEntry], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Dict[str, types.StringToClassEntry], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6779,8 +6831,8 @@ def TestFnNamedArgsSingleMapStringToMap( return baml_py.BamlStream[Dict[str, Dict[str, Optional[str]]], Dict[str, Dict[str, str]]]( raw, - lambda x: cast(Dict[str, Dict[str, Optional[str]]], x.cast_to(types, partial_types)), - lambda x: cast(Dict[str, Dict[str, str]], x.cast_to(types, types)), + lambda x: cast(Dict[str, Dict[str, Optional[str]]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Dict[str, Dict[str, str]], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6809,8 +6861,8 @@ def TestFnNamedArgsSingleMapStringToString( return baml_py.BamlStream[Dict[str, Optional[str]], Dict[str, str]]( raw, - lambda x: cast(Dict[str, Optional[str]], x.cast_to(types, partial_types)), - lambda x: cast(Dict[str, str], x.cast_to(types, types)), + lambda x: cast(Dict[str, Optional[str]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Dict[str, str], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6839,8 +6891,8 @@ def TestFnNamedArgsSingleString( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6869,8 +6921,8 @@ def TestFnNamedArgsSingleStringArray( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6899,8 +6951,8 @@ def TestFnNamedArgsSingleStringList( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6929,8 +6981,8 @@ def TestGemini( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6958,8 +7010,8 @@ def TestGeminiOpenAiGeneric( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6988,8 +7040,8 @@ def TestGeminiSystem( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7018,8 +7070,8 @@ def TestGeminiSystemAsChat( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7048,8 +7100,8 @@ def TestImageInput( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7078,8 +7130,8 @@ def TestImageInputAnthropic( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7108,8 +7160,8 @@ def TestImageListInput( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7139,8 +7191,8 @@ def TestMulticlassNamedArgs( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7169,8 +7221,8 @@ def TestNamedArgsLiteralBool( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7199,8 +7251,8 @@ def TestNamedArgsLiteralInt( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7229,8 +7281,8 @@ def TestNamedArgsLiteralString( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7259,8 +7311,8 @@ def TestOllama( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7289,8 +7341,8 @@ def TestOpenAILegacyProvider( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7319,8 +7371,8 @@ def TestOpenAIShorthand( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7348,8 +7400,8 @@ def TestRetryConstant( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7377,8 +7429,8 @@ def TestRetryExponential( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7406,8 +7458,8 @@ def TestSingleFallbackClient( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7436,8 +7488,8 @@ def TestVertex( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7465,8 +7517,8 @@ def TestVertexWithSystemInstructions( return baml_py.BamlStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7495,8 +7547,8 @@ def UnionTest_Function( return baml_py.BamlStream[partial_types.UnionTest_ReturnType, types.UnionTest_ReturnType]( raw, - lambda x: cast(partial_types.UnionTest_ReturnType, x.cast_to(types, partial_types)), - lambda x: cast(types.UnionTest_ReturnType, x.cast_to(types, types)), + lambda x: cast(partial_types.UnionTest_ReturnType, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.UnionTest_ReturnType, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7525,8 +7577,8 @@ def UseBlockConstraint( return baml_py.BamlStream[Optional[int], int]( raw, - lambda x: cast(Optional[int], x.cast_to(types, partial_types)), - lambda x: cast(int, x.cast_to(types, types)), + lambda x: cast(Optional[int], x.cast_to(types, types, partial_types, True)), + lambda x: cast(int, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7555,8 +7607,8 @@ def UseMalformedConstraints( return baml_py.BamlStream[Optional[int], int]( raw, - lambda x: cast(Optional[int], x.cast_to(types, partial_types)), - lambda x: cast(int, x.cast_to(types, types)), + lambda x: cast(Optional[int], x.cast_to(types, types, partial_types, True)), + lambda x: cast(int, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7585,8 +7637,8 @@ def UseNestedBlockConstraint( return baml_py.BamlStream[Optional[int], int]( raw, - lambda x: cast(Optional[int], x.cast_to(types, partial_types)), - lambda x: cast(int, x.cast_to(types, types)), + lambda x: cast(Optional[int], x.cast_to(types, types, partial_types, True)), + lambda x: cast(int, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) diff --git a/integ-tests/python/baml_client/inlinedbaml.py b/integ-tests/python/baml_client/inlinedbaml.py index b9bdfc4b5e..9388bdecdd 100644 --- a/integ-tests/python/baml_client/inlinedbaml.py +++ b/integ-tests/python/baml_client/inlinedbaml.py @@ -68,7 +68,7 @@ "test-files/functions/output/class-list.baml": "function FnOutputClassList(input: string) -> TestOutputClass[] {\n client GPT35\n prompt #\"\n Return a JSON array that follows this schema: \n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest FnOutputClassList {\n functions [FnOutputClassList]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/class-nested.baml": "class TestClassNested {\n prop1 string\n prop2 InnerClass\n}\n\nclass InnerClass {\n prop1 string\n prop2 string\n inner InnerClass2\n}\n\nclass InnerClass2 {\n prop2 int\n prop3 float\n}\n\nfunction FnOutputClassNested(input: string) -> TestClassNested {\n client GPT35\n prompt #\"\n Return a made up json blob that matches this schema:\n {{ctx.output_format}}\n ---\n\n JSON:\n \"#\n}\n\ntest FnOutputClassNested {\n functions [FnOutputClassNested]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/class-with-enum.baml": "enum EnumInClass {\n ONE\n TWO\n}\n\nclass TestClassWithEnum {\n prop1 string\n prop2 EnumInClass\n}\n\nfunction FnOutputClassWithEnum(input: string) -> TestClassWithEnum {\n client GPT35\n prompt #\"\n Return a made up json blob that matches this schema:\n {{ctx.output_format}}\n ---\n\n JSON:\n \"#\n}\n\ntest FnOutputClassWithEnum {\n functions [FnOutputClassWithEnum]\n args {\n input \"example input\"\n }\n}\n", - "test-files/functions/output/class.baml": "class TestOutputClass {\n prop1 string\n prop2 int\n}\n\nfunction FnOutputClass(input: string) -> TestOutputClass {\n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n For the prop2, always return a 540\n\n JSON:\n \"#\n}\n\ntest TestClass {\n functions [FnOutputClass]\n args {\n input \"example input\"\n }\n}\n", + "test-files/functions/output/class.baml": "class TestOutputClass {\n prop1 string @description(\"A long string with about 200 words\")\n prop2 int\n}\n\nfunction FnOutputClass(input: string) -> TestOutputClass {\n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n For the prop2, always return a 540\n\n JSON:\n \"#\n}\n\ntest TestClass {\n functions [FnOutputClass]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/enum-list.baml": "function FnEnumListOutput(input: string) -> EnumOutput[] {\n client GPT35\n prompt #\"\n Print out two of these values randomly selected from the list below in a json array.\n\n {{ctx.output_format}}\n\n Answer:\n \"#\n} \n\ntest FnEnumListOutput {\n functions [FnEnumListOutput]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/enum.baml": "/// An enum with three values,\n/// ONE, TWO and THREE.\nenum EnumOutput {\n\n /// The first enum.\n ONE\n\n /// The second enum.\n TWO\n THREE\n\n @@alias(\"VALUE_ENUM\")\n}\n\nfunction FnEnumOutput(input: string) -> EnumOutput {\n client GPT35\n prompt #\"\n Choose one of these values randomly. Before you give the answer, write out an unrelated haiku about the ocean.\n\n {{ctx.output_format(prefix=null)}}\n \"#\n}\n\ntest FnEnumOutput {\n functions [FnEnumOutput]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/int.baml": "function FnOutputInt(input: string) -> int {\n client GPT35\n prompt #\"\n Return the integer 5 with no additional context.\n \"#\n}\n\ntest FnOutputInt {\n functions [FnOutputInt]\n args {\n input \"example input\"\n }\n}\n", @@ -99,6 +99,7 @@ "test-files/providers/openai.baml": "function PromptTestOpenAI(input: string) -> string {\n client GPT35\n prompt #\"\n Write a nice haiku about {{ input }}\n \"#\n}\n\nfunction TestOpenAILegacyProvider(input: string) -> string {\n client GPT35LegacyProvider\n prompt #\"\n Write a nice haiku about {{ input }}\n \"#\n}\n\nfunction TestOpenAIShorthand(input: string) -> string {\n client GPT35\n prompt #\"\n Write a nice short story about {{ input }}\n \"#\n}", "test-files/providers/tests.baml": "test TestOpenAIShorthand {\n functions [TestOpenAIShorthand]\n args {\n input \"Donkey kong and peanut butter\"\n }\n}\n\ntest TestAWS {\n functions [\n TestAws\n ]\n args {\n input \"Donkey kong and peanut butter\"\n }\n}\n\ntest TestProvider {\n functions [\n TestAnthropic, TestVertex, PromptTestOpenAI, TestAzure, TestOllama, TestGemini, TestAws,\n TestAwsInvalidRegion,\n TestOpenAIShorthand,\n TestAnthropicShorthand,\n TestAwsInvalidAccessKey,\n TestAwsInvalidProfile,\n TestAwsInvalidSessionToken\n ]\n args {\n input \"Donkey kong and peanut butter\"\n }\n}\n\ntest TestName {\n functions [TestCaching]\n args {\n input #\"\nIn a near-future society where dreams have become a tradable commodity and shared experience, a lonely and socially awkward teenager named Alex discovers they possess a rare and powerful ability to not only view but also manipulate the dreams of others. Initially thrilled by this newfound power, Alex begins subtly altering the dreams of classmates and family members, helping them overcome fears, boost confidence, or experience fantastical adventures. As Alex's skills grow, so does their influence. They start selling premium dream experiences on the black market, crafting intricate and addictive dreamscapes for wealthy clients. However, the line between dream and reality begins to blur for those exposed to Alex's creations. Some clients struggle to differentiate between their true memories and the artificial ones implanted by Alex's dream manipulation.\n\nComplications arise when a mysterious government agency takes notice of Alex's unique abilities. They offer Alex a chance to use their gift for \"the greater good,\" hinting at applications in therapy, criminal rehabilitation, and even national security. Simultaneously, an underground resistance movement reaches out, warning Alex about the dangers of dream manipulation and the potential for mass control and exploitation. Caught between these opposing forces, Alex must navigate a complex web of ethical dilemmas. They grapple with questions of free will, the nature of consciousness, and the responsibility that comes with having power over people's minds. As the consequences of their actions spiral outward, affecting the lives of loved ones and strangers alike, Alex is forced to confront the true nature of their ability and decide how—or if—it should be used.\n\nThe story explores themes of identity, the subconscious mind, the ethics of technology, and the power of imagination. It delves into the potential consequences of a world where our most private thoughts and experiences are no longer truly our own, and examines the fine line between helping others and manipulating them for personal gain or a perceived greater good. The narrative further expands on the societal implications of such abilities, questioning the moral boundaries of altering consciousness and the potential for abuse in a world where dreams can be commodified. It challenges the reader to consider the impact of technology on personal autonomy and the ethical responsibilities of those who wield such power.\n\nAs Alex's journey unfolds, they encounter various individuals whose lives have been touched by their dream manipulations, each presenting a unique perspective on the ethical quandaries at hand. From a classmate who gains newfound confidence to a wealthy client who becomes addicted to the dreamscapes, the ripple effects of Alex's actions are profound and far-reaching. The government agency's interest in Alex's abilities raises questions about the potential for state control and surveillance, while the resistance movement highlights the dangers of unchecked power and the importance of safeguarding individual freedoms.\n\nUltimately, Alex's story is one of self-discovery and moral reckoning, as they must decide whether to embrace their abilities for personal gain, align with the government's vision of a controlled utopia, or join the resistance in their fight for freedom and autonomy. The narrative invites readers to reflect on the nature of reality, the boundaries of human experience, and the ethical implications of a world where dreams are no longer private sanctuaries but shared and manipulated commodities. It also explores the psychological impact on Alex, who must deal with the burden of knowing the intimate fears and desires of others, and the isolation that comes from being unable to share their own dreams without altering them.\n\nThe story further examines the technological advancements that have made dream manipulation possible, questioning the role of innovation in society and the potential for both progress and peril. It considers the societal divide between those who can afford to buy enhanced dream experiences and those who cannot, highlighting issues of inequality and access. As Alex becomes more entangled in the web of their own making, they must confront the possibility that their actions could lead to unintended consequences, not just for themselves but for the fabric of society as a whole.\n\nIn the end, Alex's journey is a cautionary tale about the power of dreams and the responsibilities that come with wielding such influence. It serves as a reminder of the importance of ethical considerations in the face of technological advancement and the need to balance innovation with humanity. The story leaves readers pondering the true cost of a world where dreams are no longer sacred, and the potential for both wonder and danger in the uncharted territories of the mind. But it's also a story about the power of imagination and the potential for change, even in a world where our deepest thoughts are no longer our own. And it's a story about the power of choice, and the importance of fighting for the freedom to dream.\n\nIn conclusion, this story is a reflection on the power of dreams and the responsibilities that come with wielding such influence. It serves as a reminder of the importance of ethical considerations in the face of technological advancement and the need to balance innovation with humanity. The story leaves readers pondering the true cost of a world where dreams are no longer sacred, and the potential for both wonder and danger in the uncharted territories of the mind. But it's also a story about the power of imagination and the potential for change, even in a world where our deepest thoughts are no longer our own. And it's a story about the power of choice, and the importance of fighting for the freedom to dream.\n \"#\n not_cached #\"\n hello world\n \"#\n }\n}", "test-files/providers/vertex.baml": "function TestVertex(input: string) -> string {\n client Vertex\n prompt #\"\n Write a nice short story about {{ input }}\n \"#\n}\n\nfunction TestVertexWithSystemInstructions() -> string {\n client Vertex\n prompt #\"{{_.role(\"system\")}} You are a helpful assistant\n {{_.role(\"user\")}} Write a poem about llamas\n \"#\n}\n\ntest TestVertex {\n functions [TestVertex, TestVertexWithSystemInstructions]\n args {\n input \"a cat\"\n\n }\n}\n", + "test-files/semantic_streaming/semantic_streaming.baml": "class SemanticContainer {\n sixteen_digit_number int\n string_with_twenty_words string @stream.done\n class_1 ClassWithoutDone\n class_2 ClassWithBlockDone\n class_done_needed ClassWithBlockDone @stream.not_null\n class_needed ClassWithoutDone @stream.not_null\n three_small_things SmallThing[] @description(\"Should have three items.\")\n final_string string\n}\n\nclass ClassWithoutDone {\n i_16_digits int\n s_20_words string @description(\"A string with 20 words in it\") @stream.with_state\n}\n\nclass ClassWithBlockDone {\n i_16_digits int\n s_20_words string\n @@stream.done\n}\n\nclass SmallThing {\n i_16_digits int @stream.not_null\n i_8_digits int\n}\n\nfunction MakeSemanticContainer() -> SemanticContainer {\n client GPT35\n prompt #\"\n {{ ctx.output_format }}\n \"#\n}", "test-files/strategies/fallback-shorthand.baml": "\nclient FallbackToShorthand {\n provider fallback\n options {\n strategy [\n \"openai/does-not-exist\",\n \"openai/gpt-4o-mini\"\n ]\n }\n}\n\n\nfunction TestFallbackToShorthand(input: string) -> string {\n client FallbackToShorthand\n // TODO make it return the client name instead\n prompt #\"\n Say a haiku about {{input}}.\n \"#\n}\n\ntest TestProvider_FallbackToShorthand {\n functions [\n TestFallbackToShorthand\n ]\n args {\n input \"Donkey kong and peanut butter\"\n }\n}\n", "test-files/strategies/fallback.baml": "// Happy path fallbacks.\nclient FaultyClient {\n provider openai\n options {\n model unknown-model\n api_key env.OPENAI_API_KEY\n }\n}\n\n\nclient FallbackClient {\n provider fallback\n options {\n // first 2 clients are expected to fail.\n strategy [\n FaultyClient,\n RetryClientConstant,\n GPT35\n Gemini\n\n ]\n }\n}\n\nfunction TestFallbackClient() -> string {\n client FallbackClient\n // TODO make it return the client name instead\n prompt #\"\n Say a haiku about mexico.\n \"#\n}\n\n// Fallbacks should fail gracefully.\nclient FaultyAzureClient {\n provider azure-openai\n options {\n model unknown-model\n resource_name \"unknown-resource-id\"\n deployment_id \"unknown-deployment-id\"\n }\n}\n\nclient SingleFallbackClient {\n provider fallback\n options {\n // first 2 clients are expected to fail.\n strategy [\n FaultyAzureClient\n ]\n }\n}\n\nfunction TestSingleFallbackClient() -> string {\n client SingleFallbackClient\n // TODO make it return the client name instead\n prompt #\"\n Say a haiku about mexico.\n \"#\n}\n", "test-files/strategies/retry.baml": "\nretry_policy Exponential {\n max_retries 3\n strategy {\n type exponential_backoff\n }\n}\n\nretry_policy Constant {\n max_retries 3\n strategy {\n type constant_delay\n delay_ms 100\n }\n}\n\nclient RetryClientConstant {\n provider openai\n retry_policy Constant\n options {\n model \"gpt-3.5-turbo\"\n api_key \"blah\"\n }\n}\n\nclient RetryClientExponential {\n provider openai\n retry_policy Exponential\n options {\n model \"gpt-3.5-turbo\"\n api_key \"blahh\"\n }\n}\n\nfunction TestRetryConstant() -> string {\n client RetryClientConstant\n prompt #\"\n Say a haiku\n \"#\n}\n\nfunction TestRetryExponential() -> string {\n client RetryClientExponential\n prompt #\"\n Say a haiku\n \"#\n}\n", diff --git a/integ-tests/python/baml_client/partial_types.py b/integ-tests/python/baml_client/partial_types.py index 5651736e37..dd971e0b13 100644 --- a/integ-tests/python/baml_client/partial_types.py +++ b/integ-tests/python/baml_client/partial_types.py @@ -16,7 +16,7 @@ import baml_py from enum import Enum from pydantic import BaseModel, ConfigDict -from typing import Dict, List, Optional, Union, Literal +from typing import Dict, Generic, List, Optional, TypeVar, Union, Literal from . import types from .types import Checked, Check @@ -28,6 +28,11 @@ # ############################################################################### +T = TypeVar('T') +class StreamState(BaseModel, Generic[T]): + value: T + state: Literal["Pending", "Incomplete", "Complete"] + class BigNumbers(BaseModel): a: Optional[int] = None @@ -70,11 +75,19 @@ class ClassOptionalOutput2(BaseModel): class ClassToRecAlias(BaseModel): list: Optional["LinkedListAliasNode"] = None +class ClassWithBlockDone(BaseModel): + i_16_digits: Optional[int] = None + s_20_words: Optional[str] = None + class ClassWithImage(BaseModel): myImage: Optional[baml_py.Image] = None param2: Optional[str] = None fake_image: Optional["FakeImage"] = None +class ClassWithoutDone(BaseModel): + i_16_digits: Optional[int] = None + s_20_words: StreamState[Optional[str]] + class CompoundBigNumbers(BaseModel): big: Optional["BigNumbers"] = None big_nums: List["BigNumbers"] @@ -333,6 +346,20 @@ class SearchParams(BaseModel): description: List["WithReasoning"] tags: List[Optional[Union[Optional[types.Tag], Optional[str]]]] +class SemanticContainer(BaseModel): + sixteen_digit_number: Optional[int] = None + string_with_twenty_words: Optional[str] = None + class_1: Optional["ClassWithoutDone"] = None + class_2: Optional["types.ClassWithBlockDone"] = None + class_done_needed: "types.ClassWithBlockDone" + class_needed: "ClassWithoutDone" + three_small_things: List["SmallThing"] + final_string: Optional[str] = None + +class SmallThing(BaseModel): + i_16_digits: int + i_8_digits: Optional[int] = None + class SomeClassNestedDynamic(BaseModel): model_config = ConfigDict(extra='allow') hi: Optional[str] = None diff --git a/integ-tests/python/baml_client/sync_client.py b/integ-tests/python/baml_client/sync_client.py index 19e0d198ad..b99217e93f 100644 --- a/integ-tests/python/baml_client/sync_client.py +++ b/integ-tests/python/baml_client/sync_client.py @@ -68,7 +68,7 @@ def AaaSamOutputFormat( tb, __cr__, ) - return cast(types.Recipe, raw.cast_to(types, types)) + return cast(types.Recipe, raw.cast_to(types, types, partial_types, False)) def AliasThatPointsToRecursiveType( self, @@ -91,7 +91,7 @@ def AliasThatPointsToRecursiveType( tb, __cr__, ) - return cast(types.LinkedListAliasNode, raw.cast_to(types, types)) + return cast(types.LinkedListAliasNode, raw.cast_to(types, types, partial_types, False)) def AliasWithMultipleAttrs( self, @@ -114,7 +114,7 @@ def AliasWithMultipleAttrs( tb, __cr__, ) - return cast(Checked[int,types.Literal["gt_ten"]], raw.cast_to(types, types)) + return cast(Checked[int,types.Literal["gt_ten"]], raw.cast_to(types, types, partial_types, False)) def AliasedInputClass( self, @@ -137,7 +137,7 @@ def AliasedInputClass( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def AliasedInputClass2( self, @@ -160,7 +160,7 @@ def AliasedInputClass2( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def AliasedInputClassNested( self, @@ -183,7 +183,7 @@ def AliasedInputClassNested( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def AliasedInputEnum( self, @@ -206,7 +206,7 @@ def AliasedInputEnum( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def AliasedInputList( self, @@ -229,7 +229,7 @@ def AliasedInputList( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def AllowedOptionals( self, @@ -252,7 +252,7 @@ def AllowedOptionals( tb, __cr__, ) - return cast(types.OptionalListAndMap, raw.cast_to(types, types)) + return cast(types.OptionalListAndMap, raw.cast_to(types, types, partial_types, False)) def AssertFn( self, @@ -275,7 +275,7 @@ def AssertFn( tb, __cr__, ) - return cast(int, raw.cast_to(types, types)) + return cast(int, raw.cast_to(types, types, partial_types, False)) def AudioInput( self, @@ -298,7 +298,7 @@ def AudioInput( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def BuildLinkedList( self, @@ -321,7 +321,7 @@ def BuildLinkedList( tb, __cr__, ) - return cast(types.LinkedList, raw.cast_to(types, types)) + return cast(types.LinkedList, raw.cast_to(types, types, partial_types, False)) def BuildTree( self, @@ -344,7 +344,7 @@ def BuildTree( tb, __cr__, ) - return cast(types.Tree, raw.cast_to(types, types)) + return cast(types.Tree, raw.cast_to(types, types, partial_types, False)) def ClassThatPointsToRecursiveClassThroughAlias( self, @@ -367,7 +367,7 @@ def ClassThatPointsToRecursiveClassThroughAlias( tb, __cr__, ) - return cast(types.ClassToRecAlias, raw.cast_to(types, types)) + return cast(types.ClassToRecAlias, raw.cast_to(types, types, partial_types, False)) def ClassifyDynEnumTwo( self, @@ -390,7 +390,7 @@ def ClassifyDynEnumTwo( tb, __cr__, ) - return cast(Union[types.DynEnumTwo, str], raw.cast_to(types, types)) + return cast(Union[types.DynEnumTwo, str], raw.cast_to(types, types, partial_types, False)) def ClassifyMessage( self, @@ -413,7 +413,7 @@ def ClassifyMessage( tb, __cr__, ) - return cast(types.Category, raw.cast_to(types, types)) + return cast(types.Category, raw.cast_to(types, types, partial_types, False)) def ClassifyMessage2( self, @@ -436,7 +436,7 @@ def ClassifyMessage2( tb, __cr__, ) - return cast(types.Category, raw.cast_to(types, types)) + return cast(types.Category, raw.cast_to(types, types, partial_types, False)) def ClassifyMessage3( self, @@ -459,7 +459,7 @@ def ClassifyMessage3( tb, __cr__, ) - return cast(types.Category, raw.cast_to(types, types)) + return cast(types.Category, raw.cast_to(types, types, partial_types, False)) def Completion( self, @@ -482,7 +482,7 @@ def Completion( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def CustomTask( self, @@ -505,7 +505,7 @@ def CustomTask( tb, __cr__, ) - return cast(Union[types.BookOrder, types.FlightConfirmation, types.GroceryReceipt], raw.cast_to(types, types)) + return cast(Union[types.BookOrder, types.FlightConfirmation, types.GroceryReceipt], raw.cast_to(types, types, partial_types, False)) def DescribeImage( self, @@ -528,7 +528,7 @@ def DescribeImage( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def DescribeImage2( self, @@ -551,7 +551,7 @@ def DescribeImage2( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def DescribeImage3( self, @@ -574,7 +574,7 @@ def DescribeImage3( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def DescribeImage4( self, @@ -597,7 +597,7 @@ def DescribeImage4( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def DifferentiateUnions( self, @@ -620,7 +620,7 @@ def DifferentiateUnions( tb, __cr__, ) - return cast(Union[types.OriginalA, types.OriginalB], raw.cast_to(types, types)) + return cast(Union[types.OriginalA, types.OriginalB], raw.cast_to(types, types, partial_types, False)) def DummyOutputFunction( self, @@ -643,7 +643,7 @@ def DummyOutputFunction( tb, __cr__, ) - return cast(types.DummyOutput, raw.cast_to(types, types)) + return cast(types.DummyOutput, raw.cast_to(types, types, partial_types, False)) def DynamicFunc( self, @@ -666,7 +666,7 @@ def DynamicFunc( tb, __cr__, ) - return cast(types.DynamicClassTwo, raw.cast_to(types, types)) + return cast(types.DynamicClassTwo, raw.cast_to(types, types, partial_types, False)) def DynamicInputOutput( self, @@ -689,7 +689,7 @@ def DynamicInputOutput( tb, __cr__, ) - return cast(types.DynInputOutput, raw.cast_to(types, types)) + return cast(types.DynInputOutput, raw.cast_to(types, types, partial_types, False)) def DynamicListInputOutput( self, @@ -712,7 +712,7 @@ def DynamicListInputOutput( tb, __cr__, ) - return cast(List[types.DynInputOutput], raw.cast_to(types, types)) + return cast(List[types.DynInputOutput], raw.cast_to(types, types, partial_types, False)) def ExpectFailure( self, @@ -735,7 +735,7 @@ def ExpectFailure( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def ExtractContactInfo( self, @@ -758,7 +758,7 @@ def ExtractContactInfo( tb, __cr__, ) - return cast(types.ContactInfo, raw.cast_to(types, types)) + return cast(types.ContactInfo, raw.cast_to(types, types, partial_types, False)) def ExtractHobby( self, @@ -781,7 +781,7 @@ def ExtractHobby( tb, __cr__, ) - return cast(List[Union[types.Hobby, str]], raw.cast_to(types, types)) + return cast(List[Union[types.Hobby, str]], raw.cast_to(types, types, partial_types, False)) def ExtractNames( self, @@ -804,7 +804,7 @@ def ExtractNames( tb, __cr__, ) - return cast(List[str], raw.cast_to(types, types)) + return cast(List[str], raw.cast_to(types, types, partial_types, False)) def ExtractPeople( self, @@ -827,7 +827,7 @@ def ExtractPeople( tb, __cr__, ) - return cast(List[types.Person], raw.cast_to(types, types)) + return cast(List[types.Person], raw.cast_to(types, types, partial_types, False)) def ExtractReceiptInfo( self, @@ -850,7 +850,7 @@ def ExtractReceiptInfo( tb, __cr__, ) - return cast(types.ReceiptInfo, raw.cast_to(types, types)) + return cast(types.ReceiptInfo, raw.cast_to(types, types, partial_types, False)) def ExtractResume( self, @@ -873,7 +873,7 @@ def ExtractResume( tb, __cr__, ) - return cast(types.Resume, raw.cast_to(types, types)) + return cast(types.Resume, raw.cast_to(types, types, partial_types, False)) def ExtractResume2( self, @@ -896,7 +896,7 @@ def ExtractResume2( tb, __cr__, ) - return cast(types.Resume, raw.cast_to(types, types)) + return cast(types.Resume, raw.cast_to(types, types, partial_types, False)) def FnClassOptionalOutput( self, @@ -919,7 +919,7 @@ def FnClassOptionalOutput( tb, __cr__, ) - return cast(Optional[types.ClassOptionalOutput], raw.cast_to(types, types)) + return cast(Optional[types.ClassOptionalOutput], raw.cast_to(types, types, partial_types, False)) def FnClassOptionalOutput2( self, @@ -942,7 +942,7 @@ def FnClassOptionalOutput2( tb, __cr__, ) - return cast(Optional[types.ClassOptionalOutput2], raw.cast_to(types, types)) + return cast(Optional[types.ClassOptionalOutput2], raw.cast_to(types, types, partial_types, False)) def FnEnumListOutput( self, @@ -965,7 +965,7 @@ def FnEnumListOutput( tb, __cr__, ) - return cast(List[types.EnumOutput], raw.cast_to(types, types)) + return cast(List[types.EnumOutput], raw.cast_to(types, types, partial_types, False)) def FnEnumOutput( self, @@ -988,7 +988,7 @@ def FnEnumOutput( tb, __cr__, ) - return cast(types.EnumOutput, raw.cast_to(types, types)) + return cast(types.EnumOutput, raw.cast_to(types, types, partial_types, False)) def FnLiteralClassInputOutput( self, @@ -1011,7 +1011,7 @@ def FnLiteralClassInputOutput( tb, __cr__, ) - return cast(types.LiteralClassHello, raw.cast_to(types, types)) + return cast(types.LiteralClassHello, raw.cast_to(types, types, partial_types, False)) def FnLiteralUnionClassInputOutput( self, @@ -1034,7 +1034,7 @@ def FnLiteralUnionClassInputOutput( tb, __cr__, ) - return cast(Union[types.LiteralClassOne, types.LiteralClassTwo], raw.cast_to(types, types)) + return cast(Union[types.LiteralClassOne, types.LiteralClassTwo], raw.cast_to(types, types, partial_types, False)) def FnNamedArgsSingleStringOptional( self, @@ -1057,7 +1057,7 @@ def FnNamedArgsSingleStringOptional( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def FnOutputBool( self, @@ -1080,7 +1080,7 @@ def FnOutputBool( tb, __cr__, ) - return cast(bool, raw.cast_to(types, types)) + return cast(bool, raw.cast_to(types, types, partial_types, False)) def FnOutputClass( self, @@ -1103,7 +1103,7 @@ def FnOutputClass( tb, __cr__, ) - return cast(types.TestOutputClass, raw.cast_to(types, types)) + return cast(types.TestOutputClass, raw.cast_to(types, types, partial_types, False)) def FnOutputClassList( self, @@ -1126,7 +1126,7 @@ def FnOutputClassList( tb, __cr__, ) - return cast(List[types.TestOutputClass], raw.cast_to(types, types)) + return cast(List[types.TestOutputClass], raw.cast_to(types, types, partial_types, False)) def FnOutputClassNested( self, @@ -1149,7 +1149,7 @@ def FnOutputClassNested( tb, __cr__, ) - return cast(types.TestClassNested, raw.cast_to(types, types)) + return cast(types.TestClassNested, raw.cast_to(types, types, partial_types, False)) def FnOutputClassWithEnum( self, @@ -1172,7 +1172,7 @@ def FnOutputClassWithEnum( tb, __cr__, ) - return cast(types.TestClassWithEnum, raw.cast_to(types, types)) + return cast(types.TestClassWithEnum, raw.cast_to(types, types, partial_types, False)) def FnOutputInt( self, @@ -1195,7 +1195,7 @@ def FnOutputInt( tb, __cr__, ) - return cast(int, raw.cast_to(types, types)) + return cast(int, raw.cast_to(types, types, partial_types, False)) def FnOutputLiteralBool( self, @@ -1218,7 +1218,7 @@ def FnOutputLiteralBool( tb, __cr__, ) - return cast(Literal[False], raw.cast_to(types, types)) + return cast(Literal[False], raw.cast_to(types, types, partial_types, False)) def FnOutputLiteralInt( self, @@ -1241,7 +1241,7 @@ def FnOutputLiteralInt( tb, __cr__, ) - return cast(Literal[5], raw.cast_to(types, types)) + return cast(Literal[5], raw.cast_to(types, types, partial_types, False)) def FnOutputLiteralString( self, @@ -1264,7 +1264,7 @@ def FnOutputLiteralString( tb, __cr__, ) - return cast(Literal["example output"], raw.cast_to(types, types)) + return cast(Literal["example output"], raw.cast_to(types, types, partial_types, False)) def FnOutputStringList( self, @@ -1287,7 +1287,7 @@ def FnOutputStringList( tb, __cr__, ) - return cast(List[str], raw.cast_to(types, types)) + return cast(List[str], raw.cast_to(types, types, partial_types, False)) def FnTestAliasedEnumOutput( self, @@ -1310,7 +1310,7 @@ def FnTestAliasedEnumOutput( tb, __cr__, ) - return cast(types.TestEnum, raw.cast_to(types, types)) + return cast(types.TestEnum, raw.cast_to(types, types, partial_types, False)) def FnTestClassAlias( self, @@ -1333,7 +1333,7 @@ def FnTestClassAlias( tb, __cr__, ) - return cast(types.TestClassAlias, raw.cast_to(types, types)) + return cast(types.TestClassAlias, raw.cast_to(types, types, partial_types, False)) def FnTestNamedArgsSingleEnum( self, @@ -1356,7 +1356,7 @@ def FnTestNamedArgsSingleEnum( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def GetDataType( self, @@ -1379,7 +1379,7 @@ def GetDataType( tb, __cr__, ) - return cast(types.RaysData, raw.cast_to(types, types)) + return cast(types.RaysData, raw.cast_to(types, types, partial_types, False)) def GetOrderInfo( self, @@ -1402,7 +1402,7 @@ def GetOrderInfo( tb, __cr__, ) - return cast(types.OrderInfo, raw.cast_to(types, types)) + return cast(types.OrderInfo, raw.cast_to(types, types, partial_types, False)) def GetQuery( self, @@ -1425,7 +1425,7 @@ def GetQuery( tb, __cr__, ) - return cast(types.SearchParams, raw.cast_to(types, types)) + return cast(types.SearchParams, raw.cast_to(types, types, partial_types, False)) def InOutEnumMapKey( self, @@ -1448,7 +1448,7 @@ def InOutEnumMapKey( tb, __cr__, ) - return cast(Dict[types.MapKey, str], raw.cast_to(types, types)) + return cast(Dict[types.MapKey, str], raw.cast_to(types, types, partial_types, False)) def InOutLiteralStringUnionMapKey( self, @@ -1471,7 +1471,7 @@ def InOutLiteralStringUnionMapKey( tb, __cr__, ) - return cast(Dict[Union[Literal["one"], Literal["two"], Union[Literal["three"], Literal["four"]]], str], raw.cast_to(types, types)) + return cast(Dict[Union[Literal["one"], Literal["two"], Union[Literal["three"], Literal["four"]]], str], raw.cast_to(types, types, partial_types, False)) def InOutSingleLiteralStringMapKey( self, @@ -1494,7 +1494,7 @@ def InOutSingleLiteralStringMapKey( tb, __cr__, ) - return cast(Dict[Literal["key"], str], raw.cast_to(types, types)) + return cast(Dict[Literal["key"], str], raw.cast_to(types, types, partial_types, False)) def JsonTypeAliasCycle( self, @@ -1517,7 +1517,7 @@ def JsonTypeAliasCycle( tb, __cr__, ) - return cast(types.JsonValue, raw.cast_to(types, types)) + return cast(types.JsonValue, raw.cast_to(types, types, partial_types, False)) def LiteralUnionsTest( self, @@ -1540,7 +1540,7 @@ def LiteralUnionsTest( tb, __cr__, ) - return cast(Union[Literal[1], Literal[True], Literal["string output"]], raw.cast_to(types, types)) + return cast(Union[Literal[1], Literal[True], Literal["string output"]], raw.cast_to(types, types, partial_types, False)) def MakeBlockConstraint( self, @@ -1563,7 +1563,7 @@ def MakeBlockConstraint( tb, __cr__, ) - return cast(Checked[types.BlockConstraint,types.Literal["cross_field"]], raw.cast_to(types, types)) + return cast(Checked[types.BlockConstraint,types.Literal["cross_field"]], raw.cast_to(types, types, partial_types, False)) def MakeNestedBlockConstraint( self, @@ -1586,7 +1586,30 @@ def MakeNestedBlockConstraint( tb, __cr__, ) - return cast(types.NestedBlockConstraint, raw.cast_to(types, types)) + return cast(types.NestedBlockConstraint, raw.cast_to(types, types, partial_types, False)) + + def MakeSemanticContainer( + self, + + baml_options: BamlCallOptions = {}, + ) -> types.SemanticContainer: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.call_function_sync( + "MakeSemanticContainer", + { + + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + return cast(types.SemanticContainer, raw.cast_to(types, types, partial_types, False)) def MapAlias( self, @@ -1609,7 +1632,7 @@ def MapAlias( tb, __cr__, ) - return cast(Dict[str, List[str]], raw.cast_to(types, types)) + return cast(Dict[str, List[str]], raw.cast_to(types, types, partial_types, False)) def MergeAliasAttributes( self, @@ -1632,7 +1655,7 @@ def MergeAliasAttributes( tb, __cr__, ) - return cast(types.MergeAttrs, raw.cast_to(types, types)) + return cast(types.MergeAttrs, raw.cast_to(types, types, partial_types, False)) def MyFunc( self, @@ -1655,7 +1678,7 @@ def MyFunc( tb, __cr__, ) - return cast(types.DynamicOutput, raw.cast_to(types, types)) + return cast(types.DynamicOutput, raw.cast_to(types, types, partial_types, False)) def NestedAlias( self, @@ -1678,7 +1701,7 @@ def NestedAlias( tb, __cr__, ) - return cast(Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]], raw.cast_to(types, types)) + return cast(Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]], raw.cast_to(types, types, partial_types, False)) def NullLiteralClassHello( self, @@ -1701,7 +1724,7 @@ def NullLiteralClassHello( tb, __cr__, ) - return cast(types.ClassForNullLiteral, raw.cast_to(types, types)) + return cast(types.ClassForNullLiteral, raw.cast_to(types, types, partial_types, False)) def OptionalTest_Function( self, @@ -1724,7 +1747,7 @@ def OptionalTest_Function( tb, __cr__, ) - return cast(List[Optional[types.OptionalTest_ReturnType]], raw.cast_to(types, types)) + return cast(List[Optional[types.OptionalTest_ReturnType]], raw.cast_to(types, types, partial_types, False)) def PredictAge( self, @@ -1747,7 +1770,7 @@ def PredictAge( tb, __cr__, ) - return cast(types.FooAny, raw.cast_to(types, types)) + return cast(types.FooAny, raw.cast_to(types, types, partial_types, False)) def PredictAgeBare( self, @@ -1770,7 +1793,7 @@ def PredictAgeBare( tb, __cr__, ) - return cast(Checked[int,types.Literal["too_big"]], raw.cast_to(types, types)) + return cast(Checked[int,types.Literal["too_big"]], raw.cast_to(types, types, partial_types, False)) def PrimitiveAlias( self, @@ -1793,7 +1816,7 @@ def PrimitiveAlias( tb, __cr__, ) - return cast(Union[int, str, bool, float], raw.cast_to(types, types)) + return cast(Union[int, str, bool, float], raw.cast_to(types, types, partial_types, False)) def PromptTestClaude( self, @@ -1816,7 +1839,7 @@ def PromptTestClaude( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def PromptTestClaudeChat( self, @@ -1839,7 +1862,7 @@ def PromptTestClaudeChat( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def PromptTestClaudeChatNoSystem( self, @@ -1862,7 +1885,7 @@ def PromptTestClaudeChatNoSystem( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def PromptTestOpenAI( self, @@ -1885,7 +1908,7 @@ def PromptTestOpenAI( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def PromptTestOpenAIChat( self, @@ -1908,7 +1931,7 @@ def PromptTestOpenAIChat( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def PromptTestOpenAIChatNoSystem( self, @@ -1931,7 +1954,7 @@ def PromptTestOpenAIChatNoSystem( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def PromptTestStreaming( self, @@ -1954,7 +1977,7 @@ def PromptTestStreaming( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def RecursiveAliasCycle( self, @@ -1977,7 +2000,7 @@ def RecursiveAliasCycle( tb, __cr__, ) - return cast(types.RecAliasOne, raw.cast_to(types, types)) + return cast(types.RecAliasOne, raw.cast_to(types, types, partial_types, False)) def RecursiveClassWithAliasIndirection( self, @@ -2000,7 +2023,7 @@ def RecursiveClassWithAliasIndirection( tb, __cr__, ) - return cast(types.NodeWithAliasIndirection, raw.cast_to(types, types)) + return cast(types.NodeWithAliasIndirection, raw.cast_to(types, types, partial_types, False)) def ReturnAliasWithMergedAttributes( self, @@ -2023,7 +2046,7 @@ def ReturnAliasWithMergedAttributes( tb, __cr__, ) - return cast(Checked[int,types.Literal["gt_ten"]], raw.cast_to(types, types)) + return cast(Checked[int,types.Literal["gt_ten"]], raw.cast_to(types, types, partial_types, False)) def ReturnFailingAssert( self, @@ -2046,7 +2069,7 @@ def ReturnFailingAssert( tb, __cr__, ) - return cast(int, raw.cast_to(types, types)) + return cast(int, raw.cast_to(types, types, partial_types, False)) def ReturnMalformedConstraints( self, @@ -2069,7 +2092,7 @@ def ReturnMalformedConstraints( tb, __cr__, ) - return cast(types.MalformedConstraints, raw.cast_to(types, types)) + return cast(types.MalformedConstraints, raw.cast_to(types, types, partial_types, False)) def SchemaDescriptions( self, @@ -2092,7 +2115,7 @@ def SchemaDescriptions( tb, __cr__, ) - return cast(types.Schema, raw.cast_to(types, types)) + return cast(types.Schema, raw.cast_to(types, types, partial_types, False)) def SimpleRecursiveListAlias( self, @@ -2115,7 +2138,7 @@ def SimpleRecursiveListAlias( tb, __cr__, ) - return cast(types.RecursiveListAlias, raw.cast_to(types, types)) + return cast(types.RecursiveListAlias, raw.cast_to(types, types, partial_types, False)) def SimpleRecursiveMapAlias( self, @@ -2138,7 +2161,7 @@ def SimpleRecursiveMapAlias( tb, __cr__, ) - return cast(types.RecursiveMapAlias, raw.cast_to(types, types)) + return cast(types.RecursiveMapAlias, raw.cast_to(types, types, partial_types, False)) def StreamBigNumbers( self, @@ -2161,7 +2184,7 @@ def StreamBigNumbers( tb, __cr__, ) - return cast(types.BigNumbers, raw.cast_to(types, types)) + return cast(types.BigNumbers, raw.cast_to(types, types, partial_types, False)) def StreamFailingAssertion( self, @@ -2184,7 +2207,7 @@ def StreamFailingAssertion( tb, __cr__, ) - return cast(types.TwoStoriesOneTitle, raw.cast_to(types, types)) + return cast(types.TwoStoriesOneTitle, raw.cast_to(types, types, partial_types, False)) def StreamOneBigNumber( self, @@ -2207,7 +2230,7 @@ def StreamOneBigNumber( tb, __cr__, ) - return cast(int, raw.cast_to(types, types)) + return cast(int, raw.cast_to(types, types, partial_types, False)) def StreamUnionIntegers( self, @@ -2230,7 +2253,7 @@ def StreamUnionIntegers( tb, __cr__, ) - return cast(List[Union[int, str]], raw.cast_to(types, types)) + return cast(List[Union[int, str]], raw.cast_to(types, types, partial_types, False)) def StreamingCompoundNumbers( self, @@ -2253,7 +2276,7 @@ def StreamingCompoundNumbers( tb, __cr__, ) - return cast(types.CompoundBigNumbers, raw.cast_to(types, types)) + return cast(types.CompoundBigNumbers, raw.cast_to(types, types, partial_types, False)) def TestAnthropic( self, @@ -2276,7 +2299,7 @@ def TestAnthropic( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestAnthropicShorthand( self, @@ -2299,7 +2322,7 @@ def TestAnthropicShorthand( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestAws( self, @@ -2322,7 +2345,7 @@ def TestAws( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestAwsInvalidAccessKey( self, @@ -2345,7 +2368,7 @@ def TestAwsInvalidAccessKey( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestAwsInvalidProfile( self, @@ -2368,7 +2391,7 @@ def TestAwsInvalidProfile( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestAwsInvalidRegion( self, @@ -2391,7 +2414,7 @@ def TestAwsInvalidRegion( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestAwsInvalidSessionToken( self, @@ -2414,7 +2437,7 @@ def TestAwsInvalidSessionToken( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestAzure( self, @@ -2437,7 +2460,7 @@ def TestAzure( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestAzureFailure( self, @@ -2460,7 +2483,7 @@ def TestAzureFailure( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestCaching( self, @@ -2483,7 +2506,7 @@ def TestCaching( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestFallbackClient( self, @@ -2506,7 +2529,7 @@ def TestFallbackClient( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestFallbackToShorthand( self, @@ -2529,7 +2552,7 @@ def TestFallbackToShorthand( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestFnNamedArgsSingleBool( self, @@ -2552,7 +2575,7 @@ def TestFnNamedArgsSingleBool( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestFnNamedArgsSingleClass( self, @@ -2575,7 +2598,7 @@ def TestFnNamedArgsSingleClass( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestFnNamedArgsSingleEnumList( self, @@ -2598,7 +2621,7 @@ def TestFnNamedArgsSingleEnumList( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestFnNamedArgsSingleFloat( self, @@ -2621,7 +2644,7 @@ def TestFnNamedArgsSingleFloat( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestFnNamedArgsSingleInt( self, @@ -2644,7 +2667,7 @@ def TestFnNamedArgsSingleInt( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestFnNamedArgsSingleMapStringToClass( self, @@ -2667,7 +2690,7 @@ def TestFnNamedArgsSingleMapStringToClass( tb, __cr__, ) - return cast(Dict[str, types.StringToClassEntry], raw.cast_to(types, types)) + return cast(Dict[str, types.StringToClassEntry], raw.cast_to(types, types, partial_types, False)) def TestFnNamedArgsSingleMapStringToMap( self, @@ -2690,7 +2713,7 @@ def TestFnNamedArgsSingleMapStringToMap( tb, __cr__, ) - return cast(Dict[str, Dict[str, str]], raw.cast_to(types, types)) + return cast(Dict[str, Dict[str, str]], raw.cast_to(types, types, partial_types, False)) def TestFnNamedArgsSingleMapStringToString( self, @@ -2713,7 +2736,7 @@ def TestFnNamedArgsSingleMapStringToString( tb, __cr__, ) - return cast(Dict[str, str], raw.cast_to(types, types)) + return cast(Dict[str, str], raw.cast_to(types, types, partial_types, False)) def TestFnNamedArgsSingleString( self, @@ -2736,7 +2759,7 @@ def TestFnNamedArgsSingleString( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestFnNamedArgsSingleStringArray( self, @@ -2759,7 +2782,7 @@ def TestFnNamedArgsSingleStringArray( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestFnNamedArgsSingleStringList( self, @@ -2782,7 +2805,7 @@ def TestFnNamedArgsSingleStringList( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestGemini( self, @@ -2805,7 +2828,7 @@ def TestGemini( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestGeminiOpenAiGeneric( self, @@ -2828,7 +2851,7 @@ def TestGeminiOpenAiGeneric( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestGeminiSystem( self, @@ -2851,7 +2874,7 @@ def TestGeminiSystem( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestGeminiSystemAsChat( self, @@ -2874,7 +2897,7 @@ def TestGeminiSystemAsChat( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestImageInput( self, @@ -2897,7 +2920,7 @@ def TestImageInput( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestImageInputAnthropic( self, @@ -2920,7 +2943,7 @@ def TestImageInputAnthropic( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestImageListInput( self, @@ -2943,7 +2966,7 @@ def TestImageListInput( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestMulticlassNamedArgs( self, @@ -2966,7 +2989,7 @@ def TestMulticlassNamedArgs( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestNamedArgsLiteralBool( self, @@ -2989,7 +3012,7 @@ def TestNamedArgsLiteralBool( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestNamedArgsLiteralInt( self, @@ -3012,7 +3035,7 @@ def TestNamedArgsLiteralInt( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestNamedArgsLiteralString( self, @@ -3035,7 +3058,7 @@ def TestNamedArgsLiteralString( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestOllama( self, @@ -3058,7 +3081,7 @@ def TestOllama( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestOpenAILegacyProvider( self, @@ -3081,7 +3104,7 @@ def TestOpenAILegacyProvider( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestOpenAIShorthand( self, @@ -3104,7 +3127,7 @@ def TestOpenAIShorthand( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestRetryConstant( self, @@ -3127,7 +3150,7 @@ def TestRetryConstant( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestRetryExponential( self, @@ -3150,7 +3173,7 @@ def TestRetryExponential( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestSingleFallbackClient( self, @@ -3173,7 +3196,7 @@ def TestSingleFallbackClient( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestVertex( self, @@ -3196,7 +3219,7 @@ def TestVertex( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def TestVertexWithSystemInstructions( self, @@ -3219,7 +3242,7 @@ def TestVertexWithSystemInstructions( tb, __cr__, ) - return cast(str, raw.cast_to(types, types)) + return cast(str, raw.cast_to(types, types, partial_types, False)) def UnionTest_Function( self, @@ -3242,7 +3265,7 @@ def UnionTest_Function( tb, __cr__, ) - return cast(types.UnionTest_ReturnType, raw.cast_to(types, types)) + return cast(types.UnionTest_ReturnType, raw.cast_to(types, types, partial_types, False)) def UseBlockConstraint( self, @@ -3265,7 +3288,7 @@ def UseBlockConstraint( tb, __cr__, ) - return cast(int, raw.cast_to(types, types)) + return cast(int, raw.cast_to(types, types, partial_types, False)) def UseMalformedConstraints( self, @@ -3288,7 +3311,7 @@ def UseMalformedConstraints( tb, __cr__, ) - return cast(int, raw.cast_to(types, types)) + return cast(int, raw.cast_to(types, types, partial_types, False)) def UseNestedBlockConstraint( self, @@ -3311,7 +3334,7 @@ def UseNestedBlockConstraint( tb, __cr__, ) - return cast(int, raw.cast_to(types, types)) + return cast(int, raw.cast_to(types, types, partial_types, False)) @@ -3350,8 +3373,8 @@ def AaaSamOutputFormat( return baml_py.BamlSyncStream[partial_types.Recipe, types.Recipe]( raw, - lambda x: cast(partial_types.Recipe, x.cast_to(types, partial_types)), - lambda x: cast(types.Recipe, x.cast_to(types, types)), + lambda x: cast(partial_types.Recipe, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.Recipe, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3380,8 +3403,8 @@ def AliasThatPointsToRecursiveType( return baml_py.BamlSyncStream[partial_types.LinkedListAliasNode, types.LinkedListAliasNode]( raw, - lambda x: cast(partial_types.LinkedListAliasNode, x.cast_to(types, partial_types)), - lambda x: cast(types.LinkedListAliasNode, x.cast_to(types, types)), + lambda x: cast(partial_types.LinkedListAliasNode, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.LinkedListAliasNode, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3410,8 +3433,8 @@ def AliasWithMultipleAttrs( return baml_py.BamlSyncStream[Checked[Optional[int],types.Literal["gt_ten"]], Checked[int,types.Literal["gt_ten"]]]( raw, - lambda x: cast(Checked[Optional[int],types.Literal["gt_ten"]], x.cast_to(types, partial_types)), - lambda x: cast(Checked[int,types.Literal["gt_ten"]], x.cast_to(types, types)), + lambda x: cast(Checked[Optional[int],types.Literal["gt_ten"]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Checked[int,types.Literal["gt_ten"]], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3440,8 +3463,8 @@ def AliasedInputClass( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3470,8 +3493,8 @@ def AliasedInputClass2( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3500,8 +3523,8 @@ def AliasedInputClassNested( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3530,8 +3553,8 @@ def AliasedInputEnum( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3560,8 +3583,8 @@ def AliasedInputList( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3590,8 +3613,8 @@ def AllowedOptionals( return baml_py.BamlSyncStream[partial_types.OptionalListAndMap, types.OptionalListAndMap]( raw, - lambda x: cast(partial_types.OptionalListAndMap, x.cast_to(types, partial_types)), - lambda x: cast(types.OptionalListAndMap, x.cast_to(types, types)), + lambda x: cast(partial_types.OptionalListAndMap, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.OptionalListAndMap, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3620,8 +3643,8 @@ def AssertFn( return baml_py.BamlSyncStream[Optional[int], int]( raw, - lambda x: cast(Optional[int], x.cast_to(types, partial_types)), - lambda x: cast(int, x.cast_to(types, types)), + lambda x: cast(Optional[int], x.cast_to(types, types, partial_types, True)), + lambda x: cast(int, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3650,8 +3673,8 @@ def AudioInput( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3680,8 +3703,8 @@ def BuildLinkedList( return baml_py.BamlSyncStream[partial_types.LinkedList, types.LinkedList]( raw, - lambda x: cast(partial_types.LinkedList, x.cast_to(types, partial_types)), - lambda x: cast(types.LinkedList, x.cast_to(types, types)), + lambda x: cast(partial_types.LinkedList, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.LinkedList, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3710,8 +3733,8 @@ def BuildTree( return baml_py.BamlSyncStream[partial_types.Tree, types.Tree]( raw, - lambda x: cast(partial_types.Tree, x.cast_to(types, partial_types)), - lambda x: cast(types.Tree, x.cast_to(types, types)), + lambda x: cast(partial_types.Tree, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.Tree, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3740,8 +3763,8 @@ def ClassThatPointsToRecursiveClassThroughAlias( return baml_py.BamlSyncStream[partial_types.ClassToRecAlias, types.ClassToRecAlias]( raw, - lambda x: cast(partial_types.ClassToRecAlias, x.cast_to(types, partial_types)), - lambda x: cast(types.ClassToRecAlias, x.cast_to(types, types)), + lambda x: cast(partial_types.ClassToRecAlias, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.ClassToRecAlias, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3770,8 +3793,8 @@ def ClassifyDynEnumTwo( return baml_py.BamlSyncStream[Optional[Union[types.DynEnumTwo, str]], Union[types.DynEnumTwo, str]]( raw, - lambda x: cast(Optional[Union[types.DynEnumTwo, str]], x.cast_to(types, partial_types)), - lambda x: cast(Union[types.DynEnumTwo, str], x.cast_to(types, types)), + lambda x: cast(Optional[Union[types.DynEnumTwo, str]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Union[types.DynEnumTwo, str], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3800,8 +3823,8 @@ def ClassifyMessage( return baml_py.BamlSyncStream[Optional[types.Category], types.Category]( raw, - lambda x: cast(Optional[types.Category], x.cast_to(types, partial_types)), - lambda x: cast(types.Category, x.cast_to(types, types)), + lambda x: cast(Optional[types.Category], x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.Category, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3830,8 +3853,8 @@ def ClassifyMessage2( return baml_py.BamlSyncStream[Optional[types.Category], types.Category]( raw, - lambda x: cast(Optional[types.Category], x.cast_to(types, partial_types)), - lambda x: cast(types.Category, x.cast_to(types, types)), + lambda x: cast(Optional[types.Category], x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.Category, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3860,8 +3883,8 @@ def ClassifyMessage3( return baml_py.BamlSyncStream[Optional[types.Category], types.Category]( raw, - lambda x: cast(Optional[types.Category], x.cast_to(types, partial_types)), - lambda x: cast(types.Category, x.cast_to(types, types)), + lambda x: cast(Optional[types.Category], x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.Category, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3892,8 +3915,8 @@ def Completion( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3922,8 +3945,8 @@ def CustomTask( return baml_py.BamlSyncStream[Optional[Union[partial_types.BookOrder, partial_types.FlightConfirmation, partial_types.GroceryReceipt]], Union[types.BookOrder, types.FlightConfirmation, types.GroceryReceipt]]( raw, - lambda x: cast(Optional[Union[partial_types.BookOrder, partial_types.FlightConfirmation, partial_types.GroceryReceipt]], x.cast_to(types, partial_types)), - lambda x: cast(Union[types.BookOrder, types.FlightConfirmation, types.GroceryReceipt], x.cast_to(types, types)), + lambda x: cast(Optional[Union[partial_types.BookOrder, partial_types.FlightConfirmation, partial_types.GroceryReceipt]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Union[types.BookOrder, types.FlightConfirmation, types.GroceryReceipt], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3952,8 +3975,8 @@ def DescribeImage( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -3983,8 +4006,8 @@ def DescribeImage2( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4014,8 +4037,8 @@ def DescribeImage3( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4045,8 +4068,8 @@ def DescribeImage4( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4074,8 +4097,8 @@ def DifferentiateUnions( return baml_py.BamlSyncStream[Optional[Union[partial_types.OriginalA, partial_types.OriginalB]], Union[types.OriginalA, types.OriginalB]]( raw, - lambda x: cast(Optional[Union[partial_types.OriginalA, partial_types.OriginalB]], x.cast_to(types, partial_types)), - lambda x: cast(Union[types.OriginalA, types.OriginalB], x.cast_to(types, types)), + lambda x: cast(Optional[Union[partial_types.OriginalA, partial_types.OriginalB]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Union[types.OriginalA, types.OriginalB], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4104,8 +4127,8 @@ def DummyOutputFunction( return baml_py.BamlSyncStream[partial_types.DummyOutput, types.DummyOutput]( raw, - lambda x: cast(partial_types.DummyOutput, x.cast_to(types, partial_types)), - lambda x: cast(types.DummyOutput, x.cast_to(types, types)), + lambda x: cast(partial_types.DummyOutput, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.DummyOutput, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4134,8 +4157,8 @@ def DynamicFunc( return baml_py.BamlSyncStream[partial_types.DynamicClassTwo, types.DynamicClassTwo]( raw, - lambda x: cast(partial_types.DynamicClassTwo, x.cast_to(types, partial_types)), - lambda x: cast(types.DynamicClassTwo, x.cast_to(types, types)), + lambda x: cast(partial_types.DynamicClassTwo, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.DynamicClassTwo, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4164,8 +4187,8 @@ def DynamicInputOutput( return baml_py.BamlSyncStream[partial_types.DynInputOutput, types.DynInputOutput]( raw, - lambda x: cast(partial_types.DynInputOutput, x.cast_to(types, partial_types)), - lambda x: cast(types.DynInputOutput, x.cast_to(types, types)), + lambda x: cast(partial_types.DynInputOutput, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.DynInputOutput, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4194,8 +4217,8 @@ def DynamicListInputOutput( return baml_py.BamlSyncStream[List[partial_types.DynInputOutput], List[types.DynInputOutput]]( raw, - lambda x: cast(List[partial_types.DynInputOutput], x.cast_to(types, partial_types)), - lambda x: cast(List[types.DynInputOutput], x.cast_to(types, types)), + lambda x: cast(List[partial_types.DynInputOutput], x.cast_to(types, types, partial_types, True)), + lambda x: cast(List[types.DynInputOutput], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4223,8 +4246,8 @@ def ExpectFailure( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4253,8 +4276,8 @@ def ExtractContactInfo( return baml_py.BamlSyncStream[partial_types.ContactInfo, types.ContactInfo]( raw, - lambda x: cast(partial_types.ContactInfo, x.cast_to(types, partial_types)), - lambda x: cast(types.ContactInfo, x.cast_to(types, types)), + lambda x: cast(partial_types.ContactInfo, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.ContactInfo, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4283,8 +4306,8 @@ def ExtractHobby( return baml_py.BamlSyncStream[List[Optional[Union[types.Hobby, str]]], List[Union[types.Hobby, str]]]( raw, - lambda x: cast(List[Optional[Union[types.Hobby, str]]], x.cast_to(types, partial_types)), - lambda x: cast(List[Union[types.Hobby, str]], x.cast_to(types, types)), + lambda x: cast(List[Optional[Union[types.Hobby, str]]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(List[Union[types.Hobby, str]], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4313,8 +4336,8 @@ def ExtractNames( return baml_py.BamlSyncStream[List[Optional[str]], List[str]]( raw, - lambda x: cast(List[Optional[str]], x.cast_to(types, partial_types)), - lambda x: cast(List[str], x.cast_to(types, types)), + lambda x: cast(List[Optional[str]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(List[str], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4343,8 +4366,8 @@ def ExtractPeople( return baml_py.BamlSyncStream[List[partial_types.Person], List[types.Person]]( raw, - lambda x: cast(List[partial_types.Person], x.cast_to(types, partial_types)), - lambda x: cast(List[types.Person], x.cast_to(types, types)), + lambda x: cast(List[partial_types.Person], x.cast_to(types, types, partial_types, True)), + lambda x: cast(List[types.Person], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4374,8 +4397,8 @@ def ExtractReceiptInfo( return baml_py.BamlSyncStream[partial_types.ReceiptInfo, types.ReceiptInfo]( raw, - lambda x: cast(partial_types.ReceiptInfo, x.cast_to(types, partial_types)), - lambda x: cast(types.ReceiptInfo, x.cast_to(types, types)), + lambda x: cast(partial_types.ReceiptInfo, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.ReceiptInfo, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4405,8 +4428,8 @@ def ExtractResume( return baml_py.BamlSyncStream[partial_types.Resume, types.Resume]( raw, - lambda x: cast(partial_types.Resume, x.cast_to(types, partial_types)), - lambda x: cast(types.Resume, x.cast_to(types, types)), + lambda x: cast(partial_types.Resume, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.Resume, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4435,8 +4458,8 @@ def ExtractResume2( return baml_py.BamlSyncStream[partial_types.Resume, types.Resume]( raw, - lambda x: cast(partial_types.Resume, x.cast_to(types, partial_types)), - lambda x: cast(types.Resume, x.cast_to(types, types)), + lambda x: cast(partial_types.Resume, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.Resume, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4465,8 +4488,8 @@ def FnClassOptionalOutput( return baml_py.BamlSyncStream[partial_types.ClassOptionalOutput, Optional[types.ClassOptionalOutput]]( raw, - lambda x: cast(partial_types.ClassOptionalOutput, x.cast_to(types, partial_types)), - lambda x: cast(Optional[types.ClassOptionalOutput], x.cast_to(types, types)), + lambda x: cast(partial_types.ClassOptionalOutput, x.cast_to(types, types, partial_types, True)), + lambda x: cast(Optional[types.ClassOptionalOutput], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4495,8 +4518,8 @@ def FnClassOptionalOutput2( return baml_py.BamlSyncStream[partial_types.ClassOptionalOutput2, Optional[types.ClassOptionalOutput2]]( raw, - lambda x: cast(partial_types.ClassOptionalOutput2, x.cast_to(types, partial_types)), - lambda x: cast(Optional[types.ClassOptionalOutput2], x.cast_to(types, types)), + lambda x: cast(partial_types.ClassOptionalOutput2, x.cast_to(types, types, partial_types, True)), + lambda x: cast(Optional[types.ClassOptionalOutput2], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4525,8 +4548,8 @@ def FnEnumListOutput( return baml_py.BamlSyncStream[List[Optional[types.EnumOutput]], List[types.EnumOutput]]( raw, - lambda x: cast(List[Optional[types.EnumOutput]], x.cast_to(types, partial_types)), - lambda x: cast(List[types.EnumOutput], x.cast_to(types, types)), + lambda x: cast(List[Optional[types.EnumOutput]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(List[types.EnumOutput], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4555,8 +4578,8 @@ def FnEnumOutput( return baml_py.BamlSyncStream[Optional[types.EnumOutput], types.EnumOutput]( raw, - lambda x: cast(Optional[types.EnumOutput], x.cast_to(types, partial_types)), - lambda x: cast(types.EnumOutput, x.cast_to(types, types)), + lambda x: cast(Optional[types.EnumOutput], x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.EnumOutput, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4585,8 +4608,8 @@ def FnLiteralClassInputOutput( return baml_py.BamlSyncStream[partial_types.LiteralClassHello, types.LiteralClassHello]( raw, - lambda x: cast(partial_types.LiteralClassHello, x.cast_to(types, partial_types)), - lambda x: cast(types.LiteralClassHello, x.cast_to(types, types)), + lambda x: cast(partial_types.LiteralClassHello, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.LiteralClassHello, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4615,8 +4638,8 @@ def FnLiteralUnionClassInputOutput( return baml_py.BamlSyncStream[Optional[Union[partial_types.LiteralClassOne, partial_types.LiteralClassTwo]], Union[types.LiteralClassOne, types.LiteralClassTwo]]( raw, - lambda x: cast(Optional[Union[partial_types.LiteralClassOne, partial_types.LiteralClassTwo]], x.cast_to(types, partial_types)), - lambda x: cast(Union[types.LiteralClassOne, types.LiteralClassTwo], x.cast_to(types, types)), + lambda x: cast(Optional[Union[partial_types.LiteralClassOne, partial_types.LiteralClassTwo]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Union[types.LiteralClassOne, types.LiteralClassTwo], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4645,8 +4668,8 @@ def FnNamedArgsSingleStringOptional( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4675,8 +4698,8 @@ def FnOutputBool( return baml_py.BamlSyncStream[Optional[bool], bool]( raw, - lambda x: cast(Optional[bool], x.cast_to(types, partial_types)), - lambda x: cast(bool, x.cast_to(types, types)), + lambda x: cast(Optional[bool], x.cast_to(types, types, partial_types, True)), + lambda x: cast(bool, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4705,8 +4728,8 @@ def FnOutputClass( return baml_py.BamlSyncStream[partial_types.TestOutputClass, types.TestOutputClass]( raw, - lambda x: cast(partial_types.TestOutputClass, x.cast_to(types, partial_types)), - lambda x: cast(types.TestOutputClass, x.cast_to(types, types)), + lambda x: cast(partial_types.TestOutputClass, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.TestOutputClass, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4735,8 +4758,8 @@ def FnOutputClassList( return baml_py.BamlSyncStream[List[partial_types.TestOutputClass], List[types.TestOutputClass]]( raw, - lambda x: cast(List[partial_types.TestOutputClass], x.cast_to(types, partial_types)), - lambda x: cast(List[types.TestOutputClass], x.cast_to(types, types)), + lambda x: cast(List[partial_types.TestOutputClass], x.cast_to(types, types, partial_types, True)), + lambda x: cast(List[types.TestOutputClass], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4765,8 +4788,8 @@ def FnOutputClassNested( return baml_py.BamlSyncStream[partial_types.TestClassNested, types.TestClassNested]( raw, - lambda x: cast(partial_types.TestClassNested, x.cast_to(types, partial_types)), - lambda x: cast(types.TestClassNested, x.cast_to(types, types)), + lambda x: cast(partial_types.TestClassNested, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.TestClassNested, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4795,8 +4818,8 @@ def FnOutputClassWithEnum( return baml_py.BamlSyncStream[partial_types.TestClassWithEnum, types.TestClassWithEnum]( raw, - lambda x: cast(partial_types.TestClassWithEnum, x.cast_to(types, partial_types)), - lambda x: cast(types.TestClassWithEnum, x.cast_to(types, types)), + lambda x: cast(partial_types.TestClassWithEnum, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.TestClassWithEnum, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4825,8 +4848,8 @@ def FnOutputInt( return baml_py.BamlSyncStream[Optional[int], int]( raw, - lambda x: cast(Optional[int], x.cast_to(types, partial_types)), - lambda x: cast(int, x.cast_to(types, types)), + lambda x: cast(Optional[int], x.cast_to(types, types, partial_types, True)), + lambda x: cast(int, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4855,8 +4878,8 @@ def FnOutputLiteralBool( return baml_py.BamlSyncStream[Optional[Literal[False]], Literal[False]]( raw, - lambda x: cast(Optional[Literal[False]], x.cast_to(types, partial_types)), - lambda x: cast(Literal[False], x.cast_to(types, types)), + lambda x: cast(Optional[Literal[False]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Literal[False], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4885,8 +4908,8 @@ def FnOutputLiteralInt( return baml_py.BamlSyncStream[Optional[Literal[5]], Literal[5]]( raw, - lambda x: cast(Optional[Literal[5]], x.cast_to(types, partial_types)), - lambda x: cast(Literal[5], x.cast_to(types, types)), + lambda x: cast(Optional[Literal[5]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Literal[5], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4915,8 +4938,8 @@ def FnOutputLiteralString( return baml_py.BamlSyncStream[Optional[Literal["example output"]], Literal["example output"]]( raw, - lambda x: cast(Optional[Literal["example output"]], x.cast_to(types, partial_types)), - lambda x: cast(Literal["example output"], x.cast_to(types, types)), + lambda x: cast(Optional[Literal["example output"]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Literal["example output"], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4945,8 +4968,8 @@ def FnOutputStringList( return baml_py.BamlSyncStream[List[Optional[str]], List[str]]( raw, - lambda x: cast(List[Optional[str]], x.cast_to(types, partial_types)), - lambda x: cast(List[str], x.cast_to(types, types)), + lambda x: cast(List[Optional[str]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(List[str], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -4975,8 +4998,8 @@ def FnTestAliasedEnumOutput( return baml_py.BamlSyncStream[Optional[types.TestEnum], types.TestEnum]( raw, - lambda x: cast(Optional[types.TestEnum], x.cast_to(types, partial_types)), - lambda x: cast(types.TestEnum, x.cast_to(types, types)), + lambda x: cast(Optional[types.TestEnum], x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.TestEnum, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5005,8 +5028,8 @@ def FnTestClassAlias( return baml_py.BamlSyncStream[partial_types.TestClassAlias, types.TestClassAlias]( raw, - lambda x: cast(partial_types.TestClassAlias, x.cast_to(types, partial_types)), - lambda x: cast(types.TestClassAlias, x.cast_to(types, types)), + lambda x: cast(partial_types.TestClassAlias, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.TestClassAlias, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5035,8 +5058,8 @@ def FnTestNamedArgsSingleEnum( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5065,8 +5088,8 @@ def GetDataType( return baml_py.BamlSyncStream[partial_types.RaysData, types.RaysData]( raw, - lambda x: cast(partial_types.RaysData, x.cast_to(types, partial_types)), - lambda x: cast(types.RaysData, x.cast_to(types, types)), + lambda x: cast(partial_types.RaysData, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.RaysData, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5095,8 +5118,8 @@ def GetOrderInfo( return baml_py.BamlSyncStream[partial_types.OrderInfo, types.OrderInfo]( raw, - lambda x: cast(partial_types.OrderInfo, x.cast_to(types, partial_types)), - lambda x: cast(types.OrderInfo, x.cast_to(types, types)), + lambda x: cast(partial_types.OrderInfo, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.OrderInfo, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5125,8 +5148,8 @@ def GetQuery( return baml_py.BamlSyncStream[partial_types.SearchParams, types.SearchParams]( raw, - lambda x: cast(partial_types.SearchParams, x.cast_to(types, partial_types)), - lambda x: cast(types.SearchParams, x.cast_to(types, types)), + lambda x: cast(partial_types.SearchParams, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.SearchParams, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5156,8 +5179,8 @@ def InOutEnumMapKey( return baml_py.BamlSyncStream[Dict[types.MapKey, Optional[str]], Dict[types.MapKey, str]]( raw, - lambda x: cast(Dict[types.MapKey, Optional[str]], x.cast_to(types, partial_types)), - lambda x: cast(Dict[types.MapKey, str], x.cast_to(types, types)), + lambda x: cast(Dict[types.MapKey, Optional[str]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Dict[types.MapKey, str], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5187,8 +5210,8 @@ def InOutLiteralStringUnionMapKey( return baml_py.BamlSyncStream[Dict[Union[Literal["one"], Literal["two"], Union[Literal["three"], Literal["four"]]], Optional[str]], Dict[Union[Literal["one"], Literal["two"], Union[Literal["three"], Literal["four"]]], str]]( raw, - lambda x: cast(Dict[Union[Literal["one"], Literal["two"], Union[Literal["three"], Literal["four"]]], Optional[str]], x.cast_to(types, partial_types)), - lambda x: cast(Dict[Union[Literal["one"], Literal["two"], Union[Literal["three"], Literal["four"]]], str], x.cast_to(types, types)), + lambda x: cast(Dict[Union[Literal["one"], Literal["two"], Union[Literal["three"], Literal["four"]]], Optional[str]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Dict[Union[Literal["one"], Literal["two"], Union[Literal["three"], Literal["four"]]], str], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5217,8 +5240,8 @@ def InOutSingleLiteralStringMapKey( return baml_py.BamlSyncStream[Dict[Literal["key"], Optional[str]], Dict[Literal["key"], str]]( raw, - lambda x: cast(Dict[Literal["key"], Optional[str]], x.cast_to(types, partial_types)), - lambda x: cast(Dict[Literal["key"], str], x.cast_to(types, types)), + lambda x: cast(Dict[Literal["key"], Optional[str]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Dict[Literal["key"], str], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5247,8 +5270,8 @@ def JsonTypeAliasCycle( return baml_py.BamlSyncStream[types.JsonValue, types.JsonValue]( raw, - lambda x: cast(types.JsonValue, x.cast_to(types, partial_types)), - lambda x: cast(types.JsonValue, x.cast_to(types, types)), + lambda x: cast(types.JsonValue, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.JsonValue, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5277,8 +5300,8 @@ def LiteralUnionsTest( return baml_py.BamlSyncStream[Optional[Union[Optional[Literal[1]], Optional[Literal[True]], Optional[Literal["string output"]]]], Union[Literal[1], Literal[True], Literal["string output"]]]( raw, - lambda x: cast(Optional[Union[Optional[Literal[1]], Optional[Literal[True]], Optional[Literal["string output"]]]], x.cast_to(types, partial_types)), - lambda x: cast(Union[Literal[1], Literal[True], Literal["string output"]], x.cast_to(types, types)), + lambda x: cast(Optional[Union[Optional[Literal[1]], Optional[Literal[True]], Optional[Literal["string output"]]]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Union[Literal[1], Literal[True], Literal["string output"]], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5306,8 +5329,8 @@ def MakeBlockConstraint( return baml_py.BamlSyncStream[Checked[partial_types.BlockConstraint,types.Literal["cross_field"]], Checked[types.BlockConstraint,types.Literal["cross_field"]]]( raw, - lambda x: cast(Checked[partial_types.BlockConstraint,types.Literal["cross_field"]], x.cast_to(types, partial_types)), - lambda x: cast(Checked[types.BlockConstraint,types.Literal["cross_field"]], x.cast_to(types, types)), + lambda x: cast(Checked[partial_types.BlockConstraint,types.Literal["cross_field"]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Checked[types.BlockConstraint,types.Literal["cross_field"]], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5335,8 +5358,37 @@ def MakeNestedBlockConstraint( return baml_py.BamlSyncStream[partial_types.NestedBlockConstraint, types.NestedBlockConstraint]( raw, - lambda x: cast(partial_types.NestedBlockConstraint, x.cast_to(types, partial_types)), - lambda x: cast(types.NestedBlockConstraint, x.cast_to(types, types)), + lambda x: cast(partial_types.NestedBlockConstraint, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.NestedBlockConstraint, x.cast_to(types, types, partial_types, False)), + self.__ctx_manager.get(), + ) + + def MakeSemanticContainer( + self, + + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlSyncStream[partial_types.SemanticContainer, types.SemanticContainer]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb # type: ignore (we know how to use this private attribute) + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function_sync( + "MakeSemanticContainer", + { + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + return baml_py.BamlSyncStream[partial_types.SemanticContainer, types.SemanticContainer]( + raw, + lambda x: cast(partial_types.SemanticContainer, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.SemanticContainer, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5365,8 +5417,8 @@ def MapAlias( return baml_py.BamlSyncStream[Dict[str, List[Optional[str]]], Dict[str, List[str]]]( raw, - lambda x: cast(Dict[str, List[Optional[str]]], x.cast_to(types, partial_types)), - lambda x: cast(Dict[str, List[str]], x.cast_to(types, types)), + lambda x: cast(Dict[str, List[Optional[str]]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Dict[str, List[str]], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5395,8 +5447,8 @@ def MergeAliasAttributes( return baml_py.BamlSyncStream[partial_types.MergeAttrs, types.MergeAttrs]( raw, - lambda x: cast(partial_types.MergeAttrs, x.cast_to(types, partial_types)), - lambda x: cast(types.MergeAttrs, x.cast_to(types, types)), + lambda x: cast(partial_types.MergeAttrs, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.MergeAttrs, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5425,8 +5477,8 @@ def MyFunc( return baml_py.BamlSyncStream[partial_types.DynamicOutput, types.DynamicOutput]( raw, - lambda x: cast(partial_types.DynamicOutput, x.cast_to(types, partial_types)), - lambda x: cast(types.DynamicOutput, x.cast_to(types, types)), + lambda x: cast(partial_types.DynamicOutput, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.DynamicOutput, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5455,8 +5507,8 @@ def NestedAlias( return baml_py.BamlSyncStream[Optional[Union[Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], List[Optional[str]], Dict[str, List[Optional[str]]]]], Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]]]( raw, - lambda x: cast(Optional[Union[Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], List[Optional[str]], Dict[str, List[Optional[str]]]]], x.cast_to(types, partial_types)), - lambda x: cast(Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]], x.cast_to(types, types)), + lambda x: cast(Optional[Union[Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], List[Optional[str]], Dict[str, List[Optional[str]]]]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5485,8 +5537,8 @@ def NullLiteralClassHello( return baml_py.BamlSyncStream[partial_types.ClassForNullLiteral, types.ClassForNullLiteral]( raw, - lambda x: cast(partial_types.ClassForNullLiteral, x.cast_to(types, partial_types)), - lambda x: cast(types.ClassForNullLiteral, x.cast_to(types, types)), + lambda x: cast(partial_types.ClassForNullLiteral, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.ClassForNullLiteral, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5515,8 +5567,8 @@ def OptionalTest_Function( return baml_py.BamlSyncStream[List[partial_types.OptionalTest_ReturnType], List[Optional[types.OptionalTest_ReturnType]]]( raw, - lambda x: cast(List[partial_types.OptionalTest_ReturnType], x.cast_to(types, partial_types)), - lambda x: cast(List[Optional[types.OptionalTest_ReturnType]], x.cast_to(types, types)), + lambda x: cast(List[partial_types.OptionalTest_ReturnType], x.cast_to(types, types, partial_types, True)), + lambda x: cast(List[Optional[types.OptionalTest_ReturnType]], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5545,8 +5597,8 @@ def PredictAge( return baml_py.BamlSyncStream[partial_types.FooAny, types.FooAny]( raw, - lambda x: cast(partial_types.FooAny, x.cast_to(types, partial_types)), - lambda x: cast(types.FooAny, x.cast_to(types, types)), + lambda x: cast(partial_types.FooAny, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.FooAny, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5575,8 +5627,8 @@ def PredictAgeBare( return baml_py.BamlSyncStream[Checked[Optional[int],types.Literal["too_big"]], Checked[int,types.Literal["too_big"]]]( raw, - lambda x: cast(Checked[Optional[int],types.Literal["too_big"]], x.cast_to(types, partial_types)), - lambda x: cast(Checked[int,types.Literal["too_big"]], x.cast_to(types, types)), + lambda x: cast(Checked[Optional[int],types.Literal["too_big"]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Checked[int,types.Literal["too_big"]], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5605,8 +5657,8 @@ def PrimitiveAlias( return baml_py.BamlSyncStream[Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], Union[int, str, bool, float]]( raw, - lambda x: cast(Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], x.cast_to(types, partial_types)), - lambda x: cast(Union[int, str, bool, float], x.cast_to(types, types)), + lambda x: cast(Optional[Union[Optional[int], Optional[str], Optional[bool], Optional[float]]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Union[int, str, bool, float], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5635,8 +5687,8 @@ def PromptTestClaude( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5665,8 +5717,8 @@ def PromptTestClaudeChat( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5695,8 +5747,8 @@ def PromptTestClaudeChatNoSystem( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5725,8 +5777,8 @@ def PromptTestOpenAI( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5755,8 +5807,8 @@ def PromptTestOpenAIChat( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5785,8 +5837,8 @@ def PromptTestOpenAIChatNoSystem( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5815,8 +5867,8 @@ def PromptTestStreaming( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5845,8 +5897,8 @@ def RecursiveAliasCycle( return baml_py.BamlSyncStream[types.RecAliasOne, types.RecAliasOne]( raw, - lambda x: cast(types.RecAliasOne, x.cast_to(types, partial_types)), - lambda x: cast(types.RecAliasOne, x.cast_to(types, types)), + lambda x: cast(types.RecAliasOne, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.RecAliasOne, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5875,8 +5927,8 @@ def RecursiveClassWithAliasIndirection( return baml_py.BamlSyncStream[partial_types.NodeWithAliasIndirection, types.NodeWithAliasIndirection]( raw, - lambda x: cast(partial_types.NodeWithAliasIndirection, x.cast_to(types, partial_types)), - lambda x: cast(types.NodeWithAliasIndirection, x.cast_to(types, types)), + lambda x: cast(partial_types.NodeWithAliasIndirection, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.NodeWithAliasIndirection, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5905,8 +5957,8 @@ def ReturnAliasWithMergedAttributes( return baml_py.BamlSyncStream[Checked[Optional[int],types.Literal["gt_ten"]], Checked[int,types.Literal["gt_ten"]]]( raw, - lambda x: cast(Checked[Optional[int],types.Literal["gt_ten"]], x.cast_to(types, partial_types)), - lambda x: cast(Checked[int,types.Literal["gt_ten"]], x.cast_to(types, types)), + lambda x: cast(Checked[Optional[int],types.Literal["gt_ten"]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Checked[int,types.Literal["gt_ten"]], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5935,8 +5987,8 @@ def ReturnFailingAssert( return baml_py.BamlSyncStream[Optional[int], int]( raw, - lambda x: cast(Optional[int], x.cast_to(types, partial_types)), - lambda x: cast(int, x.cast_to(types, types)), + lambda x: cast(Optional[int], x.cast_to(types, types, partial_types, True)), + lambda x: cast(int, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5965,8 +6017,8 @@ def ReturnMalformedConstraints( return baml_py.BamlSyncStream[partial_types.MalformedConstraints, types.MalformedConstraints]( raw, - lambda x: cast(partial_types.MalformedConstraints, x.cast_to(types, partial_types)), - lambda x: cast(types.MalformedConstraints, x.cast_to(types, types)), + lambda x: cast(partial_types.MalformedConstraints, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.MalformedConstraints, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -5995,8 +6047,8 @@ def SchemaDescriptions( return baml_py.BamlSyncStream[partial_types.Schema, types.Schema]( raw, - lambda x: cast(partial_types.Schema, x.cast_to(types, partial_types)), - lambda x: cast(types.Schema, x.cast_to(types, types)), + lambda x: cast(partial_types.Schema, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.Schema, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6025,8 +6077,8 @@ def SimpleRecursiveListAlias( return baml_py.BamlSyncStream[types.RecursiveListAlias, types.RecursiveListAlias]( raw, - lambda x: cast(types.RecursiveListAlias, x.cast_to(types, partial_types)), - lambda x: cast(types.RecursiveListAlias, x.cast_to(types, types)), + lambda x: cast(types.RecursiveListAlias, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.RecursiveListAlias, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6055,8 +6107,8 @@ def SimpleRecursiveMapAlias( return baml_py.BamlSyncStream[types.RecursiveMapAlias, types.RecursiveMapAlias]( raw, - lambda x: cast(types.RecursiveMapAlias, x.cast_to(types, partial_types)), - lambda x: cast(types.RecursiveMapAlias, x.cast_to(types, types)), + lambda x: cast(types.RecursiveMapAlias, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.RecursiveMapAlias, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6085,8 +6137,8 @@ def StreamBigNumbers( return baml_py.BamlSyncStream[partial_types.BigNumbers, types.BigNumbers]( raw, - lambda x: cast(partial_types.BigNumbers, x.cast_to(types, partial_types)), - lambda x: cast(types.BigNumbers, x.cast_to(types, types)), + lambda x: cast(partial_types.BigNumbers, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.BigNumbers, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6116,8 +6168,8 @@ def StreamFailingAssertion( return baml_py.BamlSyncStream[partial_types.TwoStoriesOneTitle, types.TwoStoriesOneTitle]( raw, - lambda x: cast(partial_types.TwoStoriesOneTitle, x.cast_to(types, partial_types)), - lambda x: cast(types.TwoStoriesOneTitle, x.cast_to(types, types)), + lambda x: cast(partial_types.TwoStoriesOneTitle, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.TwoStoriesOneTitle, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6146,8 +6198,8 @@ def StreamOneBigNumber( return baml_py.BamlSyncStream[Optional[int], int]( raw, - lambda x: cast(Optional[int], x.cast_to(types, partial_types)), - lambda x: cast(int, x.cast_to(types, types)), + lambda x: cast(Optional[int], x.cast_to(types, types, partial_types, True)), + lambda x: cast(int, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6176,8 +6228,8 @@ def StreamUnionIntegers( return baml_py.BamlSyncStream[List[Optional[Union[Optional[int], Optional[str]]]], List[Union[int, str]]]( raw, - lambda x: cast(List[Optional[Union[Optional[int], Optional[str]]]], x.cast_to(types, partial_types)), - lambda x: cast(List[Union[int, str]], x.cast_to(types, types)), + lambda x: cast(List[Optional[Union[Optional[int], Optional[str]]]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(List[Union[int, str]], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6207,8 +6259,8 @@ def StreamingCompoundNumbers( return baml_py.BamlSyncStream[partial_types.CompoundBigNumbers, types.CompoundBigNumbers]( raw, - lambda x: cast(partial_types.CompoundBigNumbers, x.cast_to(types, partial_types)), - lambda x: cast(types.CompoundBigNumbers, x.cast_to(types, types)), + lambda x: cast(partial_types.CompoundBigNumbers, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.CompoundBigNumbers, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6237,8 +6289,8 @@ def TestAnthropic( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6267,8 +6319,8 @@ def TestAnthropicShorthand( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6297,8 +6349,8 @@ def TestAws( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6327,8 +6379,8 @@ def TestAwsInvalidAccessKey( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6357,8 +6409,8 @@ def TestAwsInvalidProfile( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6387,8 +6439,8 @@ def TestAwsInvalidRegion( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6417,8 +6469,8 @@ def TestAwsInvalidSessionToken( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6447,8 +6499,8 @@ def TestAzure( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6477,8 +6529,8 @@ def TestAzureFailure( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6508,8 +6560,8 @@ def TestCaching( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6537,8 +6589,8 @@ def TestFallbackClient( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6567,8 +6619,8 @@ def TestFallbackToShorthand( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6597,8 +6649,8 @@ def TestFnNamedArgsSingleBool( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6627,8 +6679,8 @@ def TestFnNamedArgsSingleClass( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6657,8 +6709,8 @@ def TestFnNamedArgsSingleEnumList( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6687,8 +6739,8 @@ def TestFnNamedArgsSingleFloat( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6717,8 +6769,8 @@ def TestFnNamedArgsSingleInt( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6747,8 +6799,8 @@ def TestFnNamedArgsSingleMapStringToClass( return baml_py.BamlSyncStream[Dict[str, partial_types.StringToClassEntry], Dict[str, types.StringToClassEntry]]( raw, - lambda x: cast(Dict[str, partial_types.StringToClassEntry], x.cast_to(types, partial_types)), - lambda x: cast(Dict[str, types.StringToClassEntry], x.cast_to(types, types)), + lambda x: cast(Dict[str, partial_types.StringToClassEntry], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Dict[str, types.StringToClassEntry], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6777,8 +6829,8 @@ def TestFnNamedArgsSingleMapStringToMap( return baml_py.BamlSyncStream[Dict[str, Dict[str, Optional[str]]], Dict[str, Dict[str, str]]]( raw, - lambda x: cast(Dict[str, Dict[str, Optional[str]]], x.cast_to(types, partial_types)), - lambda x: cast(Dict[str, Dict[str, str]], x.cast_to(types, types)), + lambda x: cast(Dict[str, Dict[str, Optional[str]]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Dict[str, Dict[str, str]], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6807,8 +6859,8 @@ def TestFnNamedArgsSingleMapStringToString( return baml_py.BamlSyncStream[Dict[str, Optional[str]], Dict[str, str]]( raw, - lambda x: cast(Dict[str, Optional[str]], x.cast_to(types, partial_types)), - lambda x: cast(Dict[str, str], x.cast_to(types, types)), + lambda x: cast(Dict[str, Optional[str]], x.cast_to(types, types, partial_types, True)), + lambda x: cast(Dict[str, str], x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6837,8 +6889,8 @@ def TestFnNamedArgsSingleString( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6867,8 +6919,8 @@ def TestFnNamedArgsSingleStringArray( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6897,8 +6949,8 @@ def TestFnNamedArgsSingleStringList( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6927,8 +6979,8 @@ def TestGemini( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6956,8 +7008,8 @@ def TestGeminiOpenAiGeneric( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -6986,8 +7038,8 @@ def TestGeminiSystem( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7016,8 +7068,8 @@ def TestGeminiSystemAsChat( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7046,8 +7098,8 @@ def TestImageInput( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7076,8 +7128,8 @@ def TestImageInputAnthropic( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7106,8 +7158,8 @@ def TestImageListInput( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7137,8 +7189,8 @@ def TestMulticlassNamedArgs( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7167,8 +7219,8 @@ def TestNamedArgsLiteralBool( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7197,8 +7249,8 @@ def TestNamedArgsLiteralInt( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7227,8 +7279,8 @@ def TestNamedArgsLiteralString( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7257,8 +7309,8 @@ def TestOllama( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7287,8 +7339,8 @@ def TestOpenAILegacyProvider( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7317,8 +7369,8 @@ def TestOpenAIShorthand( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7346,8 +7398,8 @@ def TestRetryConstant( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7375,8 +7427,8 @@ def TestRetryExponential( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7404,8 +7456,8 @@ def TestSingleFallbackClient( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7434,8 +7486,8 @@ def TestVertex( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7463,8 +7515,8 @@ def TestVertexWithSystemInstructions( return baml_py.BamlSyncStream[Optional[str], str]( raw, - lambda x: cast(Optional[str], x.cast_to(types, partial_types)), - lambda x: cast(str, x.cast_to(types, types)), + lambda x: cast(Optional[str], x.cast_to(types, types, partial_types, True)), + lambda x: cast(str, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7493,8 +7545,8 @@ def UnionTest_Function( return baml_py.BamlSyncStream[partial_types.UnionTest_ReturnType, types.UnionTest_ReturnType]( raw, - lambda x: cast(partial_types.UnionTest_ReturnType, x.cast_to(types, partial_types)), - lambda x: cast(types.UnionTest_ReturnType, x.cast_to(types, types)), + lambda x: cast(partial_types.UnionTest_ReturnType, x.cast_to(types, types, partial_types, True)), + lambda x: cast(types.UnionTest_ReturnType, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7523,8 +7575,8 @@ def UseBlockConstraint( return baml_py.BamlSyncStream[Optional[int], int]( raw, - lambda x: cast(Optional[int], x.cast_to(types, partial_types)), - lambda x: cast(int, x.cast_to(types, types)), + lambda x: cast(Optional[int], x.cast_to(types, types, partial_types, True)), + lambda x: cast(int, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7553,8 +7605,8 @@ def UseMalformedConstraints( return baml_py.BamlSyncStream[Optional[int], int]( raw, - lambda x: cast(Optional[int], x.cast_to(types, partial_types)), - lambda x: cast(int, x.cast_to(types, types)), + lambda x: cast(Optional[int], x.cast_to(types, types, partial_types, True)), + lambda x: cast(int, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) @@ -7583,8 +7635,8 @@ def UseNestedBlockConstraint( return baml_py.BamlSyncStream[Optional[int], int]( raw, - lambda x: cast(Optional[int], x.cast_to(types, partial_types)), - lambda x: cast(int, x.cast_to(types, types)), + lambda x: cast(Optional[int], x.cast_to(types, types, partial_types, True)), + lambda x: cast(int, x.cast_to(types, types, partial_types, False)), self.__ctx_manager.get(), ) diff --git a/integ-tests/python/baml_client/type_builder.py b/integ-tests/python/baml_client/type_builder.py index 91661856e1..485b83edf7 100644 --- a/integ-tests/python/baml_client/type_builder.py +++ b/integ-tests/python/baml_client/type_builder.py @@ -20,7 +20,7 @@ class TypeBuilder(_TypeBuilder): def __init__(self): super().__init__(classes=set( - ["BigNumbers","BinaryNode","Blah","BlockConstraint","BlockConstraintForParam","BookOrder","ClassForNullLiteral","ClassOptionalOutput","ClassOptionalOutput2","ClassToRecAlias","ClassWithImage","CompoundBigNumbers","ContactInfo","CustomTaskResult","DummyOutput","DynInputOutput","DynamicClassOne","DynamicClassTwo","DynamicOutput","Earthling","Education","Email","EmailAddress","Event","FakeImage","FlightConfirmation","FooAny","Forest","FormatterTest0","FormatterTest1","FormatterTest2","FormatterTest3","GroceryReceipt","InnerClass","InnerClass2","InputClass","InputClassNested","LinkedList","LinkedListAliasNode","LiteralClassHello","LiteralClassOne","LiteralClassTwo","MalformedConstraints","MalformedConstraints2","Martian","MergeAttrs","NamedArgsSingleClass","Nested","Nested2","NestedBlockConstraint","NestedBlockConstraintForParam","Node","NodeWithAliasIndirection","OptionalListAndMap","OptionalTest_Prop1","OptionalTest_ReturnType","OrderInfo","OriginalA","OriginalB","Person","PhoneNumber","Quantity","RaysData","ReceiptInfo","ReceiptItem","Recipe","Resume","Schema","SearchParams","SomeClassNestedDynamic","StringToClassEntry","TestClassAlias","TestClassNested","TestClassWithEnum","TestOutputClass","Tree","TwoStoriesOneTitle","UnionTest_ReturnType","WithReasoning",] + ["BigNumbers","BinaryNode","Blah","BlockConstraint","BlockConstraintForParam","BookOrder","ClassForNullLiteral","ClassOptionalOutput","ClassOptionalOutput2","ClassToRecAlias","ClassWithBlockDone","ClassWithImage","ClassWithoutDone","CompoundBigNumbers","ContactInfo","CustomTaskResult","DummyOutput","DynInputOutput","DynamicClassOne","DynamicClassTwo","DynamicOutput","Earthling","Education","Email","EmailAddress","Event","FakeImage","FlightConfirmation","FooAny","Forest","FormatterTest0","FormatterTest1","FormatterTest2","FormatterTest3","GroceryReceipt","InnerClass","InnerClass2","InputClass","InputClassNested","LinkedList","LinkedListAliasNode","LiteralClassHello","LiteralClassOne","LiteralClassTwo","MalformedConstraints","MalformedConstraints2","Martian","MergeAttrs","NamedArgsSingleClass","Nested","Nested2","NestedBlockConstraint","NestedBlockConstraintForParam","Node","NodeWithAliasIndirection","OptionalListAndMap","OptionalTest_Prop1","OptionalTest_ReturnType","OrderInfo","OriginalA","OriginalB","Person","PhoneNumber","Quantity","RaysData","ReceiptInfo","ReceiptItem","Recipe","Resume","Schema","SearchParams","SemanticContainer","SmallThing","SomeClassNestedDynamic","StringToClassEntry","TestClassAlias","TestClassNested","TestClassWithEnum","TestOutputClass","Tree","TwoStoriesOneTitle","UnionTest_ReturnType","WithReasoning",] ), enums=set( ["AliasedEnum","Category","Category2","Category3","Color","DataType","DynEnumOne","DynEnumTwo","EnumInClass","EnumOutput","Hobby","MapKey","NamedArgsSingleEnum","NamedArgsSingleEnumList","OptionalTest_CategoryType","OrderStatus","Tag","TestEnum",] )) diff --git a/integ-tests/python/baml_client/types.py b/integ-tests/python/baml_client/types.py index 8e5a9e39c3..e54edaa290 100644 --- a/integ-tests/python/baml_client/types.py +++ b/integ-tests/python/baml_client/types.py @@ -38,7 +38,6 @@ def all_succeeded(checks: Dict[CheckName, Check]) -> bool: return all(check.status == "succeeded" for check in get_checks(checks)) - class AliasedEnum(str, Enum): KEY_ONE = "KEY_ONE" @@ -195,11 +194,19 @@ class ClassOptionalOutput2(BaseModel): class ClassToRecAlias(BaseModel): list: "LinkedListAliasNode" +class ClassWithBlockDone(BaseModel): + i_16_digits: int + s_20_words: str + class ClassWithImage(BaseModel): myImage: baml_py.Image param2: str fake_image: "FakeImage" +class ClassWithoutDone(BaseModel): + i_16_digits: int + s_20_words: str + class CompoundBigNumbers(BaseModel): big: "BigNumbers" big_nums: List["BigNumbers"] @@ -458,6 +465,20 @@ class SearchParams(BaseModel): description: List["WithReasoning"] tags: List[Union["Tag", str]] +class SemanticContainer(BaseModel): + sixteen_digit_number: int + string_with_twenty_words: str + class_1: "ClassWithoutDone" + class_2: "ClassWithBlockDone" + class_done_needed: "ClassWithBlockDone" + class_needed: "ClassWithoutDone" + three_small_things: List["SmallThing"] + final_string: str + +class SmallThing(BaseModel): + i_16_digits: int + i_8_digits: int + class SomeClassNestedDynamic(BaseModel): model_config = ConfigDict(extra='allow') hi: str diff --git a/integ-tests/python/tests/test_functions.py b/integ-tests/python/tests/test_functions.py index 3a9131a635..bd7cb870f2 100644 --- a/integ-tests/python/tests/test_functions.py +++ b/integ-tests/python/tests/test_functions.py @@ -2,7 +2,7 @@ import json import os import time -from typing import List +from typing import List, Optional import pytest from assertpy import assert_that from dotenv import load_dotenv @@ -397,7 +397,7 @@ async def test_should_work_for_all_outputs(): literal_string = await b.FnOutputLiteralString(a) assert literal_string == "example output" - list = await b.FnOutputClassList(a) + list = await b.FnOutputClassList(a) # Broken assert len(list) > 0 assert len(list[0].prop1) > 0 @@ -1481,7 +1481,11 @@ async def test_no_stream_big_integer(): msgs: List[int | None] = [] async for msg in stream: msgs.append(msg) + print("msgs:") + print(msgs) res = await stream.get_final_response() + print("res:") + print(res) for msg in msgs: assert True if msg is None else msg == res @@ -1668,3 +1672,51 @@ async def test_null_literal_class_hello(): stream = b.stream.NullLiteralClassHello(s="unused") async for msg in stream: msg.a is None + + +@pytest.mark.asyncio +async def test_semantic_streaming(): + stream = b.stream.MakeSemanticContainer() + + # We will use these to store streaming fields and check them + # for stability. + reference_string: Optional[str] = None + reference_int: Optional[int] = None + + async for msg in stream: + assert "string_with_twenty_words" in dict(msg) + assert "sixteen_digit_number" in dict(msg) + + # Checks for stability of numeric and @stream.done fields. + if msg.sixteen_digit_number is not None: + if reference_int is None: + # Set the reference if it hasn't been set yet. + reference_int = msg.sixteen_digit_number + else: + # If the reference has been set, check that the + # current value matches it. + assert reference_int == msg.sixteen_digit_number + if msg.string_with_twenty_words is not None: + if reference_string is None: + # Set the reference if it hasn't been set yet. + reference_string = msg.string_with_twenty_words + else: + # If the reference has been set, check that the + # current value matches it. + assert reference_string == msg.string_with_twenty_words + + # Checks for @stream.with_state. + if msg.class_needed is not None: + if msg.class_needed.s_20_words.value is not None: + if len(msg.class_needed.s_20_words.value.split(" ")) < 3 and msg.final_string is None: + print(msg) + assert msg.class_needed.s_20_words.state == "Incomplete" + if msg.final_string is not None: + assert msg.class_needed.s_20_words.state == "Complete" + + # Checks for @stream.not_null. + for sub in msg.three_small_things: + assert sub.i_16_digits is not None + + final = await stream.get_final_response() + print(final) diff --git a/integ-tests/ruby/baml_client/client.rb b/integ-tests/ruby/baml_client/client.rb index edde31e32c..81b9bc5bf6 100644 --- a/integ-tests/ruby/baml_client/client.rb +++ b/integ-tests/ruby/baml_client/client.rb @@ -79,7 +79,7 @@ def AaaSamOutputFormat( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -111,7 +111,7 @@ def AliasThatPointsToRecursiveType( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -143,7 +143,7 @@ def AliasWithMultipleAttrs( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -175,7 +175,7 @@ def AliasedInputClass( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -207,7 +207,7 @@ def AliasedInputClass2( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -239,7 +239,7 @@ def AliasedInputClassNested( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -271,7 +271,7 @@ def AliasedInputEnum( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -303,7 +303,7 @@ def AliasedInputList( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -335,7 +335,7 @@ def AllowedOptionals( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -367,7 +367,7 @@ def AssertFn( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -399,7 +399,7 @@ def AudioInput( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -431,7 +431,7 @@ def BuildLinkedList( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -463,7 +463,7 @@ def BuildTree( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -495,7 +495,7 @@ def ClassThatPointsToRecursiveClassThroughAlias( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -527,7 +527,7 @@ def ClassifyDynEnumTwo( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -559,7 +559,7 @@ def ClassifyMessage( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -591,7 +591,7 @@ def ClassifyMessage2( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -623,7 +623,7 @@ def ClassifyMessage3( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -655,7 +655,7 @@ def Completion( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -687,7 +687,7 @@ def CustomTask( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -719,7 +719,7 @@ def DescribeImage( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -751,7 +751,7 @@ def DescribeImage2( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -783,7 +783,7 @@ def DescribeImage3( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -815,7 +815,7 @@ def DescribeImage4( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -847,7 +847,7 @@ def DifferentiateUnions( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -879,7 +879,7 @@ def DummyOutputFunction( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -911,7 +911,7 @@ def DynamicFunc( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -943,7 +943,7 @@ def DynamicInputOutput( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -975,7 +975,7 @@ def DynamicListInputOutput( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1007,7 +1007,7 @@ def ExpectFailure( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1039,7 +1039,7 @@ def ExtractContactInfo( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1071,7 +1071,7 @@ def ExtractHobby( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1103,7 +1103,7 @@ def ExtractNames( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1135,7 +1135,7 @@ def ExtractPeople( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1167,7 +1167,7 @@ def ExtractReceiptInfo( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1199,7 +1199,7 @@ def ExtractResume( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1231,7 +1231,7 @@ def ExtractResume2( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1263,7 +1263,7 @@ def FnClassOptionalOutput( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1295,7 +1295,7 @@ def FnClassOptionalOutput2( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1327,7 +1327,7 @@ def FnEnumListOutput( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1359,7 +1359,7 @@ def FnEnumOutput( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1391,7 +1391,7 @@ def FnLiteralClassInputOutput( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1423,7 +1423,7 @@ def FnLiteralUnionClassInputOutput( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1455,7 +1455,7 @@ def FnNamedArgsSingleStringOptional( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1487,7 +1487,7 @@ def FnOutputBool( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1519,7 +1519,7 @@ def FnOutputClass( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1551,7 +1551,7 @@ def FnOutputClassList( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1583,7 +1583,7 @@ def FnOutputClassNested( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1615,7 +1615,7 @@ def FnOutputClassWithEnum( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1647,7 +1647,7 @@ def FnOutputInt( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1679,7 +1679,7 @@ def FnOutputLiteralBool( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1711,7 +1711,7 @@ def FnOutputLiteralInt( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1743,7 +1743,7 @@ def FnOutputLiteralString( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1775,7 +1775,7 @@ def FnOutputStringList( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1807,7 +1807,7 @@ def FnTestAliasedEnumOutput( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1839,7 +1839,7 @@ def FnTestClassAlias( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1871,7 +1871,7 @@ def FnTestNamedArgsSingleEnum( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1903,7 +1903,7 @@ def GetDataType( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1935,7 +1935,7 @@ def GetOrderInfo( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1967,7 +1967,7 @@ def GetQuery( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -1999,7 +1999,7 @@ def InOutEnumMapKey( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2031,7 +2031,7 @@ def InOutLiteralStringUnionMapKey( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2063,7 +2063,7 @@ def InOutSingleLiteralStringMapKey( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2095,7 +2095,7 @@ def JsonTypeAliasCycle( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2127,7 +2127,7 @@ def LiteralUnionsTest( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2159,7 +2159,7 @@ def MakeBlockConstraint( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2191,7 +2191,39 @@ def MakeNestedBlockConstraint( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) + end + + sig { + params( + varargs: T.untyped, + + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::Types::SemanticContainer) + } + def MakeSemanticContainer( + *varargs, + + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("MakeSemanticContainer may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.call_function( + "MakeSemanticContainer", + { + + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2223,7 +2255,7 @@ def MapAlias( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2255,7 +2287,7 @@ def MergeAliasAttributes( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2287,7 +2319,7 @@ def MyFunc( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2319,7 +2351,7 @@ def NestedAlias( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2351,7 +2383,7 @@ def NullLiteralClassHello( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2383,7 +2415,7 @@ def OptionalTest_Function( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2415,7 +2447,7 @@ def PredictAge( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2447,7 +2479,7 @@ def PredictAgeBare( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2479,7 +2511,7 @@ def PrimitiveAlias( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2511,7 +2543,7 @@ def PromptTestClaude( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2543,7 +2575,7 @@ def PromptTestClaudeChat( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2575,7 +2607,7 @@ def PromptTestClaudeChatNoSystem( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2607,7 +2639,7 @@ def PromptTestOpenAI( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2639,7 +2671,7 @@ def PromptTestOpenAIChat( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2671,7 +2703,7 @@ def PromptTestOpenAIChatNoSystem( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2703,7 +2735,7 @@ def PromptTestStreaming( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2735,7 +2767,7 @@ def RecursiveAliasCycle( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2767,7 +2799,7 @@ def RecursiveClassWithAliasIndirection( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2799,7 +2831,7 @@ def ReturnAliasWithMergedAttributes( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2831,7 +2863,7 @@ def ReturnFailingAssert( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2863,7 +2895,7 @@ def ReturnMalformedConstraints( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2895,7 +2927,7 @@ def SchemaDescriptions( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2927,7 +2959,7 @@ def SimpleRecursiveListAlias( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2959,7 +2991,7 @@ def SimpleRecursiveMapAlias( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -2991,7 +3023,7 @@ def StreamBigNumbers( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3023,7 +3055,7 @@ def StreamFailingAssertion( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3055,7 +3087,7 @@ def StreamOneBigNumber( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3087,7 +3119,7 @@ def StreamUnionIntegers( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3119,7 +3151,7 @@ def StreamingCompoundNumbers( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3151,7 +3183,7 @@ def TestAnthropic( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3183,7 +3215,7 @@ def TestAnthropicShorthand( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3215,7 +3247,7 @@ def TestAws( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3247,7 +3279,7 @@ def TestAwsInvalidAccessKey( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3279,7 +3311,7 @@ def TestAwsInvalidProfile( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3311,7 +3343,7 @@ def TestAwsInvalidRegion( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3343,7 +3375,7 @@ def TestAwsInvalidSessionToken( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3375,7 +3407,7 @@ def TestAzure( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3407,7 +3439,7 @@ def TestAzureFailure( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3439,7 +3471,7 @@ def TestCaching( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3471,7 +3503,7 @@ def TestFallbackClient( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3503,7 +3535,7 @@ def TestFallbackToShorthand( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3535,7 +3567,7 @@ def TestFnNamedArgsSingleBool( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3567,7 +3599,7 @@ def TestFnNamedArgsSingleClass( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3599,7 +3631,7 @@ def TestFnNamedArgsSingleEnumList( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3631,7 +3663,7 @@ def TestFnNamedArgsSingleFloat( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3663,7 +3695,7 @@ def TestFnNamedArgsSingleInt( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3695,7 +3727,7 @@ def TestFnNamedArgsSingleMapStringToClass( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3727,7 +3759,7 @@ def TestFnNamedArgsSingleMapStringToMap( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3759,7 +3791,7 @@ def TestFnNamedArgsSingleMapStringToString( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3791,7 +3823,7 @@ def TestFnNamedArgsSingleString( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3823,7 +3855,7 @@ def TestFnNamedArgsSingleStringArray( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3855,7 +3887,7 @@ def TestFnNamedArgsSingleStringList( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3887,7 +3919,7 @@ def TestGemini( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3919,7 +3951,7 @@ def TestGeminiOpenAiGeneric( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3951,7 +3983,7 @@ def TestGeminiSystem( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -3983,7 +4015,7 @@ def TestGeminiSystemAsChat( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -4015,7 +4047,7 @@ def TestImageInput( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -4047,7 +4079,7 @@ def TestImageInputAnthropic( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -4079,7 +4111,7 @@ def TestImageListInput( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -4111,7 +4143,7 @@ def TestMulticlassNamedArgs( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -4143,7 +4175,7 @@ def TestNamedArgsLiteralBool( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -4175,7 +4207,7 @@ def TestNamedArgsLiteralInt( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -4207,7 +4239,7 @@ def TestNamedArgsLiteralString( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -4239,7 +4271,7 @@ def TestOllama( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -4271,7 +4303,7 @@ def TestOpenAILegacyProvider( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -4303,7 +4335,7 @@ def TestOpenAIShorthand( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -4335,7 +4367,7 @@ def TestRetryConstant( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -4367,7 +4399,7 @@ def TestRetryExponential( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -4399,7 +4431,7 @@ def TestSingleFallbackClient( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -4431,7 +4463,7 @@ def TestVertex( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -4463,7 +4495,7 @@ def TestVertexWithSystemInstructions( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -4495,7 +4527,7 @@ def UnionTest_Function( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -4527,7 +4559,7 @@ def UseBlockConstraint( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -4559,7 +4591,7 @@ def UseMalformedConstraints( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end sig { @@ -4591,7 +4623,7 @@ def UseNestedBlockConstraint( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - (raw.parsed_using_types(Baml::Types)) + (raw.parsed_using_types(Baml::Types, Baml::PartialTypes, false)) end @@ -4636,7 +4668,7 @@ def AaaSamOutputFormat( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::Recipe, Baml::Types::Recipe].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::Recipe), Baml::Types::Recipe].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -4671,7 +4703,7 @@ def AliasThatPointsToRecursiveType( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::LinkedListAliasNode, Baml::Types::LinkedListAliasNode].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::LinkedListAliasNode), Baml::Types::LinkedListAliasNode].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -4706,7 +4738,7 @@ def AliasWithMultipleAttrs( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::Checked[T.nilable(Integer)], Baml::Checked[Integer]].new( + Baml::BamlStream[T.nilable(Baml::Checked[Integer]), Baml::Checked[Integer]].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -4916,7 +4948,7 @@ def AllowedOptionals( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::OptionalListAndMap, Baml::Types::OptionalListAndMap].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::OptionalListAndMap), Baml::Types::OptionalListAndMap].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -5021,7 +5053,7 @@ def BuildLinkedList( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::LinkedList, Baml::Types::LinkedList].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::LinkedList), Baml::Types::LinkedList].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -5056,7 +5088,7 @@ def BuildTree( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::Tree, Baml::Types::Tree].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::Tree), Baml::Types::Tree].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -5091,7 +5123,7 @@ def ClassThatPointsToRecursiveClassThroughAlias( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::ClassToRecAlias, Baml::Types::ClassToRecAlias].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::ClassToRecAlias), Baml::Types::ClassToRecAlias].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -5301,7 +5333,7 @@ def CustomTask( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[T.nilable(T.any(Baml::PartialTypes::BookOrder, Baml::PartialTypes::FlightConfirmation, Baml::PartialTypes::GroceryReceipt)), T.any(Baml::Types::BookOrder, Baml::Types::FlightConfirmation, Baml::Types::GroceryReceipt)].new( + Baml::BamlStream[T.nilable(T.any(T.nilable(Baml::PartialTypes::BookOrder), T.nilable(Baml::PartialTypes::FlightConfirmation), T.nilable(Baml::PartialTypes::GroceryReceipt))), T.any(Baml::Types::BookOrder, Baml::Types::FlightConfirmation, Baml::Types::GroceryReceipt)].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -5476,7 +5508,7 @@ def DifferentiateUnions( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[T.nilable(T.any(Baml::PartialTypes::OriginalA, Baml::PartialTypes::OriginalB)), T.any(Baml::Types::OriginalA, Baml::Types::OriginalB)].new( + Baml::BamlStream[T.nilable(T.any(T.nilable(Baml::PartialTypes::OriginalA), T.nilable(Baml::PartialTypes::OriginalB))), T.any(Baml::Types::OriginalA, Baml::Types::OriginalB)].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -5511,7 +5543,7 @@ def DummyOutputFunction( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::DummyOutput, Baml::Types::DummyOutput].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::DummyOutput), Baml::Types::DummyOutput].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -5546,7 +5578,7 @@ def DynamicFunc( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::DynamicClassTwo, Baml::Types::DynamicClassTwo].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::DynamicClassTwo), Baml::Types::DynamicClassTwo].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -5581,7 +5613,7 @@ def DynamicInputOutput( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::DynInputOutput, Baml::Types::DynInputOutput].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::DynInputOutput), Baml::Types::DynInputOutput].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -5616,7 +5648,7 @@ def DynamicListInputOutput( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[T::Array[Baml::PartialTypes::DynInputOutput], T::Array[Baml::Types::DynInputOutput]].new( + Baml::BamlStream[T::Array[T.nilable(Baml::PartialTypes::DynInputOutput)], T::Array[Baml::Types::DynInputOutput]].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -5686,7 +5718,7 @@ def ExtractContactInfo( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::ContactInfo, Baml::Types::ContactInfo].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::ContactInfo), Baml::Types::ContactInfo].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -5791,7 +5823,7 @@ def ExtractPeople( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[T::Array[Baml::PartialTypes::Person], T::Array[Baml::Types::Person]].new( + Baml::BamlStream[T::Array[T.nilable(Baml::PartialTypes::Person)], T::Array[Baml::Types::Person]].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -5826,7 +5858,7 @@ def ExtractReceiptInfo( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::ReceiptInfo, Baml::Types::ReceiptInfo].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::ReceiptInfo), Baml::Types::ReceiptInfo].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -5861,7 +5893,7 @@ def ExtractResume( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::Resume, Baml::Types::Resume].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::Resume), Baml::Types::Resume].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -5896,7 +5928,7 @@ def ExtractResume2( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::Resume, Baml::Types::Resume].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::Resume), Baml::Types::Resume].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -5931,7 +5963,7 @@ def FnClassOptionalOutput( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::ClassOptionalOutput, T.nilable(Baml::Types::ClassOptionalOutput)].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::ClassOptionalOutput), T.nilable(Baml::Types::ClassOptionalOutput)].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -5966,7 +5998,7 @@ def FnClassOptionalOutput2( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::ClassOptionalOutput2, T.nilable(Baml::Types::ClassOptionalOutput2)].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::ClassOptionalOutput2), T.nilable(Baml::Types::ClassOptionalOutput2)].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -6071,7 +6103,7 @@ def FnLiteralClassInputOutput( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::LiteralClassHello, Baml::Types::LiteralClassHello].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::LiteralClassHello), Baml::Types::LiteralClassHello].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -6106,7 +6138,7 @@ def FnLiteralUnionClassInputOutput( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[T.nilable(T.any(Baml::PartialTypes::LiteralClassOne, Baml::PartialTypes::LiteralClassTwo)), T.any(Baml::Types::LiteralClassOne, Baml::Types::LiteralClassTwo)].new( + Baml::BamlStream[T.nilable(T.any(T.nilable(Baml::PartialTypes::LiteralClassOne), T.nilable(Baml::PartialTypes::LiteralClassTwo))), T.any(Baml::Types::LiteralClassOne, Baml::Types::LiteralClassTwo)].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -6211,7 +6243,7 @@ def FnOutputClass( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::TestOutputClass, Baml::Types::TestOutputClass].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::TestOutputClass), Baml::Types::TestOutputClass].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -6246,7 +6278,7 @@ def FnOutputClassList( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[T::Array[Baml::PartialTypes::TestOutputClass], T::Array[Baml::Types::TestOutputClass]].new( + Baml::BamlStream[T::Array[T.nilable(Baml::PartialTypes::TestOutputClass)], T::Array[Baml::Types::TestOutputClass]].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -6281,7 +6313,7 @@ def FnOutputClassNested( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::TestClassNested, Baml::Types::TestClassNested].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::TestClassNested), Baml::Types::TestClassNested].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -6316,7 +6348,7 @@ def FnOutputClassWithEnum( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::TestClassWithEnum, Baml::Types::TestClassWithEnum].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::TestClassWithEnum), Baml::Types::TestClassWithEnum].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -6561,7 +6593,7 @@ def FnTestClassAlias( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::TestClassAlias, Baml::Types::TestClassAlias].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::TestClassAlias), Baml::Types::TestClassAlias].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -6631,7 +6663,7 @@ def GetDataType( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::RaysData, Baml::Types::RaysData].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::RaysData), Baml::Types::RaysData].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -6666,7 +6698,7 @@ def GetOrderInfo( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::OrderInfo, Baml::Types::OrderInfo].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::OrderInfo), Baml::Types::OrderInfo].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -6701,7 +6733,7 @@ def GetQuery( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::SearchParams, Baml::Types::SearchParams].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::SearchParams), Baml::Types::SearchParams].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -6911,7 +6943,7 @@ def MakeBlockConstraint( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::Checked[Baml::PartialTypes::BlockConstraint], Baml::Checked[Baml::Types::BlockConstraint]].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::BlockConstraint), Baml::Checked[Baml::Types::BlockConstraint]].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -6946,7 +6978,42 @@ def MakeNestedBlockConstraint( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::NestedBlockConstraint, Baml::Types::NestedBlockConstraint].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::NestedBlockConstraint), Baml::Types::NestedBlockConstraint].new( + ffi_stream: raw, + ctx_manager: @ctx_manager + ) + end + + sig { + params( + varargs: T.untyped, + + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::BamlStream[Baml::Types::SemanticContainer]) + } + def MakeSemanticContainer( + *varargs, + + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("MakeSemanticContainer may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.stream_function( + "MakeSemanticContainer", + { + + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + Baml::BamlStream[T.nilable(Baml::PartialTypes::SemanticContainer), Baml::Types::SemanticContainer].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -7016,7 +7083,7 @@ def MergeAliasAttributes( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::MergeAttrs, Baml::Types::MergeAttrs].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::MergeAttrs), Baml::Types::MergeAttrs].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -7051,7 +7118,7 @@ def MyFunc( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::DynamicOutput, Baml::Types::DynamicOutput].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::DynamicOutput), Baml::Types::DynamicOutput].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -7121,7 +7188,7 @@ def NullLiteralClassHello( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::ClassForNullLiteral, Baml::Types::ClassForNullLiteral].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::ClassForNullLiteral), Baml::Types::ClassForNullLiteral].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -7156,7 +7223,7 @@ def OptionalTest_Function( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[T::Array[Baml::PartialTypes::OptionalTest_ReturnType], T::Array[T.nilable(Baml::Types::OptionalTest_ReturnType)]].new( + Baml::BamlStream[T::Array[T.nilable(Baml::PartialTypes::OptionalTest_ReturnType)], T::Array[T.nilable(Baml::Types::OptionalTest_ReturnType)]].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -7191,7 +7258,7 @@ def PredictAge( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::FooAny, Baml::Types::FooAny].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::FooAny), Baml::Types::FooAny].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -7226,7 +7293,7 @@ def PredictAgeBare( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::Checked[T.nilable(Integer)], Baml::Checked[Integer]].new( + Baml::BamlStream[T.nilable(Baml::Checked[Integer]), Baml::Checked[Integer]].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -7576,7 +7643,7 @@ def RecursiveClassWithAliasIndirection( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::NodeWithAliasIndirection, Baml::Types::NodeWithAliasIndirection].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::NodeWithAliasIndirection), Baml::Types::NodeWithAliasIndirection].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -7611,7 +7678,7 @@ def ReturnAliasWithMergedAttributes( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::Checked[T.nilable(Integer)], Baml::Checked[Integer]].new( + Baml::BamlStream[T.nilable(Baml::Checked[Integer]), Baml::Checked[Integer]].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -7681,7 +7748,7 @@ def ReturnMalformedConstraints( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::MalformedConstraints, Baml::Types::MalformedConstraints].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::MalformedConstraints), Baml::Types::MalformedConstraints].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -7716,7 +7783,7 @@ def SchemaDescriptions( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::Schema, Baml::Types::Schema].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::Schema), Baml::Types::Schema].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -7821,7 +7888,7 @@ def StreamBigNumbers( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::BigNumbers, Baml::Types::BigNumbers].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::BigNumbers), Baml::Types::BigNumbers].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -7856,7 +7923,7 @@ def StreamFailingAssertion( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::TwoStoriesOneTitle, Baml::Types::TwoStoriesOneTitle].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::TwoStoriesOneTitle), Baml::Types::TwoStoriesOneTitle].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -7961,7 +8028,7 @@ def StreamingCompoundNumbers( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::CompoundBigNumbers, Baml::Types::CompoundBigNumbers].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::CompoundBigNumbers), Baml::Types::CompoundBigNumbers].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -8591,7 +8658,7 @@ def TestFnNamedArgsSingleMapStringToClass( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[T::Hash[String, Baml::PartialTypes::StringToClassEntry], T::Hash[String, Baml::Types::StringToClassEntry]].new( + Baml::BamlStream[T::Hash[String, T.nilable(Baml::PartialTypes::StringToClassEntry)], T::Hash[String, Baml::Types::StringToClassEntry]].new( ffi_stream: raw, ctx_manager: @ctx_manager ) @@ -9466,7 +9533,7 @@ def UnionTest_Function( baml_options[:tb]&.instance_variable_get(:@registry), baml_options[:client_registry], ) - Baml::BamlStream[Baml::PartialTypes::UnionTest_ReturnType, Baml::Types::UnionTest_ReturnType].new( + Baml::BamlStream[T.nilable(Baml::PartialTypes::UnionTest_ReturnType), Baml::Types::UnionTest_ReturnType].new( ffi_stream: raw, ctx_manager: @ctx_manager ) diff --git a/integ-tests/ruby/baml_client/inlined.rb b/integ-tests/ruby/baml_client/inlined.rb index c5dd5b6f3b..d339b286ff 100644 --- a/integ-tests/ruby/baml_client/inlined.rb +++ b/integ-tests/ruby/baml_client/inlined.rb @@ -68,7 +68,7 @@ module Inlined "test-files/functions/output/class-list.baml" => "function FnOutputClassList(input: string) -> TestOutputClass[] {\n client GPT35\n prompt #\"\n Return a JSON array that follows this schema: \n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest FnOutputClassList {\n functions [FnOutputClassList]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/class-nested.baml" => "class TestClassNested {\n prop1 string\n prop2 InnerClass\n}\n\nclass InnerClass {\n prop1 string\n prop2 string\n inner InnerClass2\n}\n\nclass InnerClass2 {\n prop2 int\n prop3 float\n}\n\nfunction FnOutputClassNested(input: string) -> TestClassNested {\n client GPT35\n prompt #\"\n Return a made up json blob that matches this schema:\n {{ctx.output_format}}\n ---\n\n JSON:\n \"#\n}\n\ntest FnOutputClassNested {\n functions [FnOutputClassNested]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/class-with-enum.baml" => "enum EnumInClass {\n ONE\n TWO\n}\n\nclass TestClassWithEnum {\n prop1 string\n prop2 EnumInClass\n}\n\nfunction FnOutputClassWithEnum(input: string) -> TestClassWithEnum {\n client GPT35\n prompt #\"\n Return a made up json blob that matches this schema:\n {{ctx.output_format}}\n ---\n\n JSON:\n \"#\n}\n\ntest FnOutputClassWithEnum {\n functions [FnOutputClassWithEnum]\n args {\n input \"example input\"\n }\n}\n", - "test-files/functions/output/class.baml" => "class TestOutputClass {\n prop1 string\n prop2 int\n}\n\nfunction FnOutputClass(input: string) -> TestOutputClass {\n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n For the prop2, always return a 540\n\n JSON:\n \"#\n}\n\ntest TestClass {\n functions [FnOutputClass]\n args {\n input \"example input\"\n }\n}\n", + "test-files/functions/output/class.baml" => "class TestOutputClass {\n prop1 string @description(\"A long string with about 200 words\")\n prop2 int\n}\n\nfunction FnOutputClass(input: string) -> TestOutputClass {\n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n For the prop2, always return a 540\n\n JSON:\n \"#\n}\n\ntest TestClass {\n functions [FnOutputClass]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/enum-list.baml" => "function FnEnumListOutput(input: string) -> EnumOutput[] {\n client GPT35\n prompt #\"\n Print out two of these values randomly selected from the list below in a json array.\n\n {{ctx.output_format}}\n\n Answer:\n \"#\n} \n\ntest FnEnumListOutput {\n functions [FnEnumListOutput]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/enum.baml" => "/// An enum with three values,\n/// ONE, TWO and THREE.\nenum EnumOutput {\n\n /// The first enum.\n ONE\n\n /// The second enum.\n TWO\n THREE\n\n @@alias(\"VALUE_ENUM\")\n}\n\nfunction FnEnumOutput(input: string) -> EnumOutput {\n client GPT35\n prompt #\"\n Choose one of these values randomly. Before you give the answer, write out an unrelated haiku about the ocean.\n\n {{ctx.output_format(prefix=null)}}\n \"#\n}\n\ntest FnEnumOutput {\n functions [FnEnumOutput]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/int.baml" => "function FnOutputInt(input: string) -> int {\n client GPT35\n prompt #\"\n Return the integer 5 with no additional context.\n \"#\n}\n\ntest FnOutputInt {\n functions [FnOutputInt]\n args {\n input \"example input\"\n }\n}\n", @@ -99,6 +99,7 @@ module Inlined "test-files/providers/openai.baml" => "function PromptTestOpenAI(input: string) -> string {\n client GPT35\n prompt #\"\n Write a nice haiku about {{ input }}\n \"#\n}\n\nfunction TestOpenAILegacyProvider(input: string) -> string {\n client GPT35LegacyProvider\n prompt #\"\n Write a nice haiku about {{ input }}\n \"#\n}\n\nfunction TestOpenAIShorthand(input: string) -> string {\n client GPT35\n prompt #\"\n Write a nice short story about {{ input }}\n \"#\n}", "test-files/providers/tests.baml" => "test TestOpenAIShorthand {\n functions [TestOpenAIShorthand]\n args {\n input \"Donkey kong and peanut butter\"\n }\n}\n\ntest TestAWS {\n functions [\n TestAws\n ]\n args {\n input \"Donkey kong and peanut butter\"\n }\n}\n\ntest TestProvider {\n functions [\n TestAnthropic, TestVertex, PromptTestOpenAI, TestAzure, TestOllama, TestGemini, TestAws,\n TestAwsInvalidRegion,\n TestOpenAIShorthand,\n TestAnthropicShorthand,\n TestAwsInvalidAccessKey,\n TestAwsInvalidProfile,\n TestAwsInvalidSessionToken\n ]\n args {\n input \"Donkey kong and peanut butter\"\n }\n}\n\ntest TestName {\n functions [TestCaching]\n args {\n input #\"\nIn a near-future society where dreams have become a tradable commodity and shared experience, a lonely and socially awkward teenager named Alex discovers they possess a rare and powerful ability to not only view but also manipulate the dreams of others. Initially thrilled by this newfound power, Alex begins subtly altering the dreams of classmates and family members, helping them overcome fears, boost confidence, or experience fantastical adventures. As Alex's skills grow, so does their influence. They start selling premium dream experiences on the black market, crafting intricate and addictive dreamscapes for wealthy clients. However, the line between dream and reality begins to blur for those exposed to Alex's creations. Some clients struggle to differentiate between their true memories and the artificial ones implanted by Alex's dream manipulation.\n\nComplications arise when a mysterious government agency takes notice of Alex's unique abilities. They offer Alex a chance to use their gift for \"the greater good,\" hinting at applications in therapy, criminal rehabilitation, and even national security. Simultaneously, an underground resistance movement reaches out, warning Alex about the dangers of dream manipulation and the potential for mass control and exploitation. Caught between these opposing forces, Alex must navigate a complex web of ethical dilemmas. They grapple with questions of free will, the nature of consciousness, and the responsibility that comes with having power over people's minds. As the consequences of their actions spiral outward, affecting the lives of loved ones and strangers alike, Alex is forced to confront the true nature of their ability and decide how—or if—it should be used.\n\nThe story explores themes of identity, the subconscious mind, the ethics of technology, and the power of imagination. It delves into the potential consequences of a world where our most private thoughts and experiences are no longer truly our own, and examines the fine line between helping others and manipulating them for personal gain or a perceived greater good. The narrative further expands on the societal implications of such abilities, questioning the moral boundaries of altering consciousness and the potential for abuse in a world where dreams can be commodified. It challenges the reader to consider the impact of technology on personal autonomy and the ethical responsibilities of those who wield such power.\n\nAs Alex's journey unfolds, they encounter various individuals whose lives have been touched by their dream manipulations, each presenting a unique perspective on the ethical quandaries at hand. From a classmate who gains newfound confidence to a wealthy client who becomes addicted to the dreamscapes, the ripple effects of Alex's actions are profound and far-reaching. The government agency's interest in Alex's abilities raises questions about the potential for state control and surveillance, while the resistance movement highlights the dangers of unchecked power and the importance of safeguarding individual freedoms.\n\nUltimately, Alex's story is one of self-discovery and moral reckoning, as they must decide whether to embrace their abilities for personal gain, align with the government's vision of a controlled utopia, or join the resistance in their fight for freedom and autonomy. The narrative invites readers to reflect on the nature of reality, the boundaries of human experience, and the ethical implications of a world where dreams are no longer private sanctuaries but shared and manipulated commodities. It also explores the psychological impact on Alex, who must deal with the burden of knowing the intimate fears and desires of others, and the isolation that comes from being unable to share their own dreams without altering them.\n\nThe story further examines the technological advancements that have made dream manipulation possible, questioning the role of innovation in society and the potential for both progress and peril. It considers the societal divide between those who can afford to buy enhanced dream experiences and those who cannot, highlighting issues of inequality and access. As Alex becomes more entangled in the web of their own making, they must confront the possibility that their actions could lead to unintended consequences, not just for themselves but for the fabric of society as a whole.\n\nIn the end, Alex's journey is a cautionary tale about the power of dreams and the responsibilities that come with wielding such influence. It serves as a reminder of the importance of ethical considerations in the face of technological advancement and the need to balance innovation with humanity. The story leaves readers pondering the true cost of a world where dreams are no longer sacred, and the potential for both wonder and danger in the uncharted territories of the mind. But it's also a story about the power of imagination and the potential for change, even in a world where our deepest thoughts are no longer our own. And it's a story about the power of choice, and the importance of fighting for the freedom to dream.\n\nIn conclusion, this story is a reflection on the power of dreams and the responsibilities that come with wielding such influence. It serves as a reminder of the importance of ethical considerations in the face of technological advancement and the need to balance innovation with humanity. The story leaves readers pondering the true cost of a world where dreams are no longer sacred, and the potential for both wonder and danger in the uncharted territories of the mind. But it's also a story about the power of imagination and the potential for change, even in a world where our deepest thoughts are no longer our own. And it's a story about the power of choice, and the importance of fighting for the freedom to dream.\n \"#\n not_cached #\"\n hello world\n \"#\n }\n}", "test-files/providers/vertex.baml" => "function TestVertex(input: string) -> string {\n client Vertex\n prompt #\"\n Write a nice short story about {{ input }}\n \"#\n}\n\nfunction TestVertexWithSystemInstructions() -> string {\n client Vertex\n prompt #\"{{_.role(\"system\")}} You are a helpful assistant\n {{_.role(\"user\")}} Write a poem about llamas\n \"#\n}\n\ntest TestVertex {\n functions [TestVertex, TestVertexWithSystemInstructions]\n args {\n input \"a cat\"\n\n }\n}\n", + "test-files/semantic_streaming/semantic_streaming.baml" => "class SemanticContainer {\n sixteen_digit_number int\n string_with_twenty_words string @stream.done\n class_1 ClassWithoutDone\n class_2 ClassWithBlockDone\n class_done_needed ClassWithBlockDone @stream.not_null\n class_needed ClassWithoutDone @stream.not_null\n three_small_things SmallThing[] @description(\"Should have three items.\")\n final_string string\n}\n\nclass ClassWithoutDone {\n i_16_digits int\n s_20_words string @description(\"A string with 20 words in it\") @stream.with_state\n}\n\nclass ClassWithBlockDone {\n i_16_digits int\n s_20_words string\n @@stream.done\n}\n\nclass SmallThing {\n i_16_digits int @stream.not_null\n i_8_digits int\n}\n\nfunction MakeSemanticContainer() -> SemanticContainer {\n client GPT35\n prompt #\"\n {{ ctx.output_format }}\n \"#\n}", "test-files/strategies/fallback-shorthand.baml" => "\nclient FallbackToShorthand {\n provider fallback\n options {\n strategy [\n \"openai/does-not-exist\",\n \"openai/gpt-4o-mini\"\n ]\n }\n}\n\n\nfunction TestFallbackToShorthand(input: string) -> string {\n client FallbackToShorthand\n // TODO make it return the client name instead\n prompt #\"\n Say a haiku about {{input}}.\n \"#\n}\n\ntest TestProvider_FallbackToShorthand {\n functions [\n TestFallbackToShorthand\n ]\n args {\n input \"Donkey kong and peanut butter\"\n }\n}\n", "test-files/strategies/fallback.baml" => "// Happy path fallbacks.\nclient FaultyClient {\n provider openai\n options {\n model unknown-model\n api_key env.OPENAI_API_KEY\n }\n}\n\n\nclient FallbackClient {\n provider fallback\n options {\n // first 2 clients are expected to fail.\n strategy [\n FaultyClient,\n RetryClientConstant,\n GPT35\n Gemini\n\n ]\n }\n}\n\nfunction TestFallbackClient() -> string {\n client FallbackClient\n // TODO make it return the client name instead\n prompt #\"\n Say a haiku about mexico.\n \"#\n}\n\n// Fallbacks should fail gracefully.\nclient FaultyAzureClient {\n provider azure-openai\n options {\n model unknown-model\n resource_name \"unknown-resource-id\"\n deployment_id \"unknown-deployment-id\"\n }\n}\n\nclient SingleFallbackClient {\n provider fallback\n options {\n // first 2 clients are expected to fail.\n strategy [\n FaultyAzureClient\n ]\n }\n}\n\nfunction TestSingleFallbackClient() -> string {\n client SingleFallbackClient\n // TODO make it return the client name instead\n prompt #\"\n Say a haiku about mexico.\n \"#\n}\n", "test-files/strategies/retry.baml" => "\nretry_policy Exponential {\n max_retries 3\n strategy {\n type exponential_backoff\n }\n}\n\nretry_policy Constant {\n max_retries 3\n strategy {\n type constant_delay\n delay_ms 100\n }\n}\n\nclient RetryClientConstant {\n provider openai\n retry_policy Constant\n options {\n model \"gpt-3.5-turbo\"\n api_key \"blah\"\n }\n}\n\nclient RetryClientExponential {\n provider openai\n retry_policy Exponential\n options {\n model \"gpt-3.5-turbo\"\n api_key \"blahh\"\n }\n}\n\nfunction TestRetryConstant() -> string {\n client RetryClientConstant\n prompt #\"\n Say a haiku\n \"#\n}\n\nfunction TestRetryExponential() -> string {\n client RetryClientExponential\n prompt #\"\n Say a haiku\n \"#\n}\n", diff --git a/integ-tests/ruby/baml_client/partial-types.rb b/integ-tests/ruby/baml_client/partial-types.rb index df9e1293b5..8002af55f8 100644 --- a/integ-tests/ruby/baml_client/partial-types.rb +++ b/integ-tests/ruby/baml_client/partial-types.rb @@ -30,7 +30,9 @@ class ClassForNullLiteral < T::Struct; end class ClassOptionalOutput < T::Struct; end class ClassOptionalOutput2 < T::Struct; end class ClassToRecAlias < T::Struct; end + class ClassWithBlockDone < T::Struct; end class ClassWithImage < T::Struct; end + class ClassWithoutDone < T::Struct; end class CompoundBigNumbers < T::Struct; end class ContactInfo < T::Struct; end class CustomTaskResult < T::Struct; end @@ -89,6 +91,8 @@ class Recipe < T::Struct; end class Resume < T::Struct; end class Schema < T::Struct; end class SearchParams < T::Struct; end + class SemanticContainer < T::Struct; end + class SmallThing < T::Struct; end class SomeClassNestedDynamic < T::Struct; end class StringToClassEntry < T::Struct; end class TestClassAlias < T::Struct; end @@ -116,8 +120,8 @@ def initialize(props) class BinaryNode < T::Struct include Baml::Sorbet::Struct const :data, T.nilable(Integer) - const :left, Baml::PartialTypes::BinaryNode - const :right, Baml::PartialTypes::BinaryNode + const :left, T.nilable(Baml::PartialTypes::BinaryNode) + const :right, T.nilable(Baml::PartialTypes::BinaryNode) def initialize(props) super( @@ -217,7 +221,7 @@ class ClassOptionalOutput2 < T::Struct include Baml::Sorbet::Struct const :prop1, T.nilable(String) const :prop2, T.nilable(String) - const :prop3, Baml::PartialTypes::Blah + const :prop3, T.nilable(Baml::PartialTypes::Blah) def initialize(props) super( @@ -231,7 +235,7 @@ def initialize(props) end class ClassToRecAlias < T::Struct include Baml::Sorbet::Struct - const :list, Baml::PartialTypes::LinkedListAliasNode + const :list, T.nilable(Baml::PartialTypes::LinkedListAliasNode) def initialize(props) super( @@ -241,11 +245,25 @@ def initialize(props) @props = props end end + class ClassWithBlockDone < T::Struct + include Baml::Sorbet::Struct + const :i_16_digits, T.nilable(Integer) + const :s_20_words, T.nilable(String) + + def initialize(props) + super( + i_16_digits: props[:i_16_digits], + s_20_words: props[:s_20_words], + ) + + @props = props + end + end class ClassWithImage < T::Struct include Baml::Sorbet::Struct const :myImage, T.nilable(Baml::Image) const :param2, T.nilable(String) - const :fake_image, Baml::PartialTypes::FakeImage + const :fake_image, T.nilable(Baml::PartialTypes::FakeImage) def initialize(props) super( @@ -257,11 +275,25 @@ def initialize(props) @props = props end end + class ClassWithoutDone < T::Struct + include Baml::Sorbet::Struct + const :i_16_digits, T.nilable(Integer) + const :s_20_words, Baml::StreamState[T.nilable(String)] + + def initialize(props) + super( + i_16_digits: props[:i_16_digits], + s_20_words: props[:s_20_words], + ) + + @props = props + end + end class CompoundBigNumbers < T::Struct include Baml::Sorbet::Struct - const :big, Baml::PartialTypes::BigNumbers - const :big_nums, T::Array[Baml::PartialTypes::BigNumbers] - const :another, Baml::PartialTypes::BigNumbers + const :big, T.nilable(Baml::PartialTypes::BigNumbers) + const :big_nums, T::Array[T.nilable(Baml::PartialTypes::BigNumbers)] + const :another, T.nilable(Baml::PartialTypes::BigNumbers) def initialize(props) super( @@ -275,8 +307,8 @@ def initialize(props) end class ContactInfo < T::Struct include Baml::Sorbet::Struct - const :primary, T.nilable(T.any(Baml::PartialTypes::PhoneNumber, Baml::PartialTypes::EmailAddress)) - const :secondary, T.nilable(T.any(Baml::PartialTypes::PhoneNumber, Baml::PartialTypes::EmailAddress, T.nilable(NilClass))) + const :primary, T.nilable(T.any(T.nilable(Baml::PartialTypes::PhoneNumber), T.nilable(Baml::PartialTypes::EmailAddress))) + const :secondary, T.nilable(T.any(T.nilable(Baml::PartialTypes::PhoneNumber), T.nilable(Baml::PartialTypes::EmailAddress), T.nilable(NilClass))) def initialize(props) super( @@ -289,9 +321,9 @@ def initialize(props) end class CustomTaskResult < T::Struct include Baml::Sorbet::Struct - const :bookOrder, T.nilable(T.any(Baml::PartialTypes::BookOrder, T.nilable(NilClass))) - const :flightConfirmation, T.nilable(T.any(Baml::PartialTypes::FlightConfirmation, T.nilable(NilClass))) - const :groceryReceipt, T.nilable(T.any(Baml::PartialTypes::GroceryReceipt, T.nilable(NilClass))) + const :bookOrder, T.nilable(T.any(T.nilable(Baml::PartialTypes::BookOrder), T.nilable(NilClass))) + const :flightConfirmation, T.nilable(T.any(T.nilable(Baml::PartialTypes::FlightConfirmation), T.nilable(NilClass))) + const :groceryReceipt, T.nilable(T.any(T.nilable(Baml::PartialTypes::GroceryReceipt), T.nilable(NilClass))) def initialize(props) super( @@ -342,7 +374,7 @@ def initialize(props) class DynamicClassTwo < T::Struct include Baml::Sorbet::Struct const :hi, T.nilable(String) - const :some_class, Baml::PartialTypes::SomeClassNestedDynamic + const :some_class, T.nilable(Baml::PartialTypes::SomeClassNestedDynamic) const :status, T.nilable(Baml::Types::DynEnumOne) def initialize(props) @@ -367,7 +399,7 @@ def initialize(props) end class Earthling < T::Struct include Baml::Sorbet::Struct - const :age, Baml::Checked[T.nilable(Integer)] + const :age, T.nilable(Baml::Checked[Integer]) def initialize(props) super( @@ -477,9 +509,9 @@ def initialize(props) end class FooAny < T::Struct include Baml::Sorbet::Struct - const :planetary_age, T.nilable(T.any(Baml::PartialTypes::Martian, Baml::PartialTypes::Earthling)) - const :certainty, Baml::Checked[T.nilable(Integer)] - const :species, Baml::Checked[T.nilable(String)] + const :planetary_age, T.nilable(T.any(T.nilable(Baml::PartialTypes::Martian), T.nilable(Baml::PartialTypes::Earthling))) + const :certainty, T.nilable(Baml::Checked[Integer]) + const :species, T.nilable(Baml::Checked[String]) def initialize(props) super( @@ -493,7 +525,7 @@ def initialize(props) end class Forest < T::Struct include Baml::Sorbet::Struct - const :trees, T::Array[Baml::PartialTypes::Tree] + const :trees, T::Array[T.nilable(Baml::PartialTypes::Tree)] def initialize(props) super( @@ -581,7 +613,7 @@ class InnerClass < T::Struct include Baml::Sorbet::Struct const :prop1, T.nilable(String) const :prop2, T.nilable(String) - const :inner, Baml::PartialTypes::InnerClass2 + const :inner, T.nilable(Baml::PartialTypes::InnerClass2) def initialize(props) super( @@ -624,7 +656,7 @@ def initialize(props) class InputClassNested < T::Struct include Baml::Sorbet::Struct const :key, T.nilable(String) - const :nested, Baml::PartialTypes::InputClass + const :nested, T.nilable(Baml::PartialTypes::InputClass) def initialize(props) super( @@ -637,7 +669,7 @@ def initialize(props) end class LinkedList < T::Struct include Baml::Sorbet::Struct - const :head, Baml::PartialTypes::Node + const :head, T.nilable(Baml::PartialTypes::Node) const :len, T.nilable(Integer) def initialize(props) @@ -652,7 +684,7 @@ def initialize(props) class LinkedListAliasNode < T::Struct include Baml::Sorbet::Struct const :value, T.nilable(Integer) - const :next, Baml::PartialTypes::LinkedListAliasNode + const :next, T.nilable(Baml::PartialTypes::LinkedListAliasNode) def initialize(props) super( @@ -701,7 +733,7 @@ def initialize(props) end class MalformedConstraints < T::Struct include Baml::Sorbet::Struct - const :foo, Baml::Checked[T.nilable(Integer)] + const :foo, T.nilable(Baml::Checked[Integer]) def initialize(props) super( @@ -729,7 +761,7 @@ class Martian < T::Struct include Baml::Sorbet::Struct # The age of the Martian in Mars years. # So many Mars years. - const :age, Baml::Checked[T.nilable(Integer)] + const :age, T.nilable(Baml::Checked[Integer]) def initialize(props) super( @@ -741,7 +773,7 @@ def initialize(props) end class MergeAttrs < T::Struct include Baml::Sorbet::Struct - const :amount, Baml::Checked[T.nilable(Integer)] + const :amount, T.nilable(Baml::Checked[Integer]) def initialize(props) super( @@ -771,7 +803,7 @@ class Nested < T::Struct include Baml::Sorbet::Struct const :prop3, T.nilable(T.any(T.nilable(String), T.nilable(NilClass))) const :prop4, T.nilable(T.any(T.nilable(String), T.nilable(NilClass))) - const :prop20, Baml::PartialTypes::Nested2 + const :prop20, T.nilable(Baml::PartialTypes::Nested2) def initialize(props) super( @@ -799,7 +831,7 @@ def initialize(props) end class NestedBlockConstraint < T::Struct include Baml::Sorbet::Struct - const :nbc, Baml::Checked[Baml::PartialTypes::BlockConstraint] + const :nbc, T.nilable(Baml::PartialTypes::BlockConstraint) def initialize(props) super( @@ -811,7 +843,7 @@ def initialize(props) end class NestedBlockConstraintForParam < T::Struct include Baml::Sorbet::Struct - const :nbcfp, Baml::PartialTypes::BlockConstraintForParam + const :nbcfp, T.nilable(Baml::PartialTypes::BlockConstraintForParam) def initialize(props) super( @@ -824,7 +856,7 @@ def initialize(props) class Node < T::Struct include Baml::Sorbet::Struct const :data, T.nilable(Integer) - const :next, Baml::PartialTypes::Node + const :next, T.nilable(Baml::PartialTypes::Node) def initialize(props) super( @@ -838,7 +870,7 @@ def initialize(props) class NodeWithAliasIndirection < T::Struct include Baml::Sorbet::Struct const :value, T.nilable(Integer) - const :next, Baml::PartialTypes::NodeWithAliasIndirection + const :next, T.nilable(Baml::PartialTypes::NodeWithAliasIndirection) def initialize(props) super( @@ -879,7 +911,7 @@ def initialize(props) end class OptionalTest_ReturnType < T::Struct include Baml::Sorbet::Struct - const :omega_1, Baml::PartialTypes::OptionalTest_Prop1 + const :omega_1, T.nilable(Baml::PartialTypes::OptionalTest_Prop1) const :omega_2, T.nilable(String) const :omega_3, T::Array[T.nilable(Baml::Types::OptionalTest_CategoryType)] @@ -976,7 +1008,7 @@ def initialize(props) class RaysData < T::Struct include Baml::Sorbet::Struct const :dataType, T.nilable(Baml::Types::DataType) - const :value, T.nilable(T.any(Baml::PartialTypes::Resume, Baml::PartialTypes::Event)) + const :value, T.nilable(T.any(T.nilable(Baml::PartialTypes::Resume), T.nilable(Baml::PartialTypes::Event))) def initialize(props) super( @@ -989,7 +1021,7 @@ def initialize(props) end class ReceiptInfo < T::Struct include Baml::Sorbet::Struct - const :items, T::Array[Baml::PartialTypes::ReceiptItem] + const :items, T::Array[T.nilable(Baml::PartialTypes::ReceiptItem)] const :total_cost, T.nilable(Float) const :venue, T.nilable(T.any(T.nilable(String), T.nilable(String))) @@ -1023,7 +1055,7 @@ def initialize(props) end class Recipe < T::Struct include Baml::Sorbet::Struct - const :ingredients, T::Hash[String, Baml::PartialTypes::Quantity] + const :ingredients, T::Hash[String, T.nilable(Baml::PartialTypes::Quantity)] const :recipe_type, T.nilable(T.any(T.nilable(String), T.nilable(String))) def initialize(props) @@ -1040,7 +1072,7 @@ class Resume < T::Struct const :name, T.nilable(String) const :email, T.nilable(String) const :phone, T.nilable(String) - const :experience, T::Array[Baml::PartialTypes::Education] + const :experience, T::Array[T.nilable(Baml::PartialTypes::Education)] const :education, T::Array[T.nilable(String)] const :skills, T::Array[T.nilable(String)] @@ -1060,10 +1092,10 @@ def initialize(props) class Schema < T::Struct include Baml::Sorbet::Struct const :prop1, T.nilable(T.any(T.nilable(String), T.nilable(NilClass))) - const :prop2, T.nilable(T.any(Baml::PartialTypes::Nested, T.nilable(String))) + const :prop2, T.nilable(T.any(T.nilable(Baml::PartialTypes::Nested), T.nilable(String))) const :prop5, T::Array[T.nilable(T.any(T.nilable(String), T.nilable(NilClass)))] - const :prop6, T.nilable(T.any(T.nilable(String), T::Array[Baml::PartialTypes::Nested])) - const :nested_attrs, T::Array[T.nilable(T.any(T.nilable(String), T.nilable(NilClass), Baml::PartialTypes::Nested))] + const :prop6, T.nilable(T.any(T.nilable(String), T::Array[T.nilable(Baml::PartialTypes::Nested)])) + const :nested_attrs, T::Array[T.nilable(T.any(T.nilable(String), T.nilable(NilClass), T.nilable(Baml::PartialTypes::Nested)))] const :parens, T.nilable(T.any(T.nilable(String), T.nilable(NilClass))) const :other_group, T.nilable(T.any(T.nilable(String), T.nilable(T.any(T.nilable(Integer), T.nilable(String))))) @@ -1085,9 +1117,9 @@ class SearchParams < T::Struct include Baml::Sorbet::Struct const :dateRange, T.nilable(Integer) const :location, T::Array[T.nilable(String)] - const :jobTitle, Baml::PartialTypes::WithReasoning - const :company, Baml::PartialTypes::WithReasoning - const :description, T::Array[Baml::PartialTypes::WithReasoning] + const :jobTitle, T.nilable(Baml::PartialTypes::WithReasoning) + const :company, T.nilable(Baml::PartialTypes::WithReasoning) + const :description, T::Array[T.nilable(Baml::PartialTypes::WithReasoning)] const :tags, T::Array[T.nilable(T.any(T.nilable(Baml::Types::Tag), T.nilable(String)))] def initialize(props) @@ -1103,6 +1135,46 @@ def initialize(props) @props = props end end + class SemanticContainer < T::Struct + include Baml::Sorbet::Struct + const :sixteen_digit_number, T.nilable(Integer) + const :string_with_twenty_words, T.nilable(String) + const :class_1, T.nilable(Baml::PartialTypes::ClassWithoutDone) + const :class_2, T.nilable(Baml::Types::ClassWithBlockDone) + const :class_done_needed, Baml::Types::ClassWithBlockDone + const :class_needed, Baml::PartialTypes::ClassWithoutDone + const :three_small_things, T::Array[T.nilable(Baml::PartialTypes::SmallThing)] + const :final_string, T.nilable(String) + + def initialize(props) + super( + sixteen_digit_number: props[:sixteen_digit_number], + string_with_twenty_words: props[:string_with_twenty_words], + class_1: props[:class_1], + class_2: props[:class_2], + class_done_needed: props[:class_done_needed], + class_needed: props[:class_needed], + three_small_things: props[:three_small_things], + final_string: props[:final_string], + ) + + @props = props + end + end + class SmallThing < T::Struct + include Baml::Sorbet::Struct + const :i_16_digits, Integer + const :i_8_digits, T.nilable(Integer) + + def initialize(props) + super( + i_16_digits: props[:i_16_digits], + i_8_digits: props[:i_8_digits], + ) + + @props = props + end + end class SomeClassNestedDynamic < T::Struct include Baml::Sorbet::Struct const :hi, T.nilable(String) @@ -1150,7 +1222,7 @@ def initialize(props) class TestClassNested < T::Struct include Baml::Sorbet::Struct const :prop1, T.nilable(String) - const :prop2, Baml::PartialTypes::InnerClass + const :prop2, T.nilable(Baml::PartialTypes::InnerClass) def initialize(props) super( @@ -1192,7 +1264,7 @@ def initialize(props) class Tree < T::Struct include Baml::Sorbet::Struct const :data, T.nilable(Integer) - const :children, Baml::PartialTypes::Forest + const :children, T.nilable(Baml::PartialTypes::Forest) def initialize(props) super( diff --git a/integ-tests/ruby/baml_client/type-registry.rb b/integ-tests/ruby/baml_client/type-registry.rb index 713abe9efe..4f0821908a 100644 --- a/integ-tests/ruby/baml_client/type-registry.rb +++ b/integ-tests/ruby/baml_client/type-registry.rb @@ -18,7 +18,7 @@ module Baml class TypeBuilder def initialize @registry = Baml::Ffi::TypeBuilder.new - @classes = Set[ "BigNumbers", "BinaryNode", "Blah", "BlockConstraint", "BlockConstraintForParam", "BookOrder", "ClassForNullLiteral", "ClassOptionalOutput", "ClassOptionalOutput2", "ClassToRecAlias", "ClassWithImage", "CompoundBigNumbers", "ContactInfo", "CustomTaskResult", "DummyOutput", "DynInputOutput", "DynamicClassOne", "DynamicClassTwo", "DynamicOutput", "Earthling", "Education", "Email", "EmailAddress", "Event", "FakeImage", "FlightConfirmation", "FooAny", "Forest", "FormatterTest0", "FormatterTest1", "FormatterTest2", "FormatterTest3", "GroceryReceipt", "InnerClass", "InnerClass2", "InputClass", "InputClassNested", "LinkedList", "LinkedListAliasNode", "LiteralClassHello", "LiteralClassOne", "LiteralClassTwo", "MalformedConstraints", "MalformedConstraints2", "Martian", "MergeAttrs", "NamedArgsSingleClass", "Nested", "Nested2", "NestedBlockConstraint", "NestedBlockConstraintForParam", "Node", "NodeWithAliasIndirection", "OptionalListAndMap", "OptionalTest_Prop1", "OptionalTest_ReturnType", "OrderInfo", "OriginalA", "OriginalB", "Person", "PhoneNumber", "Quantity", "RaysData", "ReceiptInfo", "ReceiptItem", "Recipe", "Resume", "Schema", "SearchParams", "SomeClassNestedDynamic", "StringToClassEntry", "TestClassAlias", "TestClassNested", "TestClassWithEnum", "TestOutputClass", "Tree", "TwoStoriesOneTitle", "UnionTest_ReturnType", "WithReasoning", ] + @classes = Set[ "BigNumbers", "BinaryNode", "Blah", "BlockConstraint", "BlockConstraintForParam", "BookOrder", "ClassForNullLiteral", "ClassOptionalOutput", "ClassOptionalOutput2", "ClassToRecAlias", "ClassWithBlockDone", "ClassWithImage", "ClassWithoutDone", "CompoundBigNumbers", "ContactInfo", "CustomTaskResult", "DummyOutput", "DynInputOutput", "DynamicClassOne", "DynamicClassTwo", "DynamicOutput", "Earthling", "Education", "Email", "EmailAddress", "Event", "FakeImage", "FlightConfirmation", "FooAny", "Forest", "FormatterTest0", "FormatterTest1", "FormatterTest2", "FormatterTest3", "GroceryReceipt", "InnerClass", "InnerClass2", "InputClass", "InputClassNested", "LinkedList", "LinkedListAliasNode", "LiteralClassHello", "LiteralClassOne", "LiteralClassTwo", "MalformedConstraints", "MalformedConstraints2", "Martian", "MergeAttrs", "NamedArgsSingleClass", "Nested", "Nested2", "NestedBlockConstraint", "NestedBlockConstraintForParam", "Node", "NodeWithAliasIndirection", "OptionalListAndMap", "OptionalTest_Prop1", "OptionalTest_ReturnType", "OrderInfo", "OriginalA", "OriginalB", "Person", "PhoneNumber", "Quantity", "RaysData", "ReceiptInfo", "ReceiptItem", "Recipe", "Resume", "Schema", "SearchParams", "SemanticContainer", "SmallThing", "SomeClassNestedDynamic", "StringToClassEntry", "TestClassAlias", "TestClassNested", "TestClassWithEnum", "TestOutputClass", "Tree", "TwoStoriesOneTitle", "UnionTest_ReturnType", "WithReasoning", ] @enums = Set[ "AliasedEnum", "Category", "Category2", "Category3", "Color", "DataType", "DynEnumOne", "DynEnumTwo", "EnumInClass", "EnumOutput", "Hobby", "MapKey", "NamedArgsSingleEnum", "NamedArgsSingleEnumList", "OptionalTest_CategoryType", "OrderStatus", "Tag", "TestEnum", ] end diff --git a/integ-tests/ruby/baml_client/types.rb b/integ-tests/ruby/baml_client/types.rb index 2ca2fcf9cc..56e4abcbb3 100644 --- a/integ-tests/ruby/baml_client/types.rb +++ b/integ-tests/ruby/baml_client/types.rb @@ -155,7 +155,9 @@ class ClassForNullLiteral < T::Struct; end class ClassOptionalOutput < T::Struct; end class ClassOptionalOutput2 < T::Struct; end class ClassToRecAlias < T::Struct; end + class ClassWithBlockDone < T::Struct; end class ClassWithImage < T::Struct; end + class ClassWithoutDone < T::Struct; end class CompoundBigNumbers < T::Struct; end class ContactInfo < T::Struct; end class CustomTaskResult < T::Struct; end @@ -214,6 +216,8 @@ class Recipe < T::Struct; end class Resume < T::Struct; end class Schema < T::Struct; end class SearchParams < T::Struct; end + class SemanticContainer < T::Struct; end + class SmallThing < T::Struct; end class SomeClassNestedDynamic < T::Struct; end class StringToClassEntry < T::Struct; end class TestClassAlias < T::Struct; end @@ -366,6 +370,20 @@ def initialize(props) @props = props end end + class ClassWithBlockDone < T::Struct + include Baml::Sorbet::Struct + const :i_16_digits, Integer + const :s_20_words, String + + def initialize(props) + super( + i_16_digits: props[:i_16_digits], + s_20_words: props[:s_20_words], + ) + + @props = props + end + end class ClassWithImage < T::Struct include Baml::Sorbet::Struct const :myImage, Baml::Image @@ -382,6 +400,20 @@ def initialize(props) @props = props end end + class ClassWithoutDone < T::Struct + include Baml::Sorbet::Struct + const :i_16_digits, Integer + const :s_20_words, String + + def initialize(props) + super( + i_16_digits: props[:i_16_digits], + s_20_words: props[:s_20_words], + ) + + @props = props + end + end class CompoundBigNumbers < T::Struct include Baml::Sorbet::Struct const :big, Baml::Types::BigNumbers @@ -1228,6 +1260,46 @@ def initialize(props) @props = props end end + class SemanticContainer < T::Struct + include Baml::Sorbet::Struct + const :sixteen_digit_number, Integer + const :string_with_twenty_words, String + const :class_1, Baml::Types::ClassWithoutDone + const :class_2, Baml::Types::ClassWithBlockDone + const :class_done_needed, Baml::Types::ClassWithBlockDone + const :class_needed, Baml::Types::ClassWithoutDone + const :three_small_things, T::Array[Baml::Types::SmallThing] + const :final_string, String + + def initialize(props) + super( + sixteen_digit_number: props[:sixteen_digit_number], + string_with_twenty_words: props[:string_with_twenty_words], + class_1: props[:class_1], + class_2: props[:class_2], + class_done_needed: props[:class_done_needed], + class_needed: props[:class_needed], + three_small_things: props[:three_small_things], + final_string: props[:final_string], + ) + + @props = props + end + end + class SmallThing < T::Struct + include Baml::Sorbet::Struct + const :i_16_digits, Integer + const :i_8_digits, Integer + + def initialize(props) + super( + i_16_digits: props[:i_16_digits], + i_8_digits: props[:i_8_digits], + ) + + @props = props + end + end class SomeClassNestedDynamic < T::Struct include Baml::Sorbet::Struct const :hi, String diff --git a/integ-tests/ruby/test_functions.rb b/integ-tests/ruby/test_functions.rb index 988ae98652..c166245254 100644 --- a/integ-tests/ruby/test_functions.rb +++ b/integ-tests/ruby/test_functions.rb @@ -110,7 +110,7 @@ ) ) - res = b.RecursiveClassWithAliasIndirection.new( + res = b.RecursiveClassWithAliasIndirection( cls: Baml::Types::NodeWithAliasIndirection.new( value: 1, next: Baml::Types::NodeWithAliasIndirection.new( @@ -257,14 +257,17 @@ it "allows streaming of nested" do stream = b.stream.FnOutputClassNested(input: "a") msgs = [] + puts "TEST" stream.each do |msg| + print("INNER") msgs << msg end final = stream.get_final_response - puts final + puts msgs.last.to_json + puts final.to_json assert msgs.size > 0, "Expected at least one streamed response but got none." - assert msgs.last == final, "Expected last stream message to match final response." + assert msgs.last.to_json == final.to_json, "Expected last stream message to match final response." end it "tests dynamic" do @@ -410,4 +413,61 @@ end end + it "uses semantic_container" do + stream = b.stream.MakeSemanticContainer() + stream.each do |msg| + puts msg.to_json + end + end + + it "uses semantic_streaming" do + stream = b.stream.MakeSemanticContainer() + + reference_string = nil + reference_int = nil + + msgs = [] + puts "HELLO'" + + stream.each do |msg| + puts "THERE" + puts msg.to_json + + msgs << msg + + # Check value stability. + if !msg.sixteen_digit_number.nil? + if reference_int.nil? + reference_int = msg.sixteen_digit_number + else + assert_equal reference_int, msg.sixteen_digit_number + end + end + if !msg.string_with_twenty_words.nil? + if reference_string.nil? + reference_string = msg.string_with_twenty_words + else + assert_equal reference_string, msg.string_with_twenty_words + end + end + + # Check for @stream.with_state. + if !msg.class_needed.nil? + if !msg.class_needed.s_20_words.value.nil? + if len(msg.class_needed.s_20_words.value.split(" ")) < 3 && msg.final_string.nil? + puts(msg) + assert msg.class_needed.s_20_words.state == "Incomplete" + end + end + end + if !msg.final_string.nil? + assert msg.class_needed.s_20_words.state == "Complete" + end + end + + puts "TRY FINAL" + final = stream.get_final_response + puts final.to_json + end + end diff --git a/integ-tests/typescript/baml_client/async_client.ts b/integ-tests/typescript/baml_client/async_client.ts index 11562bc611..cd3a9ac7d8 100644 --- a/integ-tests/typescript/baml_client/async_client.ts +++ b/integ-tests/typescript/baml_client/async_client.ts @@ -17,16 +17,11 @@ $ pnpm add @boundaryml/baml // biome-ignore format: autogenerated code import { BamlRuntime, FunctionResult, BamlCtxManager, BamlStream, Image, ClientRegistry, BamlValidationError, createBamlValidationError } from "@boundaryml/baml" import { Checked, Check } from "./types" -import {BigNumbers, BinaryNode, Blah, BlockConstraint, BlockConstraintForParam, BookOrder, ClassForNullLiteral, ClassOptionalOutput, ClassOptionalOutput2, ClassToRecAlias, ClassWithImage, CompoundBigNumbers, ContactInfo, CustomTaskResult, DummyOutput, DynInputOutput, DynamicClassOne, DynamicClassTwo, DynamicOutput, Earthling, Education, Email, EmailAddress, Event, FakeImage, FlightConfirmation, FooAny, Forest, FormatterTest0, FormatterTest1, FormatterTest2, FormatterTest3, GroceryReceipt, InnerClass, InnerClass2, InputClass, InputClassNested, LinkedList, LinkedListAliasNode, LiteralClassHello, LiteralClassOne, LiteralClassTwo, MalformedConstraints, MalformedConstraints2, Martian, MergeAttrs, NamedArgsSingleClass, Nested, Nested2, NestedBlockConstraint, NestedBlockConstraintForParam, Node, NodeWithAliasIndirection, OptionalListAndMap, OptionalTest_Prop1, OptionalTest_ReturnType, OrderInfo, OriginalA, OriginalB, Person, PhoneNumber, Quantity, RaysData, ReceiptInfo, ReceiptItem, Recipe, Resume, Schema, SearchParams, SomeClassNestedDynamic, StringToClassEntry, TestClassAlias, TestClassNested, TestClassWithEnum, TestOutputClass, Tree, TwoStoriesOneTitle, UnionTest_ReturnType, WithReasoning, AliasedEnum, Category, Category2, Category3, Color, DataType, DynEnumOne, DynEnumTwo, EnumInClass, EnumOutput, Hobby, MapKey, NamedArgsSingleEnum, NamedArgsSingleEnumList, OptionalTest_CategoryType, OrderStatus, Tag, TestEnum} from "./types" +import "./partial_types" +import {BigNumbers, BinaryNode, Blah, BlockConstraint, BlockConstraintForParam, BookOrder, ClassForNullLiteral, ClassOptionalOutput, ClassOptionalOutput2, ClassToRecAlias, ClassWithBlockDone, ClassWithImage, ClassWithoutDone, CompoundBigNumbers, ContactInfo, CustomTaskResult, DummyOutput, DynInputOutput, DynamicClassOne, DynamicClassTwo, DynamicOutput, Earthling, Education, Email, EmailAddress, Event, FakeImage, FlightConfirmation, FooAny, Forest, FormatterTest0, FormatterTest1, FormatterTest2, FormatterTest3, GroceryReceipt, InnerClass, InnerClass2, InputClass, InputClassNested, LinkedList, LinkedListAliasNode, LiteralClassHello, LiteralClassOne, LiteralClassTwo, MalformedConstraints, MalformedConstraints2, Martian, MergeAttrs, NamedArgsSingleClass, Nested, Nested2, NestedBlockConstraint, NestedBlockConstraintForParam, Node, NodeWithAliasIndirection, OptionalListAndMap, OptionalTest_Prop1, OptionalTest_ReturnType, OrderInfo, OriginalA, OriginalB, Person, PhoneNumber, Quantity, RaysData, ReceiptInfo, ReceiptItem, Recipe, Resume, Schema, SearchParams, SemanticContainer, SmallThing, SomeClassNestedDynamic, StringToClassEntry, TestClassAlias, TestClassNested, TestClassWithEnum, TestOutputClass, Tree, TwoStoriesOneTitle, UnionTest_ReturnType, WithReasoning, AliasedEnum, Category, Category2, Category3, Color, DataType, DynEnumOne, DynEnumTwo, EnumInClass, EnumOutput, Hobby, MapKey, NamedArgsSingleEnum, NamedArgsSingleEnumList, OptionalTest_CategoryType, OrderStatus, Tag, TestEnum} from "./types" import TypeBuilder from "./type_builder" import { DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX, DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_RUNTIME } from "./globals" -export type RecursivePartialNull = T extends object - ? { - [P in keyof T]?: RecursivePartialNull; - } - : T | null; - export class BamlAsyncClient { private runtime: BamlRuntime private ctx_manager: BamlCtxManager @@ -57,7 +52,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Recipe + return raw.parsed(false) as Recipe } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -82,7 +77,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as LinkedListAliasNode + return raw.parsed(false) as LinkedListAliasNode } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -107,7 +102,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Checked + return raw.parsed(false) as Checked } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -132,7 +127,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -157,7 +152,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -182,7 +177,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -207,7 +202,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -232,7 +227,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -257,7 +252,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as OptionalListAndMap + return raw.parsed(false) as OptionalListAndMap } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -282,7 +277,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as number + return raw.parsed(false) as number } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -307,7 +302,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -332,7 +327,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as LinkedList + return raw.parsed(false) as LinkedList } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -357,7 +352,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Tree + return raw.parsed(false) as Tree } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -382,7 +377,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as ClassToRecAlias + return raw.parsed(false) as ClassToRecAlias } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -407,7 +402,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as (string | DynEnumTwo) + return raw.parsed(false) as (string | DynEnumTwo) } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -432,7 +427,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Category + return raw.parsed(false) as Category } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -457,7 +452,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Category + return raw.parsed(false) as Category } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -482,7 +477,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Category + return raw.parsed(false) as Category } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -507,7 +502,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -532,7 +527,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as BookOrder | FlightConfirmation | GroceryReceipt + return raw.parsed(false) as BookOrder | FlightConfirmation | GroceryReceipt } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -557,7 +552,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -582,7 +577,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -607,7 +602,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -632,7 +627,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -657,7 +652,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as OriginalA | OriginalB + return raw.parsed(false) as OriginalA | OriginalB } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -682,7 +677,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as DummyOutput + return raw.parsed(false) as DummyOutput } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -707,7 +702,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as DynamicClassTwo + return raw.parsed(false) as DynamicClassTwo } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -732,7 +727,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as DynInputOutput + return raw.parsed(false) as DynInputOutput } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -757,7 +752,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as DynInputOutput[] + return raw.parsed(false) as DynInputOutput[] } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -782,7 +777,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -807,7 +802,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as ContactInfo + return raw.parsed(false) as ContactInfo } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -832,7 +827,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as (string | Hobby)[] + return raw.parsed(false) as (string | Hobby)[] } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -857,7 +852,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string[] + return raw.parsed(false) as string[] } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -882,7 +877,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Person[] + return raw.parsed(false) as Person[] } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -907,7 +902,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as ReceiptInfo + return raw.parsed(false) as ReceiptInfo } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -932,7 +927,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Resume + return raw.parsed(false) as Resume } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -957,7 +952,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Resume + return raw.parsed(false) as Resume } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -982,7 +977,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as ClassOptionalOutput | null + return raw.parsed(false) as ClassOptionalOutput | null } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1007,7 +1002,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as ClassOptionalOutput2 | null + return raw.parsed(false) as ClassOptionalOutput2 | null } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1032,7 +1027,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as EnumOutput[] + return raw.parsed(false) as EnumOutput[] } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1057,7 +1052,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as EnumOutput + return raw.parsed(false) as EnumOutput } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1082,7 +1077,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as LiteralClassHello + return raw.parsed(false) as LiteralClassHello } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1107,7 +1102,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as LiteralClassOne | LiteralClassTwo + return raw.parsed(false) as LiteralClassOne | LiteralClassTwo } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1132,7 +1127,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1157,7 +1152,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as boolean + return raw.parsed(false) as boolean } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1182,7 +1177,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as TestOutputClass + return raw.parsed(false) as TestOutputClass } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1207,7 +1202,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as TestOutputClass[] + return raw.parsed(false) as TestOutputClass[] } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1232,7 +1227,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as TestClassNested + return raw.parsed(false) as TestClassNested } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1257,7 +1252,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as TestClassWithEnum + return raw.parsed(false) as TestClassWithEnum } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1282,7 +1277,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as number + return raw.parsed(false) as number } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1307,7 +1302,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as false + return raw.parsed(false) as false } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1332,7 +1327,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as 5 + return raw.parsed(false) as 5 } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1357,7 +1352,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as "example output" + return raw.parsed(false) as "example output" } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1382,7 +1377,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string[] + return raw.parsed(false) as string[] } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1407,7 +1402,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as TestEnum + return raw.parsed(false) as TestEnum } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1432,7 +1427,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as TestClassAlias + return raw.parsed(false) as TestClassAlias } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1457,7 +1452,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1482,7 +1477,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as RaysData + return raw.parsed(false) as RaysData } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1507,7 +1502,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as OrderInfo + return raw.parsed(false) as OrderInfo } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1532,7 +1527,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as SearchParams + return raw.parsed(false) as SearchParams } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1544,9 +1539,9 @@ export class BamlAsyncClient { } async InOutEnumMapKey( - i1: Partial>,i2: Partial>, + i1: Partial>,i2: Partial>, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): Promise>> { + ): Promise>> { try { const raw = await this.runtime.callFunction( "InOutEnumMapKey", @@ -1557,7 +1552,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Partial> + return raw.parsed(false) as Partial> } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1582,7 +1577,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Partial> + return raw.parsed(false) as Partial> } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1607,7 +1602,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Partial> + return raw.parsed(false) as Partial> } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1632,7 +1627,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as JsonValue + return raw.parsed(false) as JsonValue } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1657,7 +1652,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as 1 | true | "string output" + return raw.parsed(false) as 1 | true | "string output" } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1682,7 +1677,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Checked + return raw.parsed(false) as Checked } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1707,7 +1702,32 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as NestedBlockConstraint + return raw.parsed(false) as NestedBlockConstraint + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + + async MakeSemanticContainer( + + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Promise { + try { + const raw = await this.runtime.callFunction( + "MakeSemanticContainer", + { + + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed(false) as SemanticContainer } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1732,7 +1752,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Record + return raw.parsed(false) as Record } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1757,7 +1777,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as MergeAttrs + return raw.parsed(false) as MergeAttrs } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1782,7 +1802,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as DynamicOutput + return raw.parsed(false) as DynamicOutput } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1807,7 +1827,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as number | string | boolean | number | string[] | Record + return raw.parsed(false) as number | string | boolean | number | string[] | Record } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1832,7 +1852,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as ClassForNullLiteral + return raw.parsed(false) as ClassForNullLiteral } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1857,7 +1877,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as (OptionalTest_ReturnType | null)[] + return raw.parsed(false) as (OptionalTest_ReturnType | null)[] } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1882,7 +1902,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as FooAny + return raw.parsed(false) as FooAny } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1907,7 +1927,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Checked + return raw.parsed(false) as Checked } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1932,7 +1952,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as number | string | boolean | number + return raw.parsed(false) as number | string | boolean | number } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1957,7 +1977,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1982,7 +2002,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2007,7 +2027,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2032,7 +2052,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2057,7 +2077,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2082,7 +2102,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2107,7 +2127,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2132,7 +2152,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as RecAliasOne + return raw.parsed(false) as RecAliasOne } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2157,7 +2177,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as NodeWithAliasIndirection + return raw.parsed(false) as NodeWithAliasIndirection } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2182,7 +2202,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Checked + return raw.parsed(false) as Checked } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2207,7 +2227,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as number + return raw.parsed(false) as number } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2232,7 +2252,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as MalformedConstraints + return raw.parsed(false) as MalformedConstraints } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2257,7 +2277,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Schema + return raw.parsed(false) as Schema } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2282,7 +2302,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as RecursiveListAlias + return raw.parsed(false) as RecursiveListAlias } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2307,7 +2327,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as RecursiveMapAlias + return raw.parsed(false) as RecursiveMapAlias } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2332,7 +2352,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as BigNumbers + return raw.parsed(false) as BigNumbers } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2357,7 +2377,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as TwoStoriesOneTitle + return raw.parsed(false) as TwoStoriesOneTitle } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2382,7 +2402,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as number + return raw.parsed(false) as number } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2407,7 +2427,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as (number | string)[] + return raw.parsed(false) as (number | string)[] } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2432,7 +2452,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as CompoundBigNumbers + return raw.parsed(false) as CompoundBigNumbers } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2457,7 +2477,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2482,7 +2502,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2507,7 +2527,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2532,7 +2552,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2557,7 +2577,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2582,7 +2602,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2607,7 +2627,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2632,7 +2652,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2657,7 +2677,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2682,7 +2702,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2707,7 +2727,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2732,7 +2752,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2757,7 +2777,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2782,7 +2802,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2807,7 +2827,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2832,7 +2852,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2857,7 +2877,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2882,7 +2902,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Record + return raw.parsed(false) as Record } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2907,7 +2927,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Record> + return raw.parsed(false) as Record> } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2932,7 +2952,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Record + return raw.parsed(false) as Record } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2957,7 +2977,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2982,7 +3002,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3007,7 +3027,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3032,7 +3052,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3057,7 +3077,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3082,7 +3102,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3107,7 +3127,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3132,7 +3152,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3157,7 +3177,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3182,7 +3202,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3207,7 +3227,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3232,7 +3252,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3257,7 +3277,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3282,7 +3302,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3307,7 +3327,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3332,7 +3352,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3357,7 +3377,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3382,7 +3402,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3407,7 +3427,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3432,7 +3452,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3457,7 +3477,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3482,7 +3502,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3507,7 +3527,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as UnionTest_ReturnType + return raw.parsed(false) as UnionTest_ReturnType } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3532,7 +3552,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as number + return raw.parsed(false) as number } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3557,7 +3577,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as number + return raw.parsed(false) as number } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3582,7 +3602,7 @@ export class BamlAsyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as number + return raw.parsed(false) as number } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3602,7 +3622,7 @@ class BamlStreamClient { AaaSamOutputFormat( recipe: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, Recipe> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "AaaSamOutputFormat", @@ -3614,9 +3634,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, Recipe>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.Recipe => a, (a): a is Recipe => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -3635,7 +3655,7 @@ class BamlStreamClient { AliasThatPointsToRecursiveType( list: LinkedListAliasNode, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, LinkedListAliasNode> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "AliasThatPointsToRecursiveType", @@ -3647,9 +3667,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, LinkedListAliasNode>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.LinkedListAliasNode => a, (a): a is LinkedListAliasNode => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -3668,7 +3688,7 @@ class BamlStreamClient { AliasWithMultipleAttrs( money: Checked, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream>, Checked> { + ): BamlStream, Checked> { try { const raw = this.runtime.streamFunction( "AliasWithMultipleAttrs", @@ -3680,9 +3700,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream>, Checked>( + return new BamlStream, Checked>( raw, - (a): a is RecursivePartialNull> => a, + (a): a is Checked => a, (a): a is Checked => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -3701,7 +3721,7 @@ class BamlStreamClient { AliasedInputClass( input: InputClass, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "AliasedInputClass", @@ -3713,9 +3733,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -3734,7 +3754,7 @@ class BamlStreamClient { AliasedInputClass2( input: InputClass, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "AliasedInputClass2", @@ -3746,9 +3766,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -3767,7 +3787,7 @@ class BamlStreamClient { AliasedInputClassNested( input: InputClassNested, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "AliasedInputClassNested", @@ -3779,9 +3799,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -3800,7 +3820,7 @@ class BamlStreamClient { AliasedInputEnum( input: AliasedEnum, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "AliasedInputEnum", @@ -3812,9 +3832,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -3833,7 +3853,7 @@ class BamlStreamClient { AliasedInputList( input: AliasedEnum[], __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "AliasedInputList", @@ -3845,9 +3865,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -3866,7 +3886,7 @@ class BamlStreamClient { AllowedOptionals( optionals: OptionalListAndMap, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, OptionalListAndMap> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "AllowedOptionals", @@ -3878,9 +3898,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, OptionalListAndMap>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.OptionalListAndMap => a, (a): a is OptionalListAndMap => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -3899,7 +3919,7 @@ class BamlStreamClient { AssertFn( a: number, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, number> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "AssertFn", @@ -3911,9 +3931,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, number>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is number => a, (a): a is number => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -3932,7 +3952,7 @@ class BamlStreamClient { AudioInput( aud: Audio, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "AudioInput", @@ -3944,9 +3964,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -3965,7 +3985,7 @@ class BamlStreamClient { BuildLinkedList( input: number[], __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, LinkedList> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "BuildLinkedList", @@ -3977,9 +3997,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, LinkedList>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.LinkedList => a, (a): a is LinkedList => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -3998,7 +4018,7 @@ class BamlStreamClient { BuildTree( input: BinaryNode, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, Tree> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "BuildTree", @@ -4010,9 +4030,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, Tree>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.Tree => a, (a): a is Tree => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4031,7 +4051,7 @@ class BamlStreamClient { ClassThatPointsToRecursiveClassThroughAlias( cls: ClassToRecAlias, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, ClassToRecAlias> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "ClassThatPointsToRecursiveClassThroughAlias", @@ -4043,9 +4063,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, ClassToRecAlias>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.ClassToRecAlias => a, (a): a is ClassToRecAlias => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4064,7 +4084,7 @@ class BamlStreamClient { ClassifyDynEnumTwo( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, (string | DynEnumTwo)> { + ): BamlStream<(string | DynEnumTwo), (string | DynEnumTwo)> { try { const raw = this.runtime.streamFunction( "ClassifyDynEnumTwo", @@ -4076,9 +4096,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, (string | DynEnumTwo)>( + return new BamlStream<(string | DynEnumTwo), (string | DynEnumTwo)>( raw, - (a): a is RecursivePartialNull<(string | DynEnumTwo)> => a, + (a): a is (string | DynEnumTwo) => a, (a): a is (string | DynEnumTwo) => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4097,7 +4117,7 @@ class BamlStreamClient { ClassifyMessage( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, Category> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "ClassifyMessage", @@ -4109,9 +4129,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, Category>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is types.Category => a, (a): a is Category => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4130,7 +4150,7 @@ class BamlStreamClient { ClassifyMessage2( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, Category> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "ClassifyMessage2", @@ -4142,9 +4162,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, Category>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is types.Category => a, (a): a is Category => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4163,7 +4183,7 @@ class BamlStreamClient { ClassifyMessage3( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, Category> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "ClassifyMessage3", @@ -4175,9 +4195,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, Category>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is types.Category => a, (a): a is Category => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4196,7 +4216,7 @@ class BamlStreamClient { Completion( prefix: string,suffix: string,language: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "Completion", @@ -4208,9 +4228,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4229,7 +4249,7 @@ class BamlStreamClient { CustomTask( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, BookOrder | FlightConfirmation | GroceryReceipt> { + ): BamlStream<(partial_types.BookOrder | null | partial_types.FlightConfirmation | null | partial_types.GroceryReceipt | null), BookOrder | FlightConfirmation | GroceryReceipt> { try { const raw = this.runtime.streamFunction( "CustomTask", @@ -4241,9 +4261,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, BookOrder | FlightConfirmation | GroceryReceipt>( + return new BamlStream<(partial_types.BookOrder | null | partial_types.FlightConfirmation | null | partial_types.GroceryReceipt | null), BookOrder | FlightConfirmation | GroceryReceipt>( raw, - (a): a is RecursivePartialNull => a, + (a): a is (partial_types.BookOrder | null | partial_types.FlightConfirmation | null | partial_types.GroceryReceipt | null) => a, (a): a is BookOrder | FlightConfirmation | GroceryReceipt => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4262,7 +4282,7 @@ class BamlStreamClient { DescribeImage( img: Image, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "DescribeImage", @@ -4274,9 +4294,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4295,7 +4315,7 @@ class BamlStreamClient { DescribeImage2( classWithImage: ClassWithImage,img2: Image, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "DescribeImage2", @@ -4307,9 +4327,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4328,7 +4348,7 @@ class BamlStreamClient { DescribeImage3( classWithImage: ClassWithImage,img2: Image, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "DescribeImage3", @@ -4340,9 +4360,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4361,7 +4381,7 @@ class BamlStreamClient { DescribeImage4( classWithImage: ClassWithImage,img2: Image, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "DescribeImage4", @@ -4373,9 +4393,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4394,7 +4414,7 @@ class BamlStreamClient { DifferentiateUnions( __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, OriginalA | OriginalB> { + ): BamlStream<(partial_types.OriginalA | null | partial_types.OriginalB | null), OriginalA | OriginalB> { try { const raw = this.runtime.streamFunction( "DifferentiateUnions", @@ -4406,9 +4426,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, OriginalA | OriginalB>( + return new BamlStream<(partial_types.OriginalA | null | partial_types.OriginalB | null), OriginalA | OriginalB>( raw, - (a): a is RecursivePartialNull => a, + (a): a is (partial_types.OriginalA | null | partial_types.OriginalB | null) => a, (a): a is OriginalA | OriginalB => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4427,7 +4447,7 @@ class BamlStreamClient { DummyOutputFunction( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, DummyOutput> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "DummyOutputFunction", @@ -4439,9 +4459,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, DummyOutput>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.DummyOutput => a, (a): a is DummyOutput => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4460,7 +4480,7 @@ class BamlStreamClient { DynamicFunc( input: DynamicClassOne, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, DynamicClassTwo> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "DynamicFunc", @@ -4472,9 +4492,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, DynamicClassTwo>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.DynamicClassTwo => a, (a): a is DynamicClassTwo => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4493,7 +4513,7 @@ class BamlStreamClient { DynamicInputOutput( input: DynInputOutput, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, DynInputOutput> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "DynamicInputOutput", @@ -4505,9 +4525,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, DynInputOutput>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.DynInputOutput => a, (a): a is DynInputOutput => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4526,7 +4546,7 @@ class BamlStreamClient { DynamicListInputOutput( input: DynInputOutput[], __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, DynInputOutput[]> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "DynamicListInputOutput", @@ -4538,9 +4558,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, DynInputOutput[]>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.DynInputOutput | null[] => a, (a): a is DynInputOutput[] => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4559,7 +4579,7 @@ class BamlStreamClient { ExpectFailure( __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "ExpectFailure", @@ -4571,9 +4591,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4592,7 +4612,7 @@ class BamlStreamClient { ExtractContactInfo( document: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, ContactInfo> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "ExtractContactInfo", @@ -4604,9 +4624,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, ContactInfo>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.ContactInfo => a, (a): a is ContactInfo => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4625,7 +4645,7 @@ class BamlStreamClient { ExtractHobby( text: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, (string | Hobby)[]> { + ): BamlStream<(string | Hobby | null)[], (string | Hobby)[]> { try { const raw = this.runtime.streamFunction( "ExtractHobby", @@ -4637,9 +4657,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, (string | Hobby)[]>( + return new BamlStream<(string | Hobby | null)[], (string | Hobby)[]>( raw, - (a): a is RecursivePartialNull<(string | Hobby)[]> => a, + (a): a is (string | Hobby | null)[] => a, (a): a is (string | Hobby)[] => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4658,7 +4678,7 @@ class BamlStreamClient { ExtractNames( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string[]> { + ): BamlStream<(string | null)[], string[]> { try { const raw = this.runtime.streamFunction( "ExtractNames", @@ -4670,9 +4690,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string[]>( + return new BamlStream<(string | null)[], string[]>( raw, - (a): a is RecursivePartialNull => a, + (a): a is (string | null)[] => a, (a): a is string[] => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4691,7 +4711,7 @@ class BamlStreamClient { ExtractPeople( text: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, Person[]> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "ExtractPeople", @@ -4703,9 +4723,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, Person[]>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.Person | null[] => a, (a): a is Person[] => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4724,7 +4744,7 @@ class BamlStreamClient { ExtractReceiptInfo( email: string,reason: "curiosity" | "personal_finance", __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, ReceiptInfo> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "ExtractReceiptInfo", @@ -4736,9 +4756,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, ReceiptInfo>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.ReceiptInfo => a, (a): a is ReceiptInfo => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4757,7 +4777,7 @@ class BamlStreamClient { ExtractResume( resume: string,img?: Image | null, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, Resume> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "ExtractResume", @@ -4769,9 +4789,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, Resume>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.Resume => a, (a): a is Resume => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4790,7 +4810,7 @@ class BamlStreamClient { ExtractResume2( resume: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, Resume> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "ExtractResume2", @@ -4802,9 +4822,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, Resume>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.Resume => a, (a): a is Resume => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4823,7 +4843,7 @@ class BamlStreamClient { FnClassOptionalOutput( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, ClassOptionalOutput | null> { + ): BamlStream<(partial_types.ClassOptionalOutput | null | null), ClassOptionalOutput | null> { try { const raw = this.runtime.streamFunction( "FnClassOptionalOutput", @@ -4835,9 +4855,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, ClassOptionalOutput | null>( + return new BamlStream<(partial_types.ClassOptionalOutput | null | null), ClassOptionalOutput | null>( raw, - (a): a is RecursivePartialNull => a, + (a): a is (partial_types.ClassOptionalOutput | null | null) => a, (a): a is ClassOptionalOutput | null => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4856,7 +4876,7 @@ class BamlStreamClient { FnClassOptionalOutput2( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, ClassOptionalOutput2 | null> { + ): BamlStream<(partial_types.ClassOptionalOutput2 | null | null), ClassOptionalOutput2 | null> { try { const raw = this.runtime.streamFunction( "FnClassOptionalOutput2", @@ -4868,9 +4888,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, ClassOptionalOutput2 | null>( + return new BamlStream<(partial_types.ClassOptionalOutput2 | null | null), ClassOptionalOutput2 | null>( raw, - (a): a is RecursivePartialNull => a, + (a): a is (partial_types.ClassOptionalOutput2 | null | null) => a, (a): a is ClassOptionalOutput2 | null => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4889,7 +4909,7 @@ class BamlStreamClient { FnEnumListOutput( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, EnumOutput[]> { + ): BamlStream<(EnumOutput | null)[], EnumOutput[]> { try { const raw = this.runtime.streamFunction( "FnEnumListOutput", @@ -4901,9 +4921,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, EnumOutput[]>( + return new BamlStream<(EnumOutput | null)[], EnumOutput[]>( raw, - (a): a is RecursivePartialNull => a, + (a): a is (EnumOutput | null)[] => a, (a): a is EnumOutput[] => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4922,7 +4942,7 @@ class BamlStreamClient { FnEnumOutput( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, EnumOutput> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "FnEnumOutput", @@ -4934,9 +4954,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, EnumOutput>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is types.EnumOutput => a, (a): a is EnumOutput => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4955,7 +4975,7 @@ class BamlStreamClient { FnLiteralClassInputOutput( input: LiteralClassHello, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, LiteralClassHello> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "FnLiteralClassInputOutput", @@ -4967,9 +4987,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, LiteralClassHello>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.LiteralClassHello => a, (a): a is LiteralClassHello => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -4988,7 +5008,7 @@ class BamlStreamClient { FnLiteralUnionClassInputOutput( input: LiteralClassOne | LiteralClassTwo, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, LiteralClassOne | LiteralClassTwo> { + ): BamlStream<(partial_types.LiteralClassOne | null | partial_types.LiteralClassTwo | null), LiteralClassOne | LiteralClassTwo> { try { const raw = this.runtime.streamFunction( "FnLiteralUnionClassInputOutput", @@ -5000,9 +5020,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, LiteralClassOne | LiteralClassTwo>( + return new BamlStream<(partial_types.LiteralClassOne | null | partial_types.LiteralClassTwo | null), LiteralClassOne | LiteralClassTwo>( raw, - (a): a is RecursivePartialNull => a, + (a): a is (partial_types.LiteralClassOne | null | partial_types.LiteralClassTwo | null) => a, (a): a is LiteralClassOne | LiteralClassTwo => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5021,7 +5041,7 @@ class BamlStreamClient { FnNamedArgsSingleStringOptional( myString?: string | null, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "FnNamedArgsSingleStringOptional", @@ -5033,9 +5053,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5054,7 +5074,7 @@ class BamlStreamClient { FnOutputBool( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, boolean> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "FnOutputBool", @@ -5066,9 +5086,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, boolean>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is boolean => a, (a): a is boolean => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5087,7 +5107,7 @@ class BamlStreamClient { FnOutputClass( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, TestOutputClass> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "FnOutputClass", @@ -5099,9 +5119,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, TestOutputClass>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.TestOutputClass => a, (a): a is TestOutputClass => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5120,7 +5140,7 @@ class BamlStreamClient { FnOutputClassList( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, TestOutputClass[]> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "FnOutputClassList", @@ -5132,9 +5152,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, TestOutputClass[]>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.TestOutputClass | null[] => a, (a): a is TestOutputClass[] => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5153,7 +5173,7 @@ class BamlStreamClient { FnOutputClassNested( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, TestClassNested> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "FnOutputClassNested", @@ -5165,9 +5185,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, TestClassNested>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.TestClassNested => a, (a): a is TestClassNested => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5186,7 +5206,7 @@ class BamlStreamClient { FnOutputClassWithEnum( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, TestClassWithEnum> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "FnOutputClassWithEnum", @@ -5198,9 +5218,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, TestClassWithEnum>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.TestClassWithEnum => a, (a): a is TestClassWithEnum => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5219,7 +5239,7 @@ class BamlStreamClient { FnOutputInt( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, number> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "FnOutputInt", @@ -5231,9 +5251,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, number>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is number => a, (a): a is number => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5252,7 +5272,7 @@ class BamlStreamClient { FnOutputLiteralBool( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, false> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "FnOutputLiteralBool", @@ -5264,9 +5284,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, false>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is false => a, (a): a is false => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5285,7 +5305,7 @@ class BamlStreamClient { FnOutputLiteralInt( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, 5> { + ): BamlStream<5, 5> { try { const raw = this.runtime.streamFunction( "FnOutputLiteralInt", @@ -5297,9 +5317,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, 5>( + return new BamlStream<5, 5>( raw, - (a): a is RecursivePartialNull<5> => a, + (a): a is 5 => a, (a): a is 5 => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5318,7 +5338,7 @@ class BamlStreamClient { FnOutputLiteralString( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, "example output"> { + ): BamlStream<"example output", "example output"> { try { const raw = this.runtime.streamFunction( "FnOutputLiteralString", @@ -5330,9 +5350,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, "example output">( + return new BamlStream<"example output", "example output">( raw, - (a): a is RecursivePartialNull<"example output"> => a, + (a): a is "example output" => a, (a): a is "example output" => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5351,7 +5371,7 @@ class BamlStreamClient { FnOutputStringList( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string[]> { + ): BamlStream<(string | null)[], string[]> { try { const raw = this.runtime.streamFunction( "FnOutputStringList", @@ -5363,9 +5383,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string[]>( + return new BamlStream<(string | null)[], string[]>( raw, - (a): a is RecursivePartialNull => a, + (a): a is (string | null)[] => a, (a): a is string[] => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5384,7 +5404,7 @@ class BamlStreamClient { FnTestAliasedEnumOutput( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, TestEnum> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "FnTestAliasedEnumOutput", @@ -5396,9 +5416,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, TestEnum>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is types.TestEnum => a, (a): a is TestEnum => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5417,7 +5437,7 @@ class BamlStreamClient { FnTestClassAlias( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, TestClassAlias> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "FnTestClassAlias", @@ -5429,9 +5449,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, TestClassAlias>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.TestClassAlias => a, (a): a is TestClassAlias => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5450,7 +5470,7 @@ class BamlStreamClient { FnTestNamedArgsSingleEnum( myArg: NamedArgsSingleEnum, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "FnTestNamedArgsSingleEnum", @@ -5462,9 +5482,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5483,7 +5503,7 @@ class BamlStreamClient { GetDataType( text: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, RaysData> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "GetDataType", @@ -5495,9 +5515,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, RaysData>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.RaysData => a, (a): a is RaysData => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5516,7 +5536,7 @@ class BamlStreamClient { GetOrderInfo( email: Email, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, OrderInfo> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "GetOrderInfo", @@ -5528,9 +5548,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, OrderInfo>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.OrderInfo => a, (a): a is OrderInfo => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5549,7 +5569,7 @@ class BamlStreamClient { GetQuery( query: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, SearchParams> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "GetQuery", @@ -5561,9 +5581,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, SearchParams>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.SearchParams => a, (a): a is SearchParams => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5580,9 +5600,9 @@ class BamlStreamClient { } InOutEnumMapKey( - i1: Partial>,i2: Partial>, + i1: Partial>,i2: Partial>, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream>>, Partial>> { + ): BamlStream<(Record ), Partial>> { try { const raw = this.runtime.streamFunction( "InOutEnumMapKey", @@ -5594,10 +5614,10 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream>>, Partial>>( + return new BamlStream<(Record ), Partial>>( raw, - (a): a is RecursivePartialNull>> => a, - (a): a is Partial> => a, + (a): a is (Record ) => a, + (a): a is Partial> => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), ) @@ -5615,7 +5635,7 @@ class BamlStreamClient { InOutLiteralStringUnionMapKey( i1: Partial>,i2: Partial>, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream>>, Partial>> { + ): BamlStream<(Record<"one" | "two" | "three" | "four", (string | null)> ), Partial>> { try { const raw = this.runtime.streamFunction( "InOutLiteralStringUnionMapKey", @@ -5627,9 +5647,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream>>, Partial>>( + return new BamlStream<(Record<"one" | "two" | "three" | "four", (string | null)> ), Partial>>( raw, - (a): a is RecursivePartialNull>> => a, + (a): a is (Record<"one" | "two" | "three" | "four", (string | null)> ) => a, (a): a is Partial> => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5648,7 +5668,7 @@ class BamlStreamClient { InOutSingleLiteralStringMapKey( m: Partial>, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream>>, Partial>> { + ): BamlStream<(Record<"key", (string | null)> ), Partial>> { try { const raw = this.runtime.streamFunction( "InOutSingleLiteralStringMapKey", @@ -5660,9 +5680,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream>>, Partial>>( + return new BamlStream<(Record<"key", (string | null)> ), Partial>>( raw, - (a): a is RecursivePartialNull>> => a, + (a): a is (Record<"key", (string | null)> ) => a, (a): a is Partial> => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5681,7 +5701,7 @@ class BamlStreamClient { JsonTypeAliasCycle( input: JsonValue, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, JsonValue> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "JsonTypeAliasCycle", @@ -5693,9 +5713,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, JsonValue>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is JsonValue => a, (a): a is JsonValue => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5714,7 +5734,7 @@ class BamlStreamClient { LiteralUnionsTest( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, 1 | true | "string output"> { + ): BamlStream<(1 | true | "string output"), 1 | true | "string output"> { try { const raw = this.runtime.streamFunction( "LiteralUnionsTest", @@ -5726,9 +5746,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, 1 | true | "string output">( + return new BamlStream<(1 | true | "string output"), 1 | true | "string output">( raw, - (a): a is RecursivePartialNull<1 | true | "string output"> => a, + (a): a is (1 | true | "string output") => a, (a): a is 1 | true | "string output" => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5747,7 +5767,7 @@ class BamlStreamClient { MakeBlockConstraint( __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream>, Checked> { + ): BamlStream, Checked> { try { const raw = this.runtime.streamFunction( "MakeBlockConstraint", @@ -5759,9 +5779,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream>, Checked>( + return new BamlStream, Checked>( raw, - (a): a is RecursivePartialNull> => a, + (a): a is Checked => a, (a): a is Checked => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5780,7 +5800,7 @@ class BamlStreamClient { MakeNestedBlockConstraint( __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, NestedBlockConstraint> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "MakeNestedBlockConstraint", @@ -5792,9 +5812,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, NestedBlockConstraint>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.NestedBlockConstraint => a, (a): a is NestedBlockConstraint => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5810,10 +5830,43 @@ class BamlStreamClient { } } + MakeSemanticContainer( + + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): BamlStream { + try { + const raw = this.runtime.streamFunction( + "MakeSemanticContainer", + { + + }, + undefined, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return new BamlStream( + raw, + (a): a is partial_types.SemanticContainer => a, + (a): a is SemanticContainer => a, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + ) + } catch (error) { + if (error instanceof Error) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } + } + throw error; + } + } + MapAlias( m: Record, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream>, Record> { + ): BamlStream<(Record ), Record> { try { const raw = this.runtime.streamFunction( "MapAlias", @@ -5825,9 +5878,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream>, Record>( + return new BamlStream<(Record ), Record>( raw, - (a): a is RecursivePartialNull> => a, + (a): a is (Record ) => a, (a): a is Record => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5846,7 +5899,7 @@ class BamlStreamClient { MergeAliasAttributes( money: number, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, MergeAttrs> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "MergeAliasAttributes", @@ -5858,9 +5911,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, MergeAttrs>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.MergeAttrs => a, (a): a is MergeAttrs => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5879,7 +5932,7 @@ class BamlStreamClient { MyFunc( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, DynamicOutput> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "MyFunc", @@ -5891,9 +5944,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, DynamicOutput>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.DynamicOutput => a, (a): a is DynamicOutput => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5912,7 +5965,7 @@ class BamlStreamClient { NestedAlias( c: number | string | boolean | number | string[] | Record, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream>, number | string | boolean | number | string[] | Record> { + ): BamlStream<(((number | null) | (string | null) | (boolean | null) | (number | null) | null) | (string | null)[] | (Record | null)), number | string | boolean | number | string[] | Record> { try { const raw = this.runtime.streamFunction( "NestedAlias", @@ -5924,9 +5977,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream>, number | string | boolean | number | string[] | Record>( + return new BamlStream<(((number | null) | (string | null) | (boolean | null) | (number | null) | null) | (string | null)[] | (Record | null)), number | string | boolean | number | string[] | Record>( raw, - (a): a is RecursivePartialNull> => a, + (a): a is (((number | null) | (string | null) | (boolean | null) | (number | null) | null) | (string | null)[] | (Record | null)) => a, (a): a is number | string | boolean | number | string[] | Record => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5945,7 +5998,7 @@ class BamlStreamClient { NullLiteralClassHello( s: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, ClassForNullLiteral> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "NullLiteralClassHello", @@ -5957,9 +6010,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, ClassForNullLiteral>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.ClassForNullLiteral => a, (a): a is ClassForNullLiteral => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -5978,7 +6031,7 @@ class BamlStreamClient { OptionalTest_Function( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, (OptionalTest_ReturnType | null)[]> { + ): BamlStream<(partial_types.OptionalTest_ReturnType | null | null)[], (OptionalTest_ReturnType | null)[]> { try { const raw = this.runtime.streamFunction( "OptionalTest_Function", @@ -5990,9 +6043,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, (OptionalTest_ReturnType | null)[]>( + return new BamlStream<(partial_types.OptionalTest_ReturnType | null | null)[], (OptionalTest_ReturnType | null)[]>( raw, - (a): a is RecursivePartialNull<(OptionalTest_ReturnType | null)[]> => a, + (a): a is (partial_types.OptionalTest_ReturnType | null | null)[] => a, (a): a is (OptionalTest_ReturnType | null)[] => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6011,7 +6064,7 @@ class BamlStreamClient { PredictAge( name: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, FooAny> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "PredictAge", @@ -6023,9 +6076,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, FooAny>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.FooAny => a, (a): a is FooAny => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6044,7 +6097,7 @@ class BamlStreamClient { PredictAgeBare( inp: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream>, Checked> { + ): BamlStream, Checked> { try { const raw = this.runtime.streamFunction( "PredictAgeBare", @@ -6056,9 +6109,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream>, Checked>( + return new BamlStream, Checked>( raw, - (a): a is RecursivePartialNull> => a, + (a): a is Checked => a, (a): a is Checked => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6077,7 +6130,7 @@ class BamlStreamClient { PrimitiveAlias( p: number | string | boolean | number, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, number | string | boolean | number> { + ): BamlStream<((number | null) | (string | null) | (boolean | null) | (number | null)), number | string | boolean | number> { try { const raw = this.runtime.streamFunction( "PrimitiveAlias", @@ -6089,9 +6142,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, number | string | boolean | number>( + return new BamlStream<((number | null) | (string | null) | (boolean | null) | (number | null)), number | string | boolean | number>( raw, - (a): a is RecursivePartialNull => a, + (a): a is ((number | null) | (string | null) | (boolean | null) | (number | null)) => a, (a): a is number | string | boolean | number => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6110,7 +6163,7 @@ class BamlStreamClient { PromptTestClaude( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "PromptTestClaude", @@ -6122,9 +6175,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6143,7 +6196,7 @@ class BamlStreamClient { PromptTestClaudeChat( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "PromptTestClaudeChat", @@ -6155,9 +6208,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6176,7 +6229,7 @@ class BamlStreamClient { PromptTestClaudeChatNoSystem( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "PromptTestClaudeChatNoSystem", @@ -6188,9 +6241,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6209,7 +6262,7 @@ class BamlStreamClient { PromptTestOpenAI( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "PromptTestOpenAI", @@ -6221,9 +6274,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6242,7 +6295,7 @@ class BamlStreamClient { PromptTestOpenAIChat( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "PromptTestOpenAIChat", @@ -6254,9 +6307,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6275,7 +6328,7 @@ class BamlStreamClient { PromptTestOpenAIChatNoSystem( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "PromptTestOpenAIChatNoSystem", @@ -6287,9 +6340,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6308,7 +6361,7 @@ class BamlStreamClient { PromptTestStreaming( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "PromptTestStreaming", @@ -6320,9 +6373,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6341,7 +6394,7 @@ class BamlStreamClient { RecursiveAliasCycle( input: RecAliasOne, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, RecAliasOne> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "RecursiveAliasCycle", @@ -6353,9 +6406,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, RecAliasOne>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is RecAliasOne => a, (a): a is RecAliasOne => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6374,7 +6427,7 @@ class BamlStreamClient { RecursiveClassWithAliasIndirection( cls: NodeWithAliasIndirection, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, NodeWithAliasIndirection> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "RecursiveClassWithAliasIndirection", @@ -6386,9 +6439,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, NodeWithAliasIndirection>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.NodeWithAliasIndirection => a, (a): a is NodeWithAliasIndirection => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6407,7 +6460,7 @@ class BamlStreamClient { ReturnAliasWithMergedAttributes( money: Checked, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream>, Checked> { + ): BamlStream, Checked> { try { const raw = this.runtime.streamFunction( "ReturnAliasWithMergedAttributes", @@ -6419,9 +6472,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream>, Checked>( + return new BamlStream, Checked>( raw, - (a): a is RecursivePartialNull> => a, + (a): a is Checked => a, (a): a is Checked => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6440,7 +6493,7 @@ class BamlStreamClient { ReturnFailingAssert( inp: number, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, number> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "ReturnFailingAssert", @@ -6452,9 +6505,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, number>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is number => a, (a): a is number => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6473,7 +6526,7 @@ class BamlStreamClient { ReturnMalformedConstraints( a: number, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, MalformedConstraints> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "ReturnMalformedConstraints", @@ -6485,9 +6538,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, MalformedConstraints>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.MalformedConstraints => a, (a): a is MalformedConstraints => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6506,7 +6559,7 @@ class BamlStreamClient { SchemaDescriptions( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, Schema> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "SchemaDescriptions", @@ -6518,9 +6571,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, Schema>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.Schema => a, (a): a is Schema => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6539,7 +6592,7 @@ class BamlStreamClient { SimpleRecursiveListAlias( input: RecursiveListAlias, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, RecursiveListAlias> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "SimpleRecursiveListAlias", @@ -6551,9 +6604,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, RecursiveListAlias>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is RecursiveListAlias => a, (a): a is RecursiveListAlias => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6572,7 +6625,7 @@ class BamlStreamClient { SimpleRecursiveMapAlias( input: RecursiveMapAlias, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, RecursiveMapAlias> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "SimpleRecursiveMapAlias", @@ -6584,9 +6637,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, RecursiveMapAlias>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is RecursiveMapAlias => a, (a): a is RecursiveMapAlias => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6605,7 +6658,7 @@ class BamlStreamClient { StreamBigNumbers( digits: number, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, BigNumbers> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "StreamBigNumbers", @@ -6617,9 +6670,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, BigNumbers>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.BigNumbers => a, (a): a is BigNumbers => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6638,7 +6691,7 @@ class BamlStreamClient { StreamFailingAssertion( theme: string,length: number, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, TwoStoriesOneTitle> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "StreamFailingAssertion", @@ -6650,9 +6703,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, TwoStoriesOneTitle>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.TwoStoriesOneTitle => a, (a): a is TwoStoriesOneTitle => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6671,7 +6724,7 @@ class BamlStreamClient { StreamOneBigNumber( digits: number, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, number> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "StreamOneBigNumber", @@ -6683,9 +6736,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, number>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is number => a, (a): a is number => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6704,7 +6757,7 @@ class BamlStreamClient { StreamUnionIntegers( digits: number, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, (number | string)[]> { + ): BamlStream<((number | null) | (string | null) | null)[], (number | string)[]> { try { const raw = this.runtime.streamFunction( "StreamUnionIntegers", @@ -6716,9 +6769,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, (number | string)[]>( + return new BamlStream<((number | null) | (string | null) | null)[], (number | string)[]>( raw, - (a): a is RecursivePartialNull<(number | string)[]> => a, + (a): a is ((number | null) | (string | null) | null)[] => a, (a): a is (number | string)[] => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6737,7 +6790,7 @@ class BamlStreamClient { StreamingCompoundNumbers( digits: number,yapping: boolean, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, CompoundBigNumbers> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "StreamingCompoundNumbers", @@ -6749,9 +6802,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, CompoundBigNumbers>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.CompoundBigNumbers => a, (a): a is CompoundBigNumbers => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6770,7 +6823,7 @@ class BamlStreamClient { TestAnthropic( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestAnthropic", @@ -6782,9 +6835,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6803,7 +6856,7 @@ class BamlStreamClient { TestAnthropicShorthand( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestAnthropicShorthand", @@ -6815,9 +6868,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6836,7 +6889,7 @@ class BamlStreamClient { TestAws( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestAws", @@ -6848,9 +6901,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6869,7 +6922,7 @@ class BamlStreamClient { TestAwsInvalidAccessKey( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestAwsInvalidAccessKey", @@ -6881,9 +6934,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6902,7 +6955,7 @@ class BamlStreamClient { TestAwsInvalidProfile( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestAwsInvalidProfile", @@ -6914,9 +6967,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6935,7 +6988,7 @@ class BamlStreamClient { TestAwsInvalidRegion( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestAwsInvalidRegion", @@ -6947,9 +7000,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -6968,7 +7021,7 @@ class BamlStreamClient { TestAwsInvalidSessionToken( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestAwsInvalidSessionToken", @@ -6980,9 +7033,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7001,7 +7054,7 @@ class BamlStreamClient { TestAzure( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestAzure", @@ -7013,9 +7066,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7034,7 +7087,7 @@ class BamlStreamClient { TestAzureFailure( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestAzureFailure", @@ -7046,9 +7099,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7067,7 +7120,7 @@ class BamlStreamClient { TestCaching( input: string,not_cached: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestCaching", @@ -7079,9 +7132,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7100,7 +7153,7 @@ class BamlStreamClient { TestFallbackClient( __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestFallbackClient", @@ -7112,9 +7165,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7133,7 +7186,7 @@ class BamlStreamClient { TestFallbackToShorthand( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestFallbackToShorthand", @@ -7145,9 +7198,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7166,7 +7219,7 @@ class BamlStreamClient { TestFnNamedArgsSingleBool( myBool: boolean, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestFnNamedArgsSingleBool", @@ -7178,9 +7231,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7199,7 +7252,7 @@ class BamlStreamClient { TestFnNamedArgsSingleClass( myArg: NamedArgsSingleClass, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestFnNamedArgsSingleClass", @@ -7211,9 +7264,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7232,7 +7285,7 @@ class BamlStreamClient { TestFnNamedArgsSingleEnumList( myArg: NamedArgsSingleEnumList[], __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestFnNamedArgsSingleEnumList", @@ -7244,9 +7297,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7265,7 +7318,7 @@ class BamlStreamClient { TestFnNamedArgsSingleFloat( myFloat: number, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestFnNamedArgsSingleFloat", @@ -7277,9 +7330,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7298,7 +7351,7 @@ class BamlStreamClient { TestFnNamedArgsSingleInt( myInt: number, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestFnNamedArgsSingleInt", @@ -7310,9 +7363,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7331,7 +7384,7 @@ class BamlStreamClient { TestFnNamedArgsSingleMapStringToClass( myMap: Record, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream>, Record> { + ): BamlStream<(Record ), Record> { try { const raw = this.runtime.streamFunction( "TestFnNamedArgsSingleMapStringToClass", @@ -7343,9 +7396,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream>, Record>( + return new BamlStream<(Record ), Record>( raw, - (a): a is RecursivePartialNull> => a, + (a): a is (Record ) => a, (a): a is Record => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7364,7 +7417,7 @@ class BamlStreamClient { TestFnNamedArgsSingleMapStringToMap( myMap: Record>, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream>>, Record>> { + ): BamlStream<(Record | null)> ), Record>> { try { const raw = this.runtime.streamFunction( "TestFnNamedArgsSingleMapStringToMap", @@ -7376,9 +7429,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream>>, Record>>( + return new BamlStream<(Record | null)> ), Record>>( raw, - (a): a is RecursivePartialNull>> => a, + (a): a is (Record | null)> ) => a, (a): a is Record> => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7397,7 +7450,7 @@ class BamlStreamClient { TestFnNamedArgsSingleMapStringToString( myMap: Record, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream>, Record> { + ): BamlStream<(Record ), Record> { try { const raw = this.runtime.streamFunction( "TestFnNamedArgsSingleMapStringToString", @@ -7409,9 +7462,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream>, Record>( + return new BamlStream<(Record ), Record>( raw, - (a): a is RecursivePartialNull> => a, + (a): a is (Record ) => a, (a): a is Record => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7430,7 +7483,7 @@ class BamlStreamClient { TestFnNamedArgsSingleString( myString: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestFnNamedArgsSingleString", @@ -7442,9 +7495,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7463,7 +7516,7 @@ class BamlStreamClient { TestFnNamedArgsSingleStringArray( myStringArray: string[], __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestFnNamedArgsSingleStringArray", @@ -7475,9 +7528,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7496,7 +7549,7 @@ class BamlStreamClient { TestFnNamedArgsSingleStringList( myArg: string[], __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestFnNamedArgsSingleStringList", @@ -7508,9 +7561,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7529,7 +7582,7 @@ class BamlStreamClient { TestGemini( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestGemini", @@ -7541,9 +7594,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7562,7 +7615,7 @@ class BamlStreamClient { TestGeminiOpenAiGeneric( __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestGeminiOpenAiGeneric", @@ -7574,9 +7627,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7595,7 +7648,7 @@ class BamlStreamClient { TestGeminiSystem( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestGeminiSystem", @@ -7607,9 +7660,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7628,7 +7681,7 @@ class BamlStreamClient { TestGeminiSystemAsChat( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestGeminiSystemAsChat", @@ -7640,9 +7693,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7661,7 +7714,7 @@ class BamlStreamClient { TestImageInput( img: Image, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestImageInput", @@ -7673,9 +7726,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7694,7 +7747,7 @@ class BamlStreamClient { TestImageInputAnthropic( img: Image, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestImageInputAnthropic", @@ -7706,9 +7759,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7727,7 +7780,7 @@ class BamlStreamClient { TestImageListInput( imgs: Image[], __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestImageListInput", @@ -7739,9 +7792,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7760,7 +7813,7 @@ class BamlStreamClient { TestMulticlassNamedArgs( myArg: NamedArgsSingleClass,myArg2: NamedArgsSingleClass, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestMulticlassNamedArgs", @@ -7772,9 +7825,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7793,7 +7846,7 @@ class BamlStreamClient { TestNamedArgsLiteralBool( myBool: true, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestNamedArgsLiteralBool", @@ -7805,9 +7858,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7826,7 +7879,7 @@ class BamlStreamClient { TestNamedArgsLiteralInt( myInt: 1, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestNamedArgsLiteralInt", @@ -7838,9 +7891,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7859,7 +7912,7 @@ class BamlStreamClient { TestNamedArgsLiteralString( myString: "My String", __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestNamedArgsLiteralString", @@ -7871,9 +7924,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7892,7 +7945,7 @@ class BamlStreamClient { TestOllama( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestOllama", @@ -7904,9 +7957,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7925,7 +7978,7 @@ class BamlStreamClient { TestOpenAILegacyProvider( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestOpenAILegacyProvider", @@ -7937,9 +7990,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7958,7 +8011,7 @@ class BamlStreamClient { TestOpenAIShorthand( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestOpenAIShorthand", @@ -7970,9 +8023,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -7991,7 +8044,7 @@ class BamlStreamClient { TestRetryConstant( __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestRetryConstant", @@ -8003,9 +8056,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -8024,7 +8077,7 @@ class BamlStreamClient { TestRetryExponential( __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestRetryExponential", @@ -8036,9 +8089,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -8057,7 +8110,7 @@ class BamlStreamClient { TestSingleFallbackClient( __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestSingleFallbackClient", @@ -8069,9 +8122,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -8090,7 +8143,7 @@ class BamlStreamClient { TestVertex( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestVertex", @@ -8102,9 +8155,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -8123,7 +8176,7 @@ class BamlStreamClient { TestVertexWithSystemInstructions( __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, string> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "TestVertexWithSystemInstructions", @@ -8135,9 +8188,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, string>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is string => a, (a): a is string => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -8156,7 +8209,7 @@ class BamlStreamClient { UnionTest_Function( input: string | boolean, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, UnionTest_ReturnType> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "UnionTest_Function", @@ -8168,9 +8221,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, UnionTest_ReturnType>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is partial_types.UnionTest_ReturnType => a, (a): a is UnionTest_ReturnType => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -8189,7 +8242,7 @@ class BamlStreamClient { UseBlockConstraint( inp: BlockConstraintForParam, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, number> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "UseBlockConstraint", @@ -8201,9 +8254,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, number>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is number => a, (a): a is number => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -8222,7 +8275,7 @@ class BamlStreamClient { UseMalformedConstraints( a: MalformedConstraints2, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, number> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "UseMalformedConstraints", @@ -8234,9 +8287,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, number>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is number => a, (a): a is number => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), @@ -8255,7 +8308,7 @@ class BamlStreamClient { UseNestedBlockConstraint( inp: NestedBlockConstraintForParam, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): BamlStream, number> { + ): BamlStream { try { const raw = this.runtime.streamFunction( "UseNestedBlockConstraint", @@ -8267,9 +8320,9 @@ class BamlStreamClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return new BamlStream, number>( + return new BamlStream( raw, - (a): a is RecursivePartialNull => a, + (a): a is number => a, (a): a is number => a, this.ctx_manager.cloneContext(), __baml_options__?.tb?.__tb(), diff --git a/integ-tests/typescript/baml_client/inlinedbaml.ts b/integ-tests/typescript/baml_client/inlinedbaml.ts index e8f8fbe94c..7fe3f6ee9c 100644 --- a/integ-tests/typescript/baml_client/inlinedbaml.ts +++ b/integ-tests/typescript/baml_client/inlinedbaml.ts @@ -69,7 +69,7 @@ const fileMap = { "test-files/functions/output/class-list.baml": "function FnOutputClassList(input: string) -> TestOutputClass[] {\n client GPT35\n prompt #\"\n Return a JSON array that follows this schema: \n {{ctx.output_format}}\n\n JSON:\n \"#\n}\n\ntest FnOutputClassList {\n functions [FnOutputClassList]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/class-nested.baml": "class TestClassNested {\n prop1 string\n prop2 InnerClass\n}\n\nclass InnerClass {\n prop1 string\n prop2 string\n inner InnerClass2\n}\n\nclass InnerClass2 {\n prop2 int\n prop3 float\n}\n\nfunction FnOutputClassNested(input: string) -> TestClassNested {\n client GPT35\n prompt #\"\n Return a made up json blob that matches this schema:\n {{ctx.output_format}}\n ---\n\n JSON:\n \"#\n}\n\ntest FnOutputClassNested {\n functions [FnOutputClassNested]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/class-with-enum.baml": "enum EnumInClass {\n ONE\n TWO\n}\n\nclass TestClassWithEnum {\n prop1 string\n prop2 EnumInClass\n}\n\nfunction FnOutputClassWithEnum(input: string) -> TestClassWithEnum {\n client GPT35\n prompt #\"\n Return a made up json blob that matches this schema:\n {{ctx.output_format}}\n ---\n\n JSON:\n \"#\n}\n\ntest FnOutputClassWithEnum {\n functions [FnOutputClassWithEnum]\n args {\n input \"example input\"\n }\n}\n", - "test-files/functions/output/class.baml": "class TestOutputClass {\n prop1 string\n prop2 int\n}\n\nfunction FnOutputClass(input: string) -> TestOutputClass {\n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n For the prop2, always return a 540\n\n JSON:\n \"#\n}\n\ntest TestClass {\n functions [FnOutputClass]\n args {\n input \"example input\"\n }\n}\n", + "test-files/functions/output/class.baml": "class TestOutputClass {\n prop1 string @description(\"A long string with about 200 words\")\n prop2 int\n}\n\nfunction FnOutputClass(input: string) -> TestOutputClass {\n client GPT35\n prompt #\"\n Return a JSON blob with this schema: \n {{ctx.output_format}}\n\n For the prop2, always return a 540\n\n JSON:\n \"#\n}\n\ntest TestClass {\n functions [FnOutputClass]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/enum-list.baml": "function FnEnumListOutput(input: string) -> EnumOutput[] {\n client GPT35\n prompt #\"\n Print out two of these values randomly selected from the list below in a json array.\n\n {{ctx.output_format}}\n\n Answer:\n \"#\n} \n\ntest FnEnumListOutput {\n functions [FnEnumListOutput]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/enum.baml": "/// An enum with three values,\n/// ONE, TWO and THREE.\nenum EnumOutput {\n\n /// The first enum.\n ONE\n\n /// The second enum.\n TWO\n THREE\n\n @@alias(\"VALUE_ENUM\")\n}\n\nfunction FnEnumOutput(input: string) -> EnumOutput {\n client GPT35\n prompt #\"\n Choose one of these values randomly. Before you give the answer, write out an unrelated haiku about the ocean.\n\n {{ctx.output_format(prefix=null)}}\n \"#\n}\n\ntest FnEnumOutput {\n functions [FnEnumOutput]\n args {\n input \"example input\"\n }\n}\n", "test-files/functions/output/int.baml": "function FnOutputInt(input: string) -> int {\n client GPT35\n prompt #\"\n Return the integer 5 with no additional context.\n \"#\n}\n\ntest FnOutputInt {\n functions [FnOutputInt]\n args {\n input \"example input\"\n }\n}\n", @@ -100,6 +100,7 @@ const fileMap = { "test-files/providers/openai.baml": "function PromptTestOpenAI(input: string) -> string {\n client GPT35\n prompt #\"\n Write a nice haiku about {{ input }}\n \"#\n}\n\nfunction TestOpenAILegacyProvider(input: string) -> string {\n client GPT35LegacyProvider\n prompt #\"\n Write a nice haiku about {{ input }}\n \"#\n}\n\nfunction TestOpenAIShorthand(input: string) -> string {\n client GPT35\n prompt #\"\n Write a nice short story about {{ input }}\n \"#\n}", "test-files/providers/tests.baml": "test TestOpenAIShorthand {\n functions [TestOpenAIShorthand]\n args {\n input \"Donkey kong and peanut butter\"\n }\n}\n\ntest TestAWS {\n functions [\n TestAws\n ]\n args {\n input \"Donkey kong and peanut butter\"\n }\n}\n\ntest TestProvider {\n functions [\n TestAnthropic, TestVertex, PromptTestOpenAI, TestAzure, TestOllama, TestGemini, TestAws,\n TestAwsInvalidRegion,\n TestOpenAIShorthand,\n TestAnthropicShorthand,\n TestAwsInvalidAccessKey,\n TestAwsInvalidProfile,\n TestAwsInvalidSessionToken\n ]\n args {\n input \"Donkey kong and peanut butter\"\n }\n}\n\ntest TestName {\n functions [TestCaching]\n args {\n input #\"\nIn a near-future society where dreams have become a tradable commodity and shared experience, a lonely and socially awkward teenager named Alex discovers they possess a rare and powerful ability to not only view but also manipulate the dreams of others. Initially thrilled by this newfound power, Alex begins subtly altering the dreams of classmates and family members, helping them overcome fears, boost confidence, or experience fantastical adventures. As Alex's skills grow, so does their influence. They start selling premium dream experiences on the black market, crafting intricate and addictive dreamscapes for wealthy clients. However, the line between dream and reality begins to blur for those exposed to Alex's creations. Some clients struggle to differentiate between their true memories and the artificial ones implanted by Alex's dream manipulation.\n\nComplications arise when a mysterious government agency takes notice of Alex's unique abilities. They offer Alex a chance to use their gift for \"the greater good,\" hinting at applications in therapy, criminal rehabilitation, and even national security. Simultaneously, an underground resistance movement reaches out, warning Alex about the dangers of dream manipulation and the potential for mass control and exploitation. Caught between these opposing forces, Alex must navigate a complex web of ethical dilemmas. They grapple with questions of free will, the nature of consciousness, and the responsibility that comes with having power over people's minds. As the consequences of their actions spiral outward, affecting the lives of loved ones and strangers alike, Alex is forced to confront the true nature of their ability and decide how—or if—it should be used.\n\nThe story explores themes of identity, the subconscious mind, the ethics of technology, and the power of imagination. It delves into the potential consequences of a world where our most private thoughts and experiences are no longer truly our own, and examines the fine line between helping others and manipulating them for personal gain or a perceived greater good. The narrative further expands on the societal implications of such abilities, questioning the moral boundaries of altering consciousness and the potential for abuse in a world where dreams can be commodified. It challenges the reader to consider the impact of technology on personal autonomy and the ethical responsibilities of those who wield such power.\n\nAs Alex's journey unfolds, they encounter various individuals whose lives have been touched by their dream manipulations, each presenting a unique perspective on the ethical quandaries at hand. From a classmate who gains newfound confidence to a wealthy client who becomes addicted to the dreamscapes, the ripple effects of Alex's actions are profound and far-reaching. The government agency's interest in Alex's abilities raises questions about the potential for state control and surveillance, while the resistance movement highlights the dangers of unchecked power and the importance of safeguarding individual freedoms.\n\nUltimately, Alex's story is one of self-discovery and moral reckoning, as they must decide whether to embrace their abilities for personal gain, align with the government's vision of a controlled utopia, or join the resistance in their fight for freedom and autonomy. The narrative invites readers to reflect on the nature of reality, the boundaries of human experience, and the ethical implications of a world where dreams are no longer private sanctuaries but shared and manipulated commodities. It also explores the psychological impact on Alex, who must deal with the burden of knowing the intimate fears and desires of others, and the isolation that comes from being unable to share their own dreams without altering them.\n\nThe story further examines the technological advancements that have made dream manipulation possible, questioning the role of innovation in society and the potential for both progress and peril. It considers the societal divide between those who can afford to buy enhanced dream experiences and those who cannot, highlighting issues of inequality and access. As Alex becomes more entangled in the web of their own making, they must confront the possibility that their actions could lead to unintended consequences, not just for themselves but for the fabric of society as a whole.\n\nIn the end, Alex's journey is a cautionary tale about the power of dreams and the responsibilities that come with wielding such influence. It serves as a reminder of the importance of ethical considerations in the face of technological advancement and the need to balance innovation with humanity. The story leaves readers pondering the true cost of a world where dreams are no longer sacred, and the potential for both wonder and danger in the uncharted territories of the mind. But it's also a story about the power of imagination and the potential for change, even in a world where our deepest thoughts are no longer our own. And it's a story about the power of choice, and the importance of fighting for the freedom to dream.\n\nIn conclusion, this story is a reflection on the power of dreams and the responsibilities that come with wielding such influence. It serves as a reminder of the importance of ethical considerations in the face of technological advancement and the need to balance innovation with humanity. The story leaves readers pondering the true cost of a world where dreams are no longer sacred, and the potential for both wonder and danger in the uncharted territories of the mind. But it's also a story about the power of imagination and the potential for change, even in a world where our deepest thoughts are no longer our own. And it's a story about the power of choice, and the importance of fighting for the freedom to dream.\n \"#\n not_cached #\"\n hello world\n \"#\n }\n}", "test-files/providers/vertex.baml": "function TestVertex(input: string) -> string {\n client Vertex\n prompt #\"\n Write a nice short story about {{ input }}\n \"#\n}\n\nfunction TestVertexWithSystemInstructions() -> string {\n client Vertex\n prompt #\"{{_.role(\"system\")}} You are a helpful assistant\n {{_.role(\"user\")}} Write a poem about llamas\n \"#\n}\n\ntest TestVertex {\n functions [TestVertex, TestVertexWithSystemInstructions]\n args {\n input \"a cat\"\n\n }\n}\n", + "test-files/semantic_streaming/semantic_streaming.baml": "class SemanticContainer {\n sixteen_digit_number int\n string_with_twenty_words string @stream.done\n class_1 ClassWithoutDone\n class_2 ClassWithBlockDone\n class_done_needed ClassWithBlockDone @stream.not_null\n class_needed ClassWithoutDone @stream.not_null\n three_small_things SmallThing[] @description(\"Should have three items.\")\n final_string string\n}\n\nclass ClassWithoutDone {\n i_16_digits int\n s_20_words string @description(\"A string with 20 words in it\") @stream.with_state\n}\n\nclass ClassWithBlockDone {\n i_16_digits int\n s_20_words string\n @@stream.done\n}\n\nclass SmallThing {\n i_16_digits int @stream.not_null\n i_8_digits int\n}\n\nfunction MakeSemanticContainer() -> SemanticContainer {\n client GPT35\n prompt #\"\n {{ ctx.output_format }}\n \"#\n}", "test-files/strategies/fallback-shorthand.baml": "\nclient FallbackToShorthand {\n provider fallback\n options {\n strategy [\n \"openai/does-not-exist\",\n \"openai/gpt-4o-mini\"\n ]\n }\n}\n\n\nfunction TestFallbackToShorthand(input: string) -> string {\n client FallbackToShorthand\n // TODO make it return the client name instead\n prompt #\"\n Say a haiku about {{input}}.\n \"#\n}\n\ntest TestProvider_FallbackToShorthand {\n functions [\n TestFallbackToShorthand\n ]\n args {\n input \"Donkey kong and peanut butter\"\n }\n}\n", "test-files/strategies/fallback.baml": "// Happy path fallbacks.\nclient FaultyClient {\n provider openai\n options {\n model unknown-model\n api_key env.OPENAI_API_KEY\n }\n}\n\n\nclient FallbackClient {\n provider fallback\n options {\n // first 2 clients are expected to fail.\n strategy [\n FaultyClient,\n RetryClientConstant,\n GPT35\n Gemini\n\n ]\n }\n}\n\nfunction TestFallbackClient() -> string {\n client FallbackClient\n // TODO make it return the client name instead\n prompt #\"\n Say a haiku about mexico.\n \"#\n}\n\n// Fallbacks should fail gracefully.\nclient FaultyAzureClient {\n provider azure-openai\n options {\n model unknown-model\n resource_name \"unknown-resource-id\"\n deployment_id \"unknown-deployment-id\"\n }\n}\n\nclient SingleFallbackClient {\n provider fallback\n options {\n // first 2 clients are expected to fail.\n strategy [\n FaultyAzureClient\n ]\n }\n}\n\nfunction TestSingleFallbackClient() -> string {\n client SingleFallbackClient\n // TODO make it return the client name instead\n prompt #\"\n Say a haiku about mexico.\n \"#\n}\n", "test-files/strategies/retry.baml": "\nretry_policy Exponential {\n max_retries 3\n strategy {\n type exponential_backoff\n }\n}\n\nretry_policy Constant {\n max_retries 3\n strategy {\n type constant_delay\n delay_ms 100\n }\n}\n\nclient RetryClientConstant {\n provider openai\n retry_policy Constant\n options {\n model \"gpt-3.5-turbo\"\n api_key \"blah\"\n }\n}\n\nclient RetryClientExponential {\n provider openai\n retry_policy Exponential\n options {\n model \"gpt-3.5-turbo\"\n api_key \"blahh\"\n }\n}\n\nfunction TestRetryConstant() -> string {\n client RetryClientConstant\n prompt #\"\n Say a haiku\n \"#\n}\n\nfunction TestRetryExponential() -> string {\n client RetryClientExponential\n prompt #\"\n Say a haiku\n \"#\n}\n", diff --git a/integ-tests/typescript/baml_client/partial_types.ts b/integ-tests/typescript/baml_client/partial_types.ts new file mode 100644 index 0000000000..657ee165af --- /dev/null +++ b/integ-tests/typescript/baml_client/partial_types.ts @@ -0,0 +1,490 @@ +/************************************************************************************************* + +Welcome to Baml! To use this generated code, please run one of the following: + +$ npm install @boundaryml/baml +$ yarn add @boundaryml/baml +$ pnpm add @boundaryml/baml + +*************************************************************************************************/ + +// This file was generated by BAML: do not edit it. Instead, edit the BAML +// files and re-generate this code. +// +/* eslint-disable */ +// tslint:disable +// @ts-nocheck +// biome-ignore format: autogenerated code +import { Image } from "@boundaryml/baml" + +import * as types from "./types" + +/****************************************************************************** +* +* These types are used for streaming, for when an instance of a type +* is still being built up and any of its fields is not yet fully available. +* +******************************************************************************/ + +export interface StreamState { + value: T + state: "Pending" | "Incomplete" | "Complete" +} + + +export interface BigNumbers { + a?: (number | null) + b?: (number | null) +} + +export interface BinaryNode { + data?: (number | null) + left: (partial_types.BinaryNode | null | null) + right: (partial_types.BinaryNode | null | null) +} + +export interface Blah { + prop4: ((string | null) | null) +} + +export interface BlockConstraint { + foo?: (number | null) + bar?: (string | null) +} + +export interface BlockConstraintForParam { + bcfp?: (number | null) + bcfp2?: (string | null) +} + +export interface BookOrder { + orderId?: (string | null) + title?: (string | null) + quantity?: (number | null) + price?: (number | null) +} + +export interface ClassForNullLiteral { + a: "hi" +} + +export interface ClassOptionalOutput { + prop1?: (string | null) + prop2?: (string | null) +} + +export interface ClassOptionalOutput2 { + prop1: ((string | null) | null) + prop2: ((string | null) | null) + prop3: (partial_types.Blah | null | null) +} + +export interface ClassToRecAlias { + list?: partial_types.LinkedListAliasNode | null +} + +export interface ClassWithBlockDone { + i_16_digits?: (number | null) + s_20_words?: (string | null) +} + +export interface ClassWithImage { + myImage?: (Image | null) + param2?: (string | null) + fake_image?: partial_types.FakeImage | null +} + +export interface ClassWithoutDone { + i_16_digits?: (number | null) + s_20_words?: StreamState<(string | null)> +} + +export interface CompoundBigNumbers { + big?: partial_types.BigNumbers | null + big_nums?: partial_types.BigNumbers | null[] + another?: partial_types.BigNumbers | null +} + +export interface ContactInfo { + primary?: (partial_types.PhoneNumber | null | partial_types.EmailAddress | null | null) + secondary?: (partial_types.PhoneNumber | null | partial_types.EmailAddress | null | (null | null) | null) +} + +export interface CustomTaskResult { + bookOrder?: (partial_types.BookOrder | null | ((null | null) | null) | null) + flightConfirmation?: (partial_types.FlightConfirmation | null | ((null | null) | null) | null) + groceryReceipt?: (partial_types.GroceryReceipt | null | ((null | null) | null) | null) +} + +export interface DummyOutput { + nonce?: (string | null) + nonce2?: (string | null) + [key: string]: any; +} + +export interface DynInputOutput { + testKey?: (string | null) + [key: string]: any; +} + +export interface DynamicClassOne { + [key: string]: any; +} + +export interface DynamicClassTwo { + hi?: (string | null) + some_class?: partial_types.SomeClassNestedDynamic | null + status?: (string | DynEnumOne | null) + [key: string]: any; +} + +export interface DynamicOutput { + [key: string]: any; +} + +export interface Earthling { + age?: Checked<(number | null),"earth_aged" | "no_infants"> +} + +export interface Education { + institution?: (string | null) + location?: (string | null) + degree?: (string | null) + major?: (string | null)[] + graduation_date: ((string | null) | null) +} + +export interface Email { + subject?: (string | null) + body?: (string | null) + from_address?: (string | null) +} + +export interface EmailAddress { + value?: (string | null) +} + +export interface Event { + title?: (string | null) + date?: (string | null) + location?: (string | null) + description?: (string | null) +} + +export interface FakeImage { + url?: (string | null) +} + +export interface FlightConfirmation { + confirmationNumber?: (string | null) + flightNumber?: (string | null) + departureTime?: (string | null) + arrivalTime?: (string | null) + seatNumber?: (string | null) +} + +export interface FooAny { + planetary_age?: (partial_types.Martian | null | partial_types.Earthling | null | null) + certainty?: Checked<(number | null),"unreasonably_certain"> + species?: Checked<(string | null),"regex_bad" | "regex_good" | "trivial"> +} + +export interface Forest { + trees?: partial_types.Tree | null[] +} + +export interface FormatterTest0 { + lorem?: (string | null) + ipsum?: (string | null) +} + +export interface FormatterTest1 { + lorem?: (string | null) + ipsum?: (string | null) +} + +export interface FormatterTest2 { + lorem?: (string | null) + ipsum?: (string | null) +} + +export interface FormatterTest3 { + lorem?: (string | null) + ipsum?: (string | null) +} + +export interface GroceryReceipt { + receiptId?: (string | null) + storeName?: (string | null) + items?: ((string | null) | (number | null) | (number | null) | null)[] + totalAmount?: (number | null) +} + +export interface InnerClass { + prop1?: (string | null) + prop2?: (string | null) + inner?: partial_types.InnerClass2 | null +} + +export interface InnerClass2 { + prop2?: (number | null) + prop3?: (number | null) +} + +export interface InputClass { + key?: (string | null) + key2?: (string | null) +} + +export interface InputClassNested { + key?: (string | null) + nested?: partial_types.InputClass | null +} + +export interface LinkedList { + head: (partial_types.Node | null | null) + len?: (number | null) +} + +export interface LinkedListAliasNode { + value?: (number | null) + next: (partial_types.LinkedListAliasNode | null | null) +} + +export interface LiteralClassHello { + prop: "hello" +} + +export interface LiteralClassOne { + prop: "one" +} + +export interface LiteralClassTwo { + prop: "two" +} + +export interface MalformedConstraints { + foo?: Checked<(number | null),"foo_check"> +} + +export interface MalformedConstraints2 { + foo?: (number | null) +} + +/** + * A Martian organism with an age. + * Such a nice type. + */ +export interface Martian { + /** + * The age of the Martian in Mars years. + * So many Mars years. + */ + age?: Checked<(number | null),"young_enough"> +} + +export interface MergeAttrs { + amount?: Checked<(number | null),"gt_ten"> +} + +export interface NamedArgsSingleClass { + key?: (string | null) + key_two?: (boolean | null) + key_three?: (number | null) +} + +export interface Nested { + prop3?: ((string | null) | ((null | null) | null) | null) + prop4?: ((string | null) | ((null | null) | null) | null) + prop20?: partial_types.Nested2 | null +} + +export interface Nested2 { + prop11?: ((string | null) | ((null | null) | null) | null) + prop12?: ((string | null) | ((null | null) | null) | null) +} + +export interface NestedBlockConstraint { + nbc?: Checked +} + +export interface NestedBlockConstraintForParam { + nbcfp?: partial_types.BlockConstraintForParam | null +} + +export interface Node { + data?: (number | null) + next: (partial_types.Node | null | null) +} + +export interface NodeWithAliasIndirection { + value?: (number | null) + next: (partial_types.NodeWithAliasIndirection | null | null) +} + +export interface OptionalListAndMap { + p: ((string | null)[] | null) + q: ((Record | null) | null) +} + +export interface OptionalTest_Prop1 { + omega_a?: (string | null) + omega_b?: (number | null) +} + +export interface OptionalTest_ReturnType { + omega_1: (partial_types.OptionalTest_Prop1 | null | null) + omega_2: ((string | null) | null) + omega_3?: ((OptionalTest_CategoryType | null) | null)[] +} + +export interface OrderInfo { + order_status?: (OrderStatus | null) + tracking_number: ((string | null) | null) + estimated_arrival_date: ((string | null) | null) +} + +export interface OriginalA { + value?: (number | null) +} + +export interface OriginalB { + value?: (number | null) + [key: string]: any; +} + +export interface Person { + name: ((string | null) | null) + hair_color: ((string | Color | null) | null) + [key: string]: any; +} + +export interface PhoneNumber { + value?: (string | null) +} + +export interface Quantity { + amount?: ((number | null) | (number | null) | null) + unit: ((string | null) | null) +} + +export interface RaysData { + dataType?: (DataType | null) + value?: (partial_types.Resume | null | partial_types.Event | null | null) +} + +export interface ReceiptInfo { + items?: partial_types.ReceiptItem | null[] + total_cost: ((number | null) | null) + venue?: ("barisa" | "ox_burger" | null) +} + +export interface ReceiptItem { + name?: (string | null) + description: ((string | null) | null) + quantity?: (number | null) + price?: (number | null) +} + +export interface Recipe { + ingredients?: (Record | null) + recipe_type?: ("breakfast" | "dinner" | null) +} + +export interface Resume { + name?: (string | null) + email?: (string | null) + phone?: (string | null) + experience?: partial_types.Education | null[] + education?: (string | null)[] + skills?: (string | null)[] +} + +export interface Schema { + prop1?: ((string | null) | ((null | null) | null) | null) + prop2?: (partial_types.Nested | null | (string | null) | null) + prop5?: ((string | null) | ((null | null) | null) | null)[] + prop6?: ((string | null) | partial_types.Nested | null[] | null) + nested_attrs?: ((string | null) | ((null | null) | null) | partial_types.Nested | null | null)[] + parens?: ((string | null) | ((null | null) | null) | null) + other_group?: ((string | null) | ((number | null) | (string | null) | null) | null) +} + +export interface SearchParams { + dateRange: ((number | null) | null) + location?: (string | null)[] + jobTitle: (partial_types.WithReasoning | null | null) + company: (partial_types.WithReasoning | null | null) + description?: partial_types.WithReasoning | null[] + tags?: ((Tag | null) | (string | null) | null)[] +} + +export interface SemanticContainer { + sixteen_digit_number?: (number | null) + string_with_twenty_words: string + class_1?: partial_types.ClassWithoutDone | null + class_2: types.ClassWithBlockDone + class_done_needed: types.ClassWithBlockDone + class_needed: partial_types.ClassWithoutDone + three_small_things?: partial_types.SmallThing | null[] + final_string?: (string | null) +} + +export interface SmallThing { + i_16_digits: number + i_8_digits?: (number | null) +} + +export interface SomeClassNestedDynamic { + hi?: (string | null) + [key: string]: any; +} + +export interface StringToClassEntry { + word?: (string | null) +} + +export interface TestClassAlias { + key?: (string | null) + key2?: (string | null) + key3?: (string | null) + key4?: (string | null) + key5?: (string | null) +} + +export interface TestClassNested { + prop1?: (string | null) + prop2?: partial_types.InnerClass | null +} + +export interface TestClassWithEnum { + prop1?: (string | null) + prop2?: (EnumInClass | null) +} + +export interface TestOutputClass { + prop1?: (string | null) + prop2?: (number | null) +} + +export interface Tree { + data?: (number | null) + children?: partial_types.Forest | null +} + +export interface TwoStoriesOneTitle { + title?: (string | null) + story_a?: (string | null) + story_b?: (string | null) +} + +export interface UnionTest_ReturnType { + prop1?: ((string | null) | (boolean | null) | null) + prop2?: ((number | null) | (boolean | null) | null)[] + prop3?: ((boolean | null)[] | (number | null)[] | null) +} + +export interface WithReasoning { + value?: (string | null) + reasoning?: (string | null) +} diff --git a/integ-tests/typescript/baml_client/sync_client.ts b/integ-tests/typescript/baml_client/sync_client.ts index ded32a44e8..f6cec34424 100644 --- a/integ-tests/typescript/baml_client/sync_client.ts +++ b/integ-tests/typescript/baml_client/sync_client.ts @@ -17,16 +17,10 @@ $ pnpm add @boundaryml/baml // biome-ignore format: autogenerated code import { BamlRuntime, FunctionResult, BamlCtxManager, BamlSyncStream, Image, ClientRegistry, createBamlValidationError, BamlValidationError } from "@boundaryml/baml" import { Checked, Check } from "./types" -import {BigNumbers, BinaryNode, Blah, BlockConstraint, BlockConstraintForParam, BookOrder, ClassForNullLiteral, ClassOptionalOutput, ClassOptionalOutput2, ClassToRecAlias, ClassWithImage, CompoundBigNumbers, ContactInfo, CustomTaskResult, DummyOutput, DynInputOutput, DynamicClassOne, DynamicClassTwo, DynamicOutput, Earthling, Education, Email, EmailAddress, Event, FakeImage, FlightConfirmation, FooAny, Forest, FormatterTest0, FormatterTest1, FormatterTest2, FormatterTest3, GroceryReceipt, InnerClass, InnerClass2, InputClass, InputClassNested, LinkedList, LinkedListAliasNode, LiteralClassHello, LiteralClassOne, LiteralClassTwo, MalformedConstraints, MalformedConstraints2, Martian, MergeAttrs, NamedArgsSingleClass, Nested, Nested2, NestedBlockConstraint, NestedBlockConstraintForParam, Node, NodeWithAliasIndirection, OptionalListAndMap, OptionalTest_Prop1, OptionalTest_ReturnType, OrderInfo, OriginalA, OriginalB, Person, PhoneNumber, Quantity, RaysData, ReceiptInfo, ReceiptItem, Recipe, Resume, Schema, SearchParams, SomeClassNestedDynamic, StringToClassEntry, TestClassAlias, TestClassNested, TestClassWithEnum, TestOutputClass, Tree, TwoStoriesOneTitle, UnionTest_ReturnType, WithReasoning, AliasedEnum, Category, Category2, Category3, Color, DataType, DynEnumOne, DynEnumTwo, EnumInClass, EnumOutput, Hobby, MapKey, NamedArgsSingleEnum, NamedArgsSingleEnumList, OptionalTest_CategoryType, OrderStatus, Tag, TestEnum} from "./types" +import {BigNumbers, BinaryNode, Blah, BlockConstraint, BlockConstraintForParam, BookOrder, ClassForNullLiteral, ClassOptionalOutput, ClassOptionalOutput2, ClassToRecAlias, ClassWithBlockDone, ClassWithImage, ClassWithoutDone, CompoundBigNumbers, ContactInfo, CustomTaskResult, DummyOutput, DynInputOutput, DynamicClassOne, DynamicClassTwo, DynamicOutput, Earthling, Education, Email, EmailAddress, Event, FakeImage, FlightConfirmation, FooAny, Forest, FormatterTest0, FormatterTest1, FormatterTest2, FormatterTest3, GroceryReceipt, InnerClass, InnerClass2, InputClass, InputClassNested, LinkedList, LinkedListAliasNode, LiteralClassHello, LiteralClassOne, LiteralClassTwo, MalformedConstraints, MalformedConstraints2, Martian, MergeAttrs, NamedArgsSingleClass, Nested, Nested2, NestedBlockConstraint, NestedBlockConstraintForParam, Node, NodeWithAliasIndirection, OptionalListAndMap, OptionalTest_Prop1, OptionalTest_ReturnType, OrderInfo, OriginalA, OriginalB, Person, PhoneNumber, Quantity, RaysData, ReceiptInfo, ReceiptItem, Recipe, Resume, Schema, SearchParams, SemanticContainer, SmallThing, SomeClassNestedDynamic, StringToClassEntry, TestClassAlias, TestClassNested, TestClassWithEnum, TestOutputClass, Tree, TwoStoriesOneTitle, UnionTest_ReturnType, WithReasoning, AliasedEnum, Category, Category2, Category3, Color, DataType, DynEnumOne, DynEnumTwo, EnumInClass, EnumOutput, Hobby, MapKey, NamedArgsSingleEnum, NamedArgsSingleEnumList, OptionalTest_CategoryType, OrderStatus, Tag, TestEnum} from "./types" import TypeBuilder from "./type_builder" import { DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX, DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_RUNTIME } from "./globals" -export type RecursivePartialNull = T extends object - ? { - [P in keyof T]?: RecursivePartialNull; - } - : T | null; - export class BamlSyncClient { private runtime: BamlRuntime private ctx_manager: BamlCtxManager @@ -57,7 +51,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Recipe + return raw.parsed(false) as Recipe } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -82,7 +76,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as LinkedListAliasNode + return raw.parsed(false) as LinkedListAliasNode } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -107,7 +101,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Checked + return raw.parsed(false) as Checked } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -132,7 +126,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -157,7 +151,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -182,7 +176,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -207,7 +201,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -232,7 +226,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -257,7 +251,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as OptionalListAndMap + return raw.parsed(false) as OptionalListAndMap } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -282,7 +276,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as number + return raw.parsed(false) as number } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -307,7 +301,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -332,7 +326,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as LinkedList + return raw.parsed(false) as LinkedList } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -357,7 +351,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Tree + return raw.parsed(false) as Tree } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -382,7 +376,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as ClassToRecAlias + return raw.parsed(false) as ClassToRecAlias } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -407,7 +401,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as (string | DynEnumTwo) + return raw.parsed(false) as (string | DynEnumTwo) } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -432,7 +426,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Category + return raw.parsed(false) as Category } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -457,7 +451,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Category + return raw.parsed(false) as Category } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -482,7 +476,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Category + return raw.parsed(false) as Category } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -507,7 +501,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -532,7 +526,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as BookOrder | FlightConfirmation | GroceryReceipt + return raw.parsed(false) as BookOrder | FlightConfirmation | GroceryReceipt } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -557,7 +551,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -582,7 +576,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -607,7 +601,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -632,7 +626,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -657,7 +651,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as OriginalA | OriginalB + return raw.parsed(false) as OriginalA | OriginalB } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -682,7 +676,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as DummyOutput + return raw.parsed(false) as DummyOutput } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -707,7 +701,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as DynamicClassTwo + return raw.parsed(false) as DynamicClassTwo } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -732,7 +726,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as DynInputOutput + return raw.parsed(false) as DynInputOutput } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -757,7 +751,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as DynInputOutput[] + return raw.parsed(false) as DynInputOutput[] } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -782,7 +776,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -807,7 +801,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as ContactInfo + return raw.parsed(false) as ContactInfo } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -832,7 +826,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as (string | Hobby)[] + return raw.parsed(false) as (string | Hobby)[] } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -857,7 +851,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string[] + return raw.parsed(false) as string[] } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -882,7 +876,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Person[] + return raw.parsed(false) as Person[] } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -907,7 +901,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as ReceiptInfo + return raw.parsed(false) as ReceiptInfo } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -932,7 +926,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Resume + return raw.parsed(false) as Resume } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -957,7 +951,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Resume + return raw.parsed(false) as Resume } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -982,7 +976,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as ClassOptionalOutput | null + return raw.parsed(false) as ClassOptionalOutput | null } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1007,7 +1001,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as ClassOptionalOutput2 | null + return raw.parsed(false) as ClassOptionalOutput2 | null } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1032,7 +1026,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as EnumOutput[] + return raw.parsed(false) as EnumOutput[] } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1057,7 +1051,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as EnumOutput + return raw.parsed(false) as EnumOutput } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1082,7 +1076,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as LiteralClassHello + return raw.parsed(false) as LiteralClassHello } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1107,7 +1101,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as LiteralClassOne | LiteralClassTwo + return raw.parsed(false) as LiteralClassOne | LiteralClassTwo } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1132,7 +1126,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1157,7 +1151,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as boolean + return raw.parsed(false) as boolean } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1182,7 +1176,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as TestOutputClass + return raw.parsed(false) as TestOutputClass } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1207,7 +1201,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as TestOutputClass[] + return raw.parsed(false) as TestOutputClass[] } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1232,7 +1226,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as TestClassNested + return raw.parsed(false) as TestClassNested } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1257,7 +1251,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as TestClassWithEnum + return raw.parsed(false) as TestClassWithEnum } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1282,7 +1276,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as number + return raw.parsed(false) as number } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1307,7 +1301,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as false + return raw.parsed(false) as false } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1332,7 +1326,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as 5 + return raw.parsed(false) as 5 } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1357,7 +1351,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as "example output" + return raw.parsed(false) as "example output" } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1382,7 +1376,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string[] + return raw.parsed(false) as string[] } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1407,7 +1401,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as TestEnum + return raw.parsed(false) as TestEnum } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1432,7 +1426,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as TestClassAlias + return raw.parsed(false) as TestClassAlias } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1457,7 +1451,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1482,7 +1476,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as RaysData + return raw.parsed(false) as RaysData } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1507,7 +1501,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as OrderInfo + return raw.parsed(false) as OrderInfo } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1532,7 +1526,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as SearchParams + return raw.parsed(false) as SearchParams } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1544,9 +1538,9 @@ export class BamlSyncClient { } InOutEnumMapKey( - i1: Partial>,i2: Partial>, + i1: Partial>,i2: Partial>, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } - ): Partial> { + ): Partial> { try { const raw = this.runtime.callFunctionSync( "InOutEnumMapKey", @@ -1557,7 +1551,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Partial> + return raw.parsed(false) as Partial> } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1582,7 +1576,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Partial> + return raw.parsed(false) as Partial> } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1607,7 +1601,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Partial> + return raw.parsed(false) as Partial> } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1632,7 +1626,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as JsonValue + return raw.parsed(false) as JsonValue } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1657,7 +1651,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as 1 | true | "string output" + return raw.parsed(false) as 1 | true | "string output" } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1682,7 +1676,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Checked + return raw.parsed(false) as Checked } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1707,7 +1701,32 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as NestedBlockConstraint + return raw.parsed(false) as NestedBlockConstraint + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + + MakeSemanticContainer( + + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): SemanticContainer { + try { + const raw = this.runtime.callFunctionSync( + "MakeSemanticContainer", + { + + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed(false) as SemanticContainer } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1732,7 +1751,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Record + return raw.parsed(false) as Record } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1757,7 +1776,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as MergeAttrs + return raw.parsed(false) as MergeAttrs } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1782,7 +1801,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as DynamicOutput + return raw.parsed(false) as DynamicOutput } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1807,7 +1826,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as number | string | boolean | number | string[] | Record + return raw.parsed(false) as number | string | boolean | number | string[] | Record } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1832,7 +1851,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as ClassForNullLiteral + return raw.parsed(false) as ClassForNullLiteral } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1857,7 +1876,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as (OptionalTest_ReturnType | null)[] + return raw.parsed(false) as (OptionalTest_ReturnType | null)[] } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1882,7 +1901,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as FooAny + return raw.parsed(false) as FooAny } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1907,7 +1926,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Checked + return raw.parsed(false) as Checked } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1932,7 +1951,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as number | string | boolean | number + return raw.parsed(false) as number | string | boolean | number } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1957,7 +1976,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -1982,7 +2001,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2007,7 +2026,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2032,7 +2051,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2057,7 +2076,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2082,7 +2101,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2107,7 +2126,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2132,7 +2151,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as RecAliasOne + return raw.parsed(false) as RecAliasOne } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2157,7 +2176,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as NodeWithAliasIndirection + return raw.parsed(false) as NodeWithAliasIndirection } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2182,7 +2201,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Checked + return raw.parsed(false) as Checked } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2207,7 +2226,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as number + return raw.parsed(false) as number } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2232,7 +2251,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as MalformedConstraints + return raw.parsed(false) as MalformedConstraints } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2257,7 +2276,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Schema + return raw.parsed(false) as Schema } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2282,7 +2301,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as RecursiveListAlias + return raw.parsed(false) as RecursiveListAlias } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2307,7 +2326,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as RecursiveMapAlias + return raw.parsed(false) as RecursiveMapAlias } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2332,7 +2351,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as BigNumbers + return raw.parsed(false) as BigNumbers } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2357,7 +2376,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as TwoStoriesOneTitle + return raw.parsed(false) as TwoStoriesOneTitle } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2382,7 +2401,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as number + return raw.parsed(false) as number } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2407,7 +2426,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as (number | string)[] + return raw.parsed(false) as (number | string)[] } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2432,7 +2451,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as CompoundBigNumbers + return raw.parsed(false) as CompoundBigNumbers } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2457,7 +2476,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2482,7 +2501,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2507,7 +2526,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2532,7 +2551,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2557,7 +2576,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2582,7 +2601,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2607,7 +2626,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2632,7 +2651,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2657,7 +2676,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2682,7 +2701,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2707,7 +2726,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2732,7 +2751,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2757,7 +2776,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2782,7 +2801,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2807,7 +2826,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2832,7 +2851,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2857,7 +2876,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2882,7 +2901,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Record + return raw.parsed(false) as Record } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2907,7 +2926,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Record> + return raw.parsed(false) as Record> } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2932,7 +2951,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as Record + return raw.parsed(false) as Record } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2957,7 +2976,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -2982,7 +3001,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3007,7 +3026,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3032,7 +3051,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3057,7 +3076,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3082,7 +3101,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3107,7 +3126,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3132,7 +3151,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3157,7 +3176,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3182,7 +3201,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3207,7 +3226,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3232,7 +3251,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3257,7 +3276,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3282,7 +3301,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3307,7 +3326,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3332,7 +3351,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3357,7 +3376,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3382,7 +3401,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3407,7 +3426,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3432,7 +3451,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3457,7 +3476,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3482,7 +3501,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as string + return raw.parsed(false) as string } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3507,7 +3526,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as UnionTest_ReturnType + return raw.parsed(false) as UnionTest_ReturnType } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3532,7 +3551,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as number + return raw.parsed(false) as number } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3557,7 +3576,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as number + return raw.parsed(false) as number } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { @@ -3582,7 +3601,7 @@ export class BamlSyncClient { __baml_options__?.tb?.__tb(), __baml_options__?.clientRegistry, ) - return raw.parsed() as number + return raw.parsed(false) as number } catch (error: any) { const bamlError = createBamlValidationError(error); if (bamlError instanceof BamlValidationError) { diff --git a/integ-tests/typescript/baml_client/type_builder.ts b/integ-tests/typescript/baml_client/type_builder.ts index df9d42322b..ebedf4d41e 100644 --- a/integ-tests/typescript/baml_client/type_builder.ts +++ b/integ-tests/typescript/baml_client/type_builder.ts @@ -50,7 +50,7 @@ export default class TypeBuilder { constructor() { this.tb = new _TypeBuilder({ classes: new Set([ - "BigNumbers","BinaryNode","Blah","BlockConstraint","BlockConstraintForParam","BookOrder","ClassForNullLiteral","ClassOptionalOutput","ClassOptionalOutput2","ClassToRecAlias","ClassWithImage","CompoundBigNumbers","ContactInfo","CustomTaskResult","DummyOutput","DynInputOutput","DynamicClassOne","DynamicClassTwo","DynamicOutput","Earthling","Education","Email","EmailAddress","Event","FakeImage","FlightConfirmation","FooAny","Forest","FormatterTest0","FormatterTest1","FormatterTest2","FormatterTest3","GroceryReceipt","InnerClass","InnerClass2","InputClass","InputClassNested","LinkedList","LinkedListAliasNode","LiteralClassHello","LiteralClassOne","LiteralClassTwo","MalformedConstraints","MalformedConstraints2","Martian","MergeAttrs","NamedArgsSingleClass","Nested","Nested2","NestedBlockConstraint","NestedBlockConstraintForParam","Node","NodeWithAliasIndirection","OptionalListAndMap","OptionalTest_Prop1","OptionalTest_ReturnType","OrderInfo","OriginalA","OriginalB","Person","PhoneNumber","Quantity","RaysData","ReceiptInfo","ReceiptItem","Recipe","Resume","Schema","SearchParams","SomeClassNestedDynamic","StringToClassEntry","TestClassAlias","TestClassNested","TestClassWithEnum","TestOutputClass","Tree","TwoStoriesOneTitle","UnionTest_ReturnType","WithReasoning", + "BigNumbers","BinaryNode","Blah","BlockConstraint","BlockConstraintForParam","BookOrder","ClassForNullLiteral","ClassOptionalOutput","ClassOptionalOutput2","ClassToRecAlias","ClassWithBlockDone","ClassWithImage","ClassWithoutDone","CompoundBigNumbers","ContactInfo","CustomTaskResult","DummyOutput","DynInputOutput","DynamicClassOne","DynamicClassTwo","DynamicOutput","Earthling","Education","Email","EmailAddress","Event","FakeImage","FlightConfirmation","FooAny","Forest","FormatterTest0","FormatterTest1","FormatterTest2","FormatterTest3","GroceryReceipt","InnerClass","InnerClass2","InputClass","InputClassNested","LinkedList","LinkedListAliasNode","LiteralClassHello","LiteralClassOne","LiteralClassTwo","MalformedConstraints","MalformedConstraints2","Martian","MergeAttrs","NamedArgsSingleClass","Nested","Nested2","NestedBlockConstraint","NestedBlockConstraintForParam","Node","NodeWithAliasIndirection","OptionalListAndMap","OptionalTest_Prop1","OptionalTest_ReturnType","OrderInfo","OriginalA","OriginalB","Person","PhoneNumber","Quantity","RaysData","ReceiptInfo","ReceiptItem","Recipe","Resume","Schema","SearchParams","SemanticContainer","SmallThing","SomeClassNestedDynamic","StringToClassEntry","TestClassAlias","TestClassNested","TestClassWithEnum","TestOutputClass","Tree","TwoStoriesOneTitle","UnionTest_ReturnType","WithReasoning", ]), enums: new Set([ "AliasedEnum","Category","Category2","Category3","Color","DataType","DynEnumOne","DynEnumTwo","EnumInClass","EnumOutput","Hobby","MapKey","NamedArgsSingleEnum","NamedArgsSingleEnumList","OptionalTest_CategoryType","OrderStatus","Tag","TestEnum", diff --git a/integ-tests/typescript/baml_client/types.ts b/integ-tests/typescript/baml_client/types.ts index 74dd06caa5..e0c8ea8ff0 100644 --- a/integ-tests/typescript/baml_client/types.ts +++ b/integ-tests/typescript/baml_client/types.ts @@ -217,6 +217,12 @@ export interface ClassToRecAlias { } +export interface ClassWithBlockDone { + i_16_digits: number + s_20_words: string + +} + export interface ClassWithImage { myImage: Image param2: string @@ -224,6 +230,12 @@ export interface ClassWithImage { } +export interface ClassWithoutDone { + i_16_digits: number + s_20_words: string + +} + export interface CompoundBigNumbers { big: BigNumbers big_nums: BigNumbers[] @@ -602,6 +614,24 @@ export interface SearchParams { } +export interface SemanticContainer { + sixteen_digit_number: number + string_with_twenty_words: string + class_1: ClassWithoutDone + class_2: ClassWithBlockDone + class_done_needed: ClassWithBlockDone + class_needed: ClassWithoutDone + three_small_things: SmallThing[] + final_string: string + +} + +export interface SmallThing { + i_16_digits: number + i_8_digits: number + +} + export interface SomeClassNestedDynamic { hi: string diff --git a/integ-tests/typescript/tests/input-output.test.ts b/integ-tests/typescript/tests/input-output.test.ts index 15025b6080..54825dc2c5 100644 --- a/integ-tests/typescript/tests/input-output.test.ts +++ b/integ-tests/typescript/tests/input-output.test.ts @@ -1,4 +1,5 @@ import { NamedArgsSingleEnumList } from '../baml_client' +import { SemanticContainer } from '../baml_client/partial_types'; import { b } from './test-setup' describe('Basic Input/Output Tests', () => { @@ -80,3 +81,43 @@ describe('Basic Input/Output Tests', () => { }) }) }) + +describe('Semantic Streaming Tests', () => { + it('should support semantic streaming', async () => { + const stream = b.stream.MakeSemanticContainer() + + let reference_string = null; + let reference_int = null; + + const msgs: SemanticContainer[] = [] + for await (const msg of stream) { + msgs.push(msg ?? '') + + // Test field stability. + if (msg.sixteen_digit_number != null){ + if (reference_int == null) { + reference_int = msg.sixteen_digit_number; + } else { + expect(msg.sixteen_digit_number).toEqual(reference_int); + } + } + + // Test @stream.with_state. + if (msg.class_needed.s_20_words.value && msg.class_needed.s_20_words.value.split(" ").length < 3 && msg.final_string == null) { + console.log(msg) + expect(msg.class_needed.s_20_words.state).toEqual("Incomplete"); + } + if (msg.final_string) { + expect(msg.class_needed.s_20_words.state).toEqual("Complete"); + } + + // Test @stream.not_null. + for (const sub of msg.three_small_things) { + expect(sub.i_16_digits).toBeDefined(); + } + } + + const final = await stream.getFinalResponse(); + console.log(final); + }, 20_000) +}) \ No newline at end of file diff --git a/integ-tests/typescript/tests/integ-tests.test.ts.old b/integ-tests/typescript/tests/integ-tests.test.ts.old index da6de8fd63..6fa6aa8257 100644 --- a/integ-tests/typescript/tests/integ-tests.test.ts.old +++ b/integ-tests/typescript/tests/integ-tests.test.ts.old @@ -15,8 +15,10 @@ import { onLogEvent, AliasedEnum, MapKey, + ClassWithDone, + ClassWithBlockDone, } from '../baml_client' -import { RecursivePartialNull } from '../baml_client/async_client' +import partial_types from '../baml_client/partial_types'; import { b as b_sync } from '../baml_client/sync_client' import { config } from 'dotenv' import { BamlLogEvent, BamlRuntime } from '@boundaryml/baml/native' @@ -224,6 +226,7 @@ describe('Integ tests', () => { it('merge alias attributes', async () => { const res = await b.MergeAliasAttributes(123) + console.log(JSON.stringify(res)); expect(res.amount.value).toEqual(123) expect(res.amount.checks['gt_ten'].status).toEqual('succeeded') }) @@ -305,6 +308,7 @@ describe('Integ tests', () => { expect(literal_string).toEqual('example output') const list = await b.FnOutputClassList(a) + console.log(list); expect(list.length).toBeGreaterThan(0) assert(list[0].prop1.length > 0) @@ -762,7 +766,7 @@ describe('Integ tests', () => { it('should work with nested classes', async () => { let stream = b.stream.FnOutputClassNested('hi!') - let msgs: RecursivePartialNull = [] + let msgs: partial_types.TestClassNested[] = [] for await (const msg of stream) { console.log('msg', msg) msgs.push(msg) @@ -1062,3 +1066,12 @@ describe('Integ tests', () => { flush() }) }) + + it('should support semantic streaming', async () => { + const stream = b.stream.MakeClassWithDone() + const msgs: ClassWithDone[] = [] + for await (const msg of stream) { + msgs.push(msg ?? '') + } + const final = await stream.getFinalResponse() + }, 20_000) \ No newline at end of file From 5b02c08eaf981a6131d7b64ec58a0cace18e2547 Mon Sep 17 00:00:00 2001 From: Greg Hale Date: Tue, 28 Jan 2025 11:27:16 -0800 Subject: [PATCH 02/13] Fix ruby ffi --- engine/language_client_ruby/ext/ruby_ffi/src/ruby_to_json.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/engine/language_client_ruby/ext/ruby_ffi/src/ruby_to_json.rs b/engine/language_client_ruby/ext/ruby_ffi/src/ruby_to_json.rs index fd2b910137..cd473f2b96 100644 --- a/engine/language_client_ruby/ext/ruby_ffi/src/ruby_to_json.rs +++ b/engine/language_client_ruby/ext/ruby_ffi/src/ruby_to_json.rs @@ -124,8 +124,7 @@ impl<'rb> RubyToJson<'rb> { }; let backup_class = match backup_module.const_get::<_, RClass>(class_name.as_str()) { Ok(class_type) => class_type, - // Err(_) => ruby.eval::("Baml::DynamicStruct")?, - Err(_) => unreachable!("trying to avoid these"), + Err(_) => ruby.eval::("Baml::DynamicStruct")?, }; match preferred_class.funcall("new", (hash,)) { Ok(res) => Ok(res), From 8834560b5acaaec1765f46b26e1a92a6e9dc9720 Mon Sep 17 00:00:00 2001 From: Greg Hale Date: Tue, 28 Jan 2025 11:41:02 -0800 Subject: [PATCH 03/13] remove logging from a test --- integ-tests/typescript/tests/input-output.test.ts | 2 -- 1 file changed, 2 deletions(-) diff --git a/integ-tests/typescript/tests/input-output.test.ts b/integ-tests/typescript/tests/input-output.test.ts index 54825dc2c5..91772c7547 100644 --- a/integ-tests/typescript/tests/input-output.test.ts +++ b/integ-tests/typescript/tests/input-output.test.ts @@ -104,7 +104,6 @@ describe('Semantic Streaming Tests', () => { // Test @stream.with_state. if (msg.class_needed.s_20_words.value && msg.class_needed.s_20_words.value.split(" ").length < 3 && msg.final_string == null) { - console.log(msg) expect(msg.class_needed.s_20_words.state).toEqual("Incomplete"); } if (msg.final_string) { @@ -118,6 +117,5 @@ describe('Semantic Streaming Tests', () => { } const final = await stream.getFinalResponse(); - console.log(final); }, 20_000) }) \ No newline at end of file From 69f4d219102f0b4f0a26d817ab0cfe0dcf49af69 Mon Sep 17 00:00:00 2001 From: Greg Hale Date: Tue, 28 Jan 2025 11:56:28 -0800 Subject: [PATCH 04/13] fix unit tests --- .../baml-lib/jsonish/src/tests/animation.rs | 2 +- engine/baml-lib/jsonish/src/tests/macros.rs | 2 +- .../src/internal/llm_client/mod.rs | 21 +------------------ 3 files changed, 3 insertions(+), 22 deletions(-) diff --git a/engine/baml-lib/jsonish/src/tests/animation.rs b/engine/baml-lib/jsonish/src/tests/animation.rs index 9200dcb078..9cd7cd0ea6 100644 --- a/engine/baml-lib/jsonish/src/tests/animation.rs +++ b/engine/baml-lib/jsonish/src/tests/animation.rs @@ -43,7 +43,7 @@ pub fn make_test_data1() { serde_json::to_value(&vec![ serde_json::to_value(partial_llm_data).unwrap(), - serde_json::to_value(&value).unwrap(), + serde_json::to_value(&value.serialize_partial()).unwrap(), ]) .unwrap() }) diff --git a/engine/baml-lib/jsonish/src/tests/macros.rs b/engine/baml-lib/jsonish/src/tests/macros.rs index 241be5212d..74ddc6f860 100644 --- a/engine/baml-lib/jsonish/src/tests/macros.rs +++ b/engine/baml-lib/jsonish/src/tests/macros.rs @@ -142,7 +142,7 @@ macro_rules! test_partial_deserializer_streaming { let value = result; log::trace!("Score: {}", value.score()); - let json_value = json!(value); + let json_value = json!(value.serialize_partial()); let expected = serde_json::json!($($json)+); diff --git a/engine/baml-runtime/src/internal/llm_client/mod.rs b/engine/baml-runtime/src/internal/llm_client/mod.rs index a404a7dc40..6293f072b0 100644 --- a/engine/baml-runtime/src/internal/llm_client/mod.rs +++ b/engine/baml-runtime/src/internal/llm_client/mod.rs @@ -536,7 +536,7 @@ mod tests { let res = parsed_value_to_response(&ir, value, &field_type, true).unwrap(); - let json = serde_json::to_value(&res).unwrap(); + let json = serde_json::to_value(res.serialize_final()).unwrap(); match &json { serde_json::Value::Object(items) => { @@ -546,23 +546,4 @@ mod tests { _ => panic!("Expected json object"), } } - - #[test] - fn integ_test_failure() { - let ir = make_test_ir(r#" - class Foo { - prop1 string - prop2 int - } - "#).unwrap(); - let target_type = FieldType::class("Foo"); - let target = jsonish::helpers::render_output_format(&ir, &target_type, &Default::default()).unwrap(); - - let msg = r#"{"prop1": "something", "prop2": 2}"#; - - let parsed = jsonish::from_str(&target, &target_type, msg, true).unwrap(); - let response = parsed_value_to_response(&ir, parsed, &target_type, true).unwrap(); - let json = serde_json::to_string(&response).unwrap(); - assert_eq!(json, r#"{"prop1":"something","prop2":2}"#); - } } \ No newline at end of file From e096d965ba1aac2457e541e9f39038c98a68604f Mon Sep 17 00:00:00 2001 From: Greg Hale Date: Tue, 28 Jan 2025 12:04:10 -0800 Subject: [PATCH 05/13] Drop more-efficient semantic streaming fns --- .../src/deserializer/semantic_streaming.rs | 156 ------------------ engine/baml-lib/jsonish/src/helpers/mod.rs | 64 +++---- .../src/internal/llm_client/mod.rs | 31 ---- 3 files changed, 32 insertions(+), 219 deletions(-) diff --git a/engine/baml-lib/jsonish/src/deserializer/semantic_streaming.rs b/engine/baml-lib/jsonish/src/deserializer/semantic_streaming.rs index 51e61c73a7..23adb3f7f3 100644 --- a/engine/baml-lib/jsonish/src/deserializer/semantic_streaming.rs +++ b/engine/baml-lib/jsonish/src/deserializer/semantic_streaming.rs @@ -49,23 +49,6 @@ pub fn validate_streaming_state( res } -/// Like validate_state, but specialized to the metadata we happen to have already. -/// (This is a performance hack to allow us to skip several map_meta, zip_meta -/// steps). -pub fn validate_streaming_state2( - ir: &IntermediateRepr, - baml_value: BamlValueWithMeta<(Vec, Vec)>, - field_type: &FieldType, - allow_partials: bool, -) -> Result< - BamlValueWithMeta<(Vec, Vec, Completion)>, - StreamingError, -> { - let typed_baml_value = ir.distribute_type_with_meta(baml_value, field_type.clone())?; - let res = process_node2(ir, typed_baml_value, allow_partials); - res -} - /// Consider a node's type, streaming state, and streaming behavior annotations. Return /// an error if streaming state doesn't meet the streaming requirements. Also attach /// the streaming state to the node as metadata, if this was requested by the user @@ -93,11 +76,6 @@ fn process_node( display: streaming_behavior.state, required_done: must_be_done, }; - // let new_meta = if streaming_behavior.state && allow_partials { - // Some(completion_state.clone()) - // } else { - // None - // }; if must_be_done && allow_partials && !(completion_state == CompletionState::Complete) { return Err(StreamingError::IncompleteDoneValue); @@ -234,140 +212,6 @@ fn process_node( new_value } -fn process_node2( - ir: &IntermediateRepr, - value: BamlValueWithMeta<((Vec, Vec), FieldType)>, - allow_partials: bool, -) -> Result< - BamlValueWithMeta<(Vec, Vec, Completion)>, - StreamingError, -> { - // let value_copy = value.clone(); - let ((flags, checks), field_type) = value.meta().clone(); - let complete = completion_state(&flags); - let (base_type, (_, streaming_behavior)) = ir.distribute_metadata(&field_type); - - let must_be_done = required_done(ir, &field_type) && allow_partials; - - let new_meta = ( - flags, - checks, - Completion { - state: complete.clone(), - display: streaming_behavior.state, - required_done: must_be_done, - } - ); - - if must_be_done && !(complete == CompletionState::Complete) { - return Err(StreamingError::IncompleteDoneValue); - // return Ok(BamlValueWithMeta::Null(new_meta)) - } - - let new_value = match value { - BamlValueWithMeta::String(s, _) => Ok(BamlValueWithMeta::String(s, new_meta)), - BamlValueWithMeta::Media(m, _) => Ok(BamlValueWithMeta::Media(m, new_meta)), - BamlValueWithMeta::Null(_) => Ok(BamlValueWithMeta::Null(new_meta)), - BamlValueWithMeta::Int(i, _) => Ok(BamlValueWithMeta::Int(i, new_meta)), - BamlValueWithMeta::Float(f, _) => Ok(BamlValueWithMeta::Float(f, new_meta)), - BamlValueWithMeta::Bool(b, _) => Ok(BamlValueWithMeta::Bool(b, new_meta)), - BamlValueWithMeta::List(items, _) => Ok(BamlValueWithMeta::List( - items - .into_iter() - .filter_map(|item| process_node2(ir, item, allow_partials).ok()) - .collect(), - new_meta, - )), - BamlValueWithMeta::Class(ref class_name, fields, _) => { - let field_names: HashSet = - fields.keys().into_iter().map(|s| s.to_string()).collect(); - let needed_fields: HashSet = needed_fields(ir, &field_type, allow_partials)?; - // let missing_needed_fields = needed_fields.difference(&new_field_names); - let present_nonnull_fields: HashSet = fields.iter().filter_map(|(field_name, field_value)| { - if matches!(field_value, BamlValueWithMeta::Null(_)) { - None - } else { - Some(field_name.to_string()) - } - }).collect(); - - let missing_needed_fields: HashSet<&String> = needed_fields.difference(&present_nonnull_fields).into_iter().collect(); - let unneeded_fields = field_names.difference(&needed_fields); - - let fields_needing_null = - fields_needing_null_filler(ir, &field_type, field_names, allow_partials)?; - - let mut deleted_fields: HashMap< - String, - BamlValueWithMeta<(Vec, Vec, Completion)>, - > = HashMap::new(); - - // let unneeded_fields = field_names.difference(&needed_fields); - let needed_nulls = fields_needing_null - .into_iter() - .filter_map(|ref null_field_name| { - let field = fields - .get(null_field_name) - .expect("This field is guaranteed to be in the field set"); - let use_state = type_streaming_behavior(ir, &field.meta().1).state; - let field_stream_state = Completion { - state: CompletionState::Incomplete, - display: use_state, - required_done: false, - }; - Some(( - null_field_name.to_string(), - BamlValueWithMeta::Null((Vec::new(), Vec::new(), field_stream_state)), - )) - }) - .collect::>>(); - - let mut new_fields = fields - .into_iter() - .filter_map(|(field_name, field_value)| { - let with_state = field_value - .meta() - .1 - .streaming_behavior() - .as_ref() - .map_or(false, |b| b.state); - let complete: CompletionState = completion_state(&field_value.meta().0 .0); - - match process_node2(ir, field_value, allow_partials) { - Ok(res) => Some((field_name, res)), - _ => { - let state = Completion { state: complete, display: with_state, required_done: false }; - let null = BamlValueWithMeta::Null((Vec::new(), Vec::new(), state)); - deleted_fields.insert(field_name, null); - None - } - } - }) - .collect::>>(); - - new_fields.extend(needed_nulls); - new_fields.extend(deleted_fields); - let res = BamlValueWithMeta::Class(class_name.clone(), new_fields, new_meta); - if missing_needed_fields.clone().len() == 0 { - Ok(res) - } else { - Err(StreamingError::MissingNeededFields) - } - } - BamlValueWithMeta::Enum(name, value, _) => { - Ok(BamlValueWithMeta::Enum(name, value, new_meta)) - } - BamlValueWithMeta::Map(kvs, _) => { - let new_kvs = kvs - .into_iter() - .filter_map(|(k, v)| process_node2(ir, v, allow_partials).ok().map(|v| (k, v))) - .collect(); - Ok(BamlValueWithMeta::Map(new_kvs, new_meta)) - } - }; - - new_value -} fn fields_needing_null_filler<'a>( ir: &'a IntermediateRepr, diff --git a/engine/baml-lib/jsonish/src/helpers/mod.rs b/engine/baml-lib/jsonish/src/helpers/mod.rs index e7d4f57d93..8db65ca486 100644 --- a/engine/baml-lib/jsonish/src/helpers/mod.rs +++ b/engine/baml-lib/jsonish/src/helpers/mod.rs @@ -15,7 +15,7 @@ use internal_baml_jinja::types::{Builder, Name, OutputFormatContent}; use internal_baml_jinja::types::{Class, Enum}; use crate::deserializer::deserialize_flags::{constraint_results, Flag}; -use crate::deserializer::semantic_streaming::validate_streaming_state2; +use crate::deserializer::semantic_streaming::validate_streaming_state; use crate::{BamlValueWithFlags, ResponseBamlValue}; pub fn load_test_ir(file_content: &str) -> IntermediateRepr { @@ -256,34 +256,34 @@ fn relevant_data_models<'a>( )) } -/// Validate a parsed value, checking asserts and checks. -pub fn parsed_value_to_response( - ir: &IntermediateRepr, - baml_value: BamlValueWithFlags, - field_type: &FieldType, - allow_partials: bool, -) -> Result { - let meta_flags: BamlValueWithMeta> = baml_value.into(); - - let baml_value_with_streaming2 = meta_flags.map_meta_owned(|flags| { - let constraint_results = constraint_results(&flags); - let response_checks: Vec = constraint_results - .iter() - .map(|(label, expr, result)| { - let status = (if *result { "succeeded" } else { "failed" }).to_string(); - ResponseCheck { - name: label.clone(), - expression: expr.0.clone(), - status, - } - }) - .collect(); - (flags, response_checks) - }); - - let response_value2 = - validate_streaming_state2(ir, baml_value_with_streaming2, field_type, allow_partials) - .map_err(|s| anyhow::anyhow!("{s}"))?; - - Ok(crate::ResponseBamlValue(response_value2)) -} +// /// Validate a parsed value, checking asserts and checks. +// pub fn parsed_value_to_response( +// ir: &IntermediateRepr, +// baml_value: BamlValueWithFlags, +// field_type: &FieldType, +// allow_partials: bool, +// ) -> Result { +// let meta_flags: BamlValueWithMeta> = baml_value.into(); +// +// let baml_value_with_streaming = meta_flags.map_meta_owned(|flags| { +// let constraint_results = constraint_results(&flags); +// let response_checks: Vec = constraint_results +// .iter() +// .map(|(label, expr, result)| { +// let status = (if *result { "succeeded" } else { "failed" }).to_string(); +// ResponseCheck { +// name: label.clone(), +// expression: expr.0.clone(), +// status, +// } +// }) +// .collect(); +// (flags, response_checks) +// }); +// +// let response_value2 = +// validate_streaming_state(ir, baml_value_with_streaming, field_type, allow_partials) +// .map_err(|s| anyhow::anyhow!("{s}"))?; +// +// Ok(crate::ResponseBamlValue(response_value)) +// } diff --git a/engine/baml-runtime/src/internal/llm_client/mod.rs b/engine/baml-runtime/src/internal/llm_client/mod.rs index 6293f072b0..9e9d0a6a34 100644 --- a/engine/baml-runtime/src/internal/llm_client/mod.rs +++ b/engine/baml-runtime/src/internal/llm_client/mod.rs @@ -20,7 +20,6 @@ use jsonish::{ deserializer::{ deserialize_flags::{constraint_results, DeserializerConditions, Flag}, semantic_streaming::validate_streaming_state, - semantic_streaming::validate_streaming_state2, }, BamlValueWithFlags, }; @@ -71,36 +70,6 @@ pub fn parsed_value_to_response( Ok(ResponseBamlValue(response_value)) } -/// Validate a parsed value, checking asserts and checks. -pub fn parsed_value_to_response2( - ir: &IntermediateRepr, - baml_value: BamlValueWithFlags, - field_type: &FieldType, - allow_partials: bool, -) -> Result { - let meta_flags: BamlValueWithMeta> = baml_value.into(); - let baml_value_with_streaming2 = meta_flags.map_meta_owned(|flags| { - let constraint_results = constraint_results(&flags); - let response_checks: Vec = constraint_results - .iter() - .map(|(label, expr, result)| { - let status = (if *result { "succeeded" } else { "failed" }).to_string(); - ResponseCheck { - name: label.clone(), - expression: expr.0.clone(), - status, - } - }) - .collect(); - (flags, response_checks) - }); - - let response_value2 = - validate_streaming_state2(ir, baml_value_with_streaming2, field_type, allow_partials) - .map_err(|s| anyhow::anyhow!("TODO {s:?}"))?; - - Ok(ResponseBamlValue(response_value2)) -} #[derive(Clone, Copy, PartialEq)] pub enum ResolveMediaUrls { From 5ca94e84f5aa33a20b9a178cf14d2479e84e94d5 Mon Sep 17 00:00:00 2001 From: Greg Hale Date: Tue, 28 Jan 2025 12:12:19 -0800 Subject: [PATCH 06/13] Cleanup printf debugging --- .../baml-core/src/ir/ir_helpers/mod.rs | 40 ------------------- engine/baml-lib/baml-types/src/baml_value.rs | 1 - .../coercer/ir_ref/coerce_class.rs | 1 - .../src/deserializer/semantic_streaming.rs | 10 ----- .../jsonish/src/jsonish/parser/entry.rs | 2 - .../src/jsonish/parser/fixing_parser.rs | 4 -- .../src/jsonish/parser/markdown_parser.rs | 1 - .../baml-lib/jsonish/src/tests/test_maps.rs | 1 - .../src/runtime/runtime_interface.rs | 2 - .../language_client_codegen/src/python/mod.rs | 1 - .../src/typescript/mod.rs | 1 - .../src/types/function_results.rs | 9 ----- .../ext/ruby_ffi/src/function_result.rs | 4 -- .../ext/ruby_ffi/src/ruby_to_json.rs | 10 ----- 14 files changed, 87 deletions(-) diff --git a/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs b/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs index 97a949eda8..94041ee011 100644 --- a/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs +++ b/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs @@ -605,25 +605,6 @@ impl IRHelper for IntermediateRepr { .collect::>>>( )?; - // let item_types: Vec<&FieldType> = mapped_fields - // .values() - // .map(|i| &i.meta().1) - // .dedup() - // .collect(); - // let items_type = match item_types.len() { - // 0 => None, - // 1 => Some(item_types[0].clone()), - // _ => Some(FieldType::Union( - // item_types.into_iter().map(|t| t.clone()).collect(), - // )), - // }; - // if let Some((key_ty, value_ty)) = map_types(self, &field_type) { - // let expected_type = FieldType::Map(Box::new(key_ty.clone()), Box::new(value_ty.clone())); - // if !self.is_subtype(&expected_type, &field_base_type) { - // anyhow::bail!("Could not unify {:?} with {:?}", expected_type, field_base_type); - // } - // } - Ok(BamlValueWithMeta::Map(mapped_fields, (meta, field_type))) } @@ -631,32 +612,13 @@ impl IRHelper for IntermediateRepr { let new_items = items .into_iter() .map(|i| { - // dbg!(&field_type); - // dbg!(&i); item_type(self, &field_type, &i) .ok_or({ - eprintln!("ty: {field_type:?}, i: {i:?}"); anyhow::anyhow!("Could not infer child type") }) .and_then(|item_type| self.distribute_type_with_meta(i, item_type)) }) .collect::>>()?; - // dbg!(&new_items); - // let item_types: Vec<&FieldType> = - // new_items.iter().map(|i| &i.meta().1).dedup().collect(); - // let items_type = match item_types.len() { - // 0 => None, - // 1 => Some(item_types[0].clone()), - // _ => Some(FieldType::Union( - // item_types.into_iter().map(|t| t.clone()).collect(), - // )), - // }; - // if let Some(ty) = items_type { - // let expected_type = FieldType::List(Box::new(ty)); - // if !self.is_subtype(&expected_type, &field_base_type) { - // anyhow::bail!("Could not unify {:?} with {:?}", expected_type, field_base_type); - // } - // } Ok(BamlValueWithMeta::List(new_items, (meta, field_type))) } @@ -845,8 +807,6 @@ fn item_type( field_type: &FieldType, baml_child_values: &BamlValueWithMeta, ) -> Option { - // dbg!(&baml_child_value); - // dbg!(&field_type); let res = match ir.distribute_metadata(field_type).0 { FieldType::Class(_) => None, FieldType::Enum(_) => None, diff --git a/engine/baml-lib/baml-types/src/baml_value.rs b/engine/baml-lib/baml-types/src/baml_value.rs index 61e5b3af91..7f05a4ef64 100644 --- a/engine/baml-lib/baml-types/src/baml_value.rs +++ b/engine/baml-lib/baml-types/src/baml_value.rs @@ -681,7 +681,6 @@ impl Serialize for BamlValueWithMeta> { where S: Serializer, { - eprintln!("ABOUT TO SERIALIZE"); match self { BamlValueWithMeta::String(v, cr) => serialize_with_checks(v, cr, serializer), BamlValueWithMeta::Int(v, cr) => serialize_with_checks(v, cr, serializer), diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_class.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_class.rs index 35231d010b..e4ba0fd623 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_class.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_class.rs @@ -82,7 +82,6 @@ impl TypeCoercer for Class { } Some(crate::jsonish::Value::Object(obj, completion)) => { // match keys, if that fails, then do something fancy later. - // dbg!(&obj); let mut extra_keys = vec![]; let mut found_keys = false; obj.iter().for_each(|(key, v)| { diff --git a/engine/baml-lib/jsonish/src/deserializer/semantic_streaming.rs b/engine/baml-lib/jsonish/src/deserializer/semantic_streaming.rs index 23adb3f7f3..9fee731d04 100644 --- a/engine/baml-lib/jsonish/src/deserializer/semantic_streaming.rs +++ b/engine/baml-lib/jsonish/src/deserializer/semantic_streaming.rs @@ -110,17 +110,7 @@ fn process_node( Some(field_name.to_string()) } }).collect(); - // let missing_needed_fields = needed_fields.difference(&new_field_names); let missing_needed_fields: Vec<_> = needed_fields.difference(&present_nonnull_fields).into_iter().collect(); - // if (class_name == "SmallThing") { - if false { - dbg!(class_name); - dbg!(&value_field_names); - dbg!(&present_nonnull_fields); - dbg!(&needed_fields); - dbg!(&missing_needed_fields); - dbg!(&value_fields); - } // The fields that need to be filled in by Null are initially the diff --git a/engine/baml-lib/jsonish/src/jsonish/parser/entry.rs b/engine/baml-lib/jsonish/src/jsonish/parser/entry.rs index 56e4413ccf..2a324d7011 100644 --- a/engine/baml-lib/jsonish/src/jsonish/parser/entry.rs +++ b/engine/baml-lib/jsonish/src/jsonish/parser/entry.rs @@ -141,7 +141,6 @@ pub fn parse(str: &str, mut options: ParseOptions) -> Result { Ok(items) => match items.len() { 0 => {} 1 => { - // eprintln!("MULTI_JSON: {items:?}"); let ret = Value::AnyOf( vec![Value::FixedJson( @@ -154,7 +153,6 @@ pub fn parse(str: &str, mut options: ParseOptions) -> Result { )], str.to_string(), ); - // eprintln!("ret: {ret:?}"); return Ok(ret); } n => { diff --git a/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser.rs b/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser.rs index 0cf0bd15c7..cf9dd94f70 100644 --- a/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser.rs +++ b/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser.rs @@ -107,7 +107,6 @@ mod tests { fn test_partial_array() { let opts = ParseOptions::default(); let vals = parse("[12", &opts).unwrap(); - dbg!(&vals); match vals[0].0.clone() { Value::Array(xs, array_cmplt) => { @@ -115,7 +114,6 @@ mod tests { assert_eq!(array_cmplt, CompletionState::Incomplete); match &xs[0] { Value::Number(n, n_cmplt) => { - dbg!(&n); assert_eq!(n, &serde_json::Number::from(12)); assert_eq!(n_cmplt, &CompletionState::Incomplete); } @@ -130,7 +128,6 @@ mod tests { fn test_partial_object() { let opts = ParseOptions::default(); let vals = parse(r#"{"a": 11, "b": 22"#, &opts).unwrap(); - dbg!(&vals); match &vals[0].0 { Value::Object(fields, obj_cmplt) => { assert_eq!(fields.len(), 2); @@ -155,7 +152,6 @@ mod tests { fn test_partial_object_newlines() { let opts = ParseOptions::default(); let vals = parse("{\n \"a\": 11, \n \"b\": 22", &opts).unwrap(); - dbg!(&vals); match &vals[0].0 { Value::Object(fields, obj_cmplt) => { assert_eq!(fields.len(), 2); diff --git a/engine/baml-lib/jsonish/src/jsonish/parser/markdown_parser.rs b/engine/baml-lib/jsonish/src/jsonish/parser/markdown_parser.rs index 93f168b292..6b0e9d010b 100644 --- a/engine/baml-lib/jsonish/src/jsonish/parser/markdown_parser.rs +++ b/engine/baml-lib/jsonish/src/jsonish/parser/markdown_parser.rs @@ -47,7 +47,6 @@ pub fn parse(str: &str, options: &ParseOptions) -> Result> { match res { Ok(v) => { - // eprintln!("Pushing value {v:?}"); // TODO: Add any more additional strings here. values.push(MarkdownResult::CodeBlock( if tag.len() > 3 { diff --git a/engine/baml-lib/jsonish/src/tests/test_maps.rs b/engine/baml-lib/jsonish/src/tests/test_maps.rs index 2b14391c66..05337a32e7 100644 --- a/engine/baml-lib/jsonish/src/tests/test_maps.rs +++ b/engine/baml-lib/jsonish/src/tests/test_maps.rs @@ -167,7 +167,6 @@ fn test_union_of_map_and_class() { assert!(result.is_ok(), "Failed to parse: {:?}", result); let value = result.unwrap(); - dbg!(&value); assert!(matches!(value, BamlValueWithFlags::Class(_, _, _))); log::trace!("Score: {}", value.score()); diff --git a/engine/baml-runtime/src/runtime/runtime_interface.rs b/engine/baml-runtime/src/runtime/runtime_interface.rs index 9655baff67..deec73e29c 100644 --- a/engine/baml-runtime/src/runtime/runtime_interface.rs +++ b/engine/baml-runtime/src/runtime/runtime_interface.rs @@ -378,8 +378,6 @@ impl RuntimeInterface for InternalBamlRuntime { // Now actually execute the code. let (history, _) = orchestrate_call(orchestrator, self.ir(), &ctx, &renderer, &baml_args, |s| { - // eprintln!("RAW"); - // eprintln!("{}", s); renderer.parse(self.ir(), s, false) }) .await; diff --git a/engine/language_client_codegen/src/python/mod.rs b/engine/language_client_codegen/src/python/mod.rs index a9ee02ac97..547aead87f 100644 --- a/engine/language_client_codegen/src/python/mod.rs +++ b/engine/language_client_codegen/src/python/mod.rs @@ -351,6 +351,5 @@ class Foo { let res = generate(&ir, &generator_args).unwrap(); let partial_types = res.get(&PathBuf::from("partial_types.py")).unwrap(); eprintln!("{}", partial_types); - assert!(false); } } diff --git a/engine/language_client_codegen/src/typescript/mod.rs b/engine/language_client_codegen/src/typescript/mod.rs index 7d02ae8d72..76c5e0e36c 100644 --- a/engine/language_client_codegen/src/typescript/mod.rs +++ b/engine/language_client_codegen/src/typescript/mod.rs @@ -478,6 +478,5 @@ function MkFoo() -> Foo { eprintln!("{}", partial_types); let async_client = res.get(&PathBuf::from("async_client.ts")).unwrap(); eprintln!("{}", async_client); - assert!(false); } } diff --git a/engine/language_client_python/src/types/function_results.rs b/engine/language_client_python/src/types/function_results.rs index e62937cc23..bf5af44952 100644 --- a/engine/language_client_python/src/types/function_results.rs +++ b/engine/language_client_python/src/types/function_results.rs @@ -42,7 +42,6 @@ impl FunctionResult { .map_err(BamlError::from_anyhow)?; let parsed = pythonize_strict(py, parsed.clone(), &enum_module, &cls_module, &partial_cls_module, allow_partials); - // eprintln!("parsed result: {:?}", parsed); Ok(parsed?) } @@ -85,7 +84,6 @@ fn pythonize_strict( allow_partials: bool, ) -> PyResult { let allow_partials = allow_partials && !parsed.0.meta().2.required_done; - // eprintln!("pythonize_strict parsed: {:?}", parsed); let meta = parsed.0.meta().clone(); let py_value_without_constraints = match parsed.0 { BamlValueWithMeta::String(val, _) => val.into_py_any(py), @@ -210,7 +208,6 @@ fn pythonize_strict( let (_, checks, completion_state) = meta; if checks.is_empty() && !completion_state.display { - // eprintln!("ret1: {:?}", py_value_without_constraints); Ok(py_value_without_constraints) } else { @@ -254,8 +251,6 @@ fn pythonize_strict( let checked_instance = class_checked_type.call_method("model_validate", (properties_dict.clone(),), None).expect("model_validate"); - // eprintln!("ret2: {:?}", checked_instance); - Ok::, PyErr>(checked_instance.into()) } else { Ok(py_value_without_constraints) @@ -263,7 +258,6 @@ fn pythonize_strict( let value_with_possible_completion_state = if completion_state.display && allow_partials { let value_type = value_with_possible_checks.bind(py).get_type(); - // eprintln!("value_type: {:?}", value_type); // Prepare the properties dictionary let properties_dict = pyo3::types::PyDict::new(py); @@ -272,15 +266,12 @@ fn pythonize_strict( // Prepare type parameters for StreamingState[...] let type_parameters_tuple = PyTuple::new(py, [value_type.as_ref()]).expect("PyTuple::new"); - // dbg!(&type_parameters_tuple); let class_streaming_state_type_constructor = partial_cls_module.getattr("StreamState").expect("getattr(StreamState)"); let class_completion_state_type: Bound<'_, PyAny> = class_streaming_state_type_constructor .call_method1("__class_getitem__", (type_parameters_tuple,)) .expect("__class_getitem__ for streaming"); - // dbg!(&class_completion_state_type); - // eprintln!("properties dict: {:?}", properties_dict); let streaming_state_instance = class_completion_state_type .call_method("model_validate", (properties_dict.clone(),), None) .expect("model_validate for streaming"); diff --git a/engine/language_client_ruby/ext/ruby_ffi/src/function_result.rs b/engine/language_client_ruby/ext/ruby_ffi/src/function_result.rs index 24f3c58ca2..e6b2d1b655 100644 --- a/engine/language_client_ruby/ext/ruby_ffi/src/function_result.rs +++ b/engine/language_client_ruby/ext/ruby_ffi/src/function_result.rs @@ -38,9 +38,6 @@ impl FunctionResult { partial_types: RModule, allow_partials: bool, ) -> Result { - dbg!(&types); - dbg!(&partial_types); - dbg!(&allow_partials); let res = match rb_self.inner.result_with_constraints_content() { Ok(parsed) => { ruby_to_json::RubyToJson::serialize_baml(ruby, types, partial_types, allow_partials, parsed.clone()) @@ -56,7 +53,6 @@ impl FunctionResult { format!("Failed to parse LLM response: {}", rb_self.inner), )), }; - dbg!(&res); res } diff --git a/engine/language_client_ruby/ext/ruby_ffi/src/ruby_to_json.rs b/engine/language_client_ruby/ext/ruby_ffi/src/ruby_to_json.rs index cd473f2b96..77e6842248 100644 --- a/engine/language_client_ruby/ext/ruby_ffi/src/ruby_to_json.rs +++ b/engine/language_client_ruby/ext/ruby_ffi/src/ruby_to_json.rs @@ -67,11 +67,9 @@ impl<'rb> RubyToJson<'rb> { // If we encounter a BamlValue node with check results, serialize it as // { value: T, checks: K }. To compute `value`, we strip the metadata // off the node and pass it back to `serialize_baml`. - eprintln!("SERIALIZE: {:?}", &from); let (_flags, checks, completion) = from.0.meta_mut(); if completion.display && allow_partials { - eprintln!("... with state"); let hash = ruby.hash_new(); let stream_state_class = ruby.eval::("Baml::StreamState")?; hash.aset(ruby.sym_new("state"), ruby.sym_new(serde_json::to_string(&completion.state).expect("Serializing CompletionState is safe.")))?; @@ -79,7 +77,6 @@ impl<'rb> RubyToJson<'rb> { let serialized_subvalue = RubyToJson::serialize_baml(ruby, types, partial_types, allow_partials, from)?; hash.aset(ruby.sym_new("value"), serialized_subvalue)?; let res = stream_state_class.funcall("new", (hash,)); - eprintln!("with_state res: {res:?}"); Ok(res?) } // Otherwise encode it directly. @@ -99,19 +96,15 @@ impl<'rb> RubyToJson<'rb> { hash.aset(ruby.sym_new("checks"), serialized_checks)?; } let res = checked_class.funcall("new", (hash,)); - dbg!(&res); - eprintln!("with_checks res: {res:?}"); Ok(res?) } // Otherwise encode it directly. else { - eprintln!("...without_state"); let res = match from.0 { BamlValueWithMeta::Class(class_name, class_fields, _) => { let hash = ruby.hash_new(); for (k, v) in class_fields.into_iter() { let subvalue_allow_partials = allow_partials && !v.meta().2.required_done; - dbg!(&subvalue_allow_partials); let k = ruby.sym_new(k.as_str()); let v = RubyToJson::serialize_baml(ruby, types, partial_types, subvalue_allow_partials, ResponseBamlValue(v))?; hash.aset(k, v)?; @@ -129,11 +122,9 @@ impl<'rb> RubyToJson<'rb> { match preferred_class.funcall("new", (hash,)) { Ok(res) => Ok(res), Err(original_error) => { - eprintln!("preferred_class {:?}, failed: {:?}. falling back to {:?}", preferred_class, original_error, backup_class); match backup_class.funcall("new", (hash,)) { Ok(res) => Ok(res), Err(e) => { - eprintln!("backup {:?} failed with {:?}", backup_class, e); Err(original_error) } } @@ -169,7 +160,6 @@ impl<'rb> RubyToJson<'rb> { } _ => serde_magnus::serialize(&from.0.value()), }; - dbg!(&res); res } } From f137258a8c051b38d41d3e01fbe0c4aee89d53b2 Mon Sep 17 00:00:00 2001 From: Greg Hale Date: Tue, 28 Jan 2025 12:12:51 -0800 Subject: [PATCH 07/13] whitespace --- engine/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/Cargo.toml b/engine/Cargo.toml index f88840ac5b..f85dc4b3b5 100644 --- a/engine/Cargo.toml +++ b/engine/Cargo.toml @@ -112,4 +112,4 @@ lto = false inherits = "dev" [profile.release] -lto = true \ No newline at end of file +lto = true From 743793976326d76c68e270127fc6dec03af93e81 Mon Sep 17 00:00:00 2001 From: Greg Hale Date: Tue, 28 Jan 2025 12:46:19 -0800 Subject: [PATCH 08/13] move docstrings to trait def --- .../baml-core/src/ir/ir_helpers/mod.rs | 65 ++++++++++--------- 1 file changed, 35 insertions(+), 30 deletions(-) diff --git a/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs b/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs index 94041ee011..bfb1aff96f 100644 --- a/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs +++ b/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs @@ -56,22 +56,57 @@ pub trait IRHelper { params: &BamlMap, coerce_settings: ArgCoercer, ) -> Result; + + /// BAML does not support class-based subtyping. Nonetheless some builtin + /// BAML types are subtypes of others, and we need to be able to test this + /// when checking the types of values. + /// + /// For examples of pairs of types and their subtyping relationship, see + /// this module's test suite. + /// + /// Consider renaming this to `is_assignable`. fn is_subtype(&self, base: &FieldType, other: &FieldType) -> bool; + + /// For some `BamlValue` with type `FieldType`, walk the structure of both the value + /// and the type simultaneously, associating each node in the `BamlValue` with its + /// `FieldType`. fn distribute_type( &self, value: BamlValue, field_type: FieldType, ) -> anyhow::Result>; + + /// For some `BamlValueWithMeta` with type `FieldType`, walk the structure of both the value + /// and the type simultaneously, associating each node in the `BamlValue` with its + /// `FieldType`. + /// TODO (Greg): Make this function DynamicTypes-aware. Right now it assigns default metadata + /// to unknown classes, which may have been created with TypeBuilder. fn distribute_type_with_meta( &self, value: BamlValueWithMeta, field_type: FieldType, // default_meta: Option<&T>, ) -> Result>; + + /// For any FieldType, check if the field type is FieldType::WithMetadata, + /// and if so, return the metadata alongside the base type. + /// All other field types will be returned as is, alongside default metadata. fn distribute_metadata<'a>( &'a self, field_type: &'a FieldType, ) -> (&'a FieldType, (Vec, StreamingBehavior)); + + /// Constraints may live in several places. A constrained base type stors its + /// constraints by wrapping itself in the `FieldType::WithMetadata` constructor. + /// Additionally, `FieldType::Class` may have constraints stored in its class node, + /// and `FieldType::Enum` can store constraints in its `Enum` node. + /// And the `FieldType::WithMetadata` constructor might wrap another + /// `FieldType::WithMetadata` constructor. + /// + /// This function collects constraints for a given type from all these + /// possible sources. Whenever querying a type for its constraints, you + /// should do so with this function, instead of searching manually for all + /// the places that Constraints can live. fn distribute_constraints<'a>( &'a self, field_type: &'a FieldType, @@ -233,14 +268,6 @@ impl IRHelper for IntermediateRepr { } } - /// BAML does not support class-based subtyping. Nonetheless some builtin - /// BAML types are subtypes of others, and we need to be able to test this - /// when checking the types of values. - /// - /// For examples of pairs of types and their subtyping relationship, see - /// this module's test suite. - /// - /// Consider renaming this to `is_assignable`. fn is_subtype(&self, base: &FieldType, other: &FieldType) -> bool { if base == other { return true; @@ -337,9 +364,6 @@ impl IRHelper for IntermediateRepr { } } - /// For some `BamlValue` with type `FieldType`, walk the structure of both the value - /// and the type simultaneously, associating each node in the `BamlValue` with its - /// `FieldType`. fn distribute_type( &self, value: BamlValue, @@ -526,11 +550,6 @@ impl IRHelper for IntermediateRepr { } } - /// For some `BamlValueWithMeta` with type `FieldType`, walk the structure of both the value - /// and the type simultaneously, associating each node in the `BamlValue` with its - /// `FieldType`. - /// TODO (Greg): Make this function DynamicTypes-aware. Right now it assigns default metadata - /// to unknown classes, which may have been created with TypeBuilder. fn distribute_type_with_meta( &self, value: BamlValueWithMeta, @@ -685,17 +704,6 @@ impl IRHelper for IntermediateRepr { } } - /// Constraints may live in several places. A constrained base type stors its - /// constraints by wrapping itself in the `FieldType::WithMetadata` constructor. - /// Additionally, `FieldType::Class` may have constraints stored in its class node, - /// and `FieldType::Enum` can store constraints in its `Enum` node. - /// And the `FieldType::WithMetadata` constructor might wrap another - /// `FieldType::WithMetadata` constructor. - /// - /// This function collects constraints for a given type from all these - /// possible sources. Whenever querying a type for its constraints, you - /// should do so with this function, instead of searching manually for all - /// the places that Constraints can live. fn distribute_constraints<'a>( &'a self, field_type: &'a FieldType, @@ -704,9 +712,6 @@ impl IRHelper for IntermediateRepr { (field_type, metadata.0) } - /// For any FieldType, check if the field type is FieldType::WithMetadata, - /// and if so, return the metadata alongside the base type. - /// All other field types will be returned as is, alongside default metadata. fn distribute_metadata<'a>( &'a self, field_type: &'a FieldType, From bc20392b6c086e9ed865bab07c85eae4ba729419 Mon Sep 17 00:00:00 2001 From: Greg Hale Date: Tue, 28 Jan 2025 12:52:50 -0800 Subject: [PATCH 09/13] cleanup --- .../baml-core/src/ir/ir_helpers/mod.rs | 122 +----------------- .../ext/ruby_ffi/src/ruby_to_json.rs | 6 +- 2 files changed, 9 insertions(+), 119 deletions(-) diff --git a/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs b/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs index bfb1aff96f..aa962a3cd9 100644 --- a/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs +++ b/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs @@ -432,13 +432,6 @@ impl IRHelper for IntermediateRepr { _ => Some(FieldType::Union(item_types)), }; - // let key_type = match self.distribute_metadata(&field_type) { - // (FieldType::Map(annotation_key_type, _), _) => annotation_key_type.as_ref(), - // // TODO: Make the following a baml compiler error, too. - // (FieldType::RecursiveTypeAlias(_), _) => anyhow::bail!("Type aliases are not allowed as map keys."), - // _ => anyhow::bail!("Value was not a map."), - // }; - match maybe_item_type { Some(item_type) => { let map_type = FieldType::Map( @@ -606,7 +599,6 @@ impl IRHelper for IntermediateRepr { BamlValueWithMeta::Null(meta) => Ok(BamlValueWithMeta::Null((meta, field_type))), - // TODO: Handle enums and literal keys. BamlValueWithMeta::Map(pairs, meta) => { let (annotation_key_type, annotation_value_type) = map_types(self, &field_type) .ok_or(anyhow::anyhow!("Could not unify map with {field_type:?}"))?; @@ -632,9 +624,7 @@ impl IRHelper for IntermediateRepr { .into_iter() .map(|i| { item_type(self, &field_type, &i) - .ok_or({ - anyhow::anyhow!("Could not infer child type") - }) + .ok_or({ anyhow::anyhow!("Could not infer child type") }) .and_then(|item_type| self.distribute_type_with_meta(i, item_type)) }) .collect::>>()?; @@ -824,122 +814,22 @@ fn item_type( .recursive_alias_definition(alias_name) .and_then(|resolved_type| item_type(ir, resolved_type, baml_child_values)), FieldType::Union(variants) => { - let variant_children = variants.iter().filter_map(|variant| item_type(ir, variant, baml_child_values)).collect::>(); + let variant_children = variants + .iter() + .filter_map(|variant| item_type(ir, variant, baml_child_values)) + .collect::>(); match variant_children.len() { 0 => None, 1 => Some(variant_children[0].clone()), _ => Some(FieldType::Union(variant_children)), } - }, + } FieldType::Tuple(_) => None, FieldType::WithMetadata { base, .. } => item_type(ir, base, baml_child_values), }; res } -fn typecheck_value_with_meta( - ir: &IntermediateRepr, - value: &BamlValueWithMeta, - field_type: &FieldType, -) -> bool { - let field_base_type = ir.distribute_metadata(&field_type).0; - match value { - BamlValueWithMeta::String(s, meta) => { - let literal_type = FieldType::Literal(LiteralValue::String(s.clone())); - let primitive_type = FieldType::Primitive(TypeValue::String); - - ir.is_subtype(&literal_type, &field_base_type) - || ir.is_subtype(&primitive_type, &field_base_type) - } - BamlValueWithMeta::Int(i, meta) => { - ir.is_subtype(&FieldType::Literal(LiteralValue::Int(*i)), &field_base_type) - } - BamlValueWithMeta::Float(f, meta) => { - ir.is_subtype(&FieldType::Primitive(TypeValue::Float), &field_base_type) - } - - BamlValueWithMeta::Bool(b, meta) => { - let literal_type = FieldType::Literal(LiteralValue::Bool(*b)); - let primitive_type = FieldType::Primitive(TypeValue::Bool); - - ir.is_subtype(&literal_type, &field_base_type) - || ir.is_subtype(&primitive_type, &field_base_type) - } - - BamlValueWithMeta::Null(meta) => true, - - // TODO: Handle enums and literal keys. - BamlValueWithMeta::Map(pairs, meta) => { - true - // TODO! - } - - BamlValueWithMeta::List(items, meta) => { - let items_ok = items - .iter() - .map(|i| { - item_type(ir, &field_type, i) - .map_or(true, |item_ty| typecheck_value_with_meta(ir, i, &item_ty)) - }) - .all(|x| x); - items_ok - } - - BamlValueWithMeta::Media(m, meta) => ir.is_subtype( - &FieldType::Primitive(TypeValue::Media(m.media_type)), - &field_base_type, - ), - - BamlValueWithMeta::Enum(name, val, meta) => { - ir.is_subtype(&FieldType::Enum(name.clone()), &field_base_type) - } - - BamlValueWithMeta::Class(name, fields, meta) => { - // // Classes not present in the IR may be dynamically generated. - // // In this case, all types will be inferred, rather than distributed - // // from the `field_type` parameter. - - // TODO - true - - // if ir.find_class(&name).is_err() { - // return distribute_infer_class(self, &name, fields, meta); - // } - // if !self.is_subtype(&FieldType::Class(name.clone()), &field_base_type) { - // anyhow::bail!("Could not unify Class {} with {:?}", name, field_base_type); - // } else { - // let class_type = &self.find_class(&name)?.item.elem; - // let class_fields: BamlMap = class_type - // .static_fields - // .iter() - // .map(|field_node| { - // ( - // field_node.elem.name.clone(), - // field_node.elem.r#type.elem.clone(), - // ) - // }) - // .collect(); - // let mapped_fields = fields - // .into_iter() - // .map(|(k, v)| { - // let field_type = match class_fields.get(k.as_str()) { - // Some(ft) => ft.clone(), - // None => infer_type_with_meta(&v).unwrap_or(UNIT_TYPE), - // }; - // let mapped_field = self.distribute_type_with_meta(v, field_type)?; - // Ok((k, mapped_field)) - // }) - // .collect::>>>( - // )?; - // Ok(BamlValueWithMeta::Class( - // name, - // mapped_fields, - // (meta, field_type), - // )) - // } - } - } -} /// Like item_type, but specialized for maps. fn map_types<'ir, 'a>( diff --git a/engine/language_client_ruby/ext/ruby_ffi/src/ruby_to_json.rs b/engine/language_client_ruby/ext/ruby_ffi/src/ruby_to_json.rs index 77e6842248..0a361e8c01 100644 --- a/engine/language_client_ruby/ext/ruby_ffi/src/ruby_to_json.rs +++ b/engine/language_client_ruby/ext/ruby_ffi/src/ruby_to_json.rs @@ -1,4 +1,4 @@ -use baml_types::{BamlMap, BamlValue, BamlValueWithMeta, CompletionState, ResponseCheck}; +use baml_types::{BamlMap, BamlValue, BamlValueWithMeta, ResponseCheck}; use indexmap::IndexMap; use magnus::{ prelude::*, typed_data::Obj, value::Value, Error, Float, Integer, IntoValue, RArray, RClass, @@ -10,7 +10,7 @@ use crate::types::{ self, media::{Audio, Image}, }; -use jsonish::{deserializer::deserialize_flags::Flag, ResponseBamlValue}; +use jsonish::ResponseBamlValue; struct SerializationError { position: Vec, @@ -124,7 +124,7 @@ impl<'rb> RubyToJson<'rb> { Err(original_error) => { match backup_class.funcall("new", (hash,)) { Ok(res) => Ok(res), - Err(e) => { + Err(_) => { Err(original_error) } } From f10cd70df18b71facdf5861add565d08df1d0e02 Mon Sep 17 00:00:00 2001 From: Greg Hale Date: Tue, 28 Jan 2025 14:10:56 -0800 Subject: [PATCH 10/13] cleanup --- .../baml-core/src/ir/ir_helpers/mod.rs | 3 +- engine/baml-lib/baml-core/src/ir/repr.rs | 22 +-- engine/baml-lib/baml-types/src/baml_value.rs | 155 +----------------- .../src/deserializer/deserialize_flags.rs | 2 +- .../src/deserializer/semantic_streaming.rs | 49 +++--- engine/baml-lib/jsonish/src/helpers/mod.rs | 72 ++++---- engine/baml-lib/jsonish/src/jsonish/value.rs | 31 +--- engine/baml-lib/jsonish/src/lib.rs | 29 +--- .../baml-lib/jsonish/src/tests/test_lists.rs | 2 +- .../baml-lib/jsonish/src/tests/test_maps.rs | 2 +- .../schema-ast/src/parser/parse_identifier.rs | 2 +- engine/baml-runtime/benches/lib.rs | 1 - .../benches/sap_parser_benchmark.rs | 0 .../prompt_renderer/render_output_format.rs | 2 +- .../baml-schema-wasm/src/runtime_wasm/mod.rs | 3 - .../src/python/templates/types.py.j2 | 1 + 16 files changed, 98 insertions(+), 278 deletions(-) delete mode 100644 engine/baml-runtime/benches/lib.rs delete mode 100644 engine/baml-runtime/benches/sap_parser_benchmark.rs diff --git a/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs b/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs index aa962a3cd9..e8cbf6ef6c 100644 --- a/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs +++ b/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs @@ -624,7 +624,7 @@ impl IRHelper for IntermediateRepr { .into_iter() .map(|i| { item_type(self, &field_type, &i) - .ok_or({ anyhow::anyhow!("Could not infer child type") }) + .ok_or(anyhow::anyhow!("Could not infer child type")) .and_then(|item_type| self.distribute_type_with_meta(i, item_type)) }) .collect::>>()?; @@ -830,7 +830,6 @@ fn item_type( res } - /// Like item_type, but specialized for maps. fn map_types<'ir, 'a>( ir: &'ir IntermediateRepr, diff --git a/engine/baml-lib/baml-core/src/ir/repr.rs b/engine/baml-lib/baml-core/src/ir/repr.rs index baa4b7eb87..6f8f455590 100644 --- a/engine/baml-lib/baml-core/src/ir/repr.rs +++ b/engine/baml-lib/baml-core/src/ir/repr.rs @@ -2,8 +2,8 @@ use std::collections::HashSet; use anyhow::{anyhow, Result}; use baml_types::{ - Constraint, ConstraintLevel, FieldType, JinjaExpression, Resolvable, StreamingBehavior, StringOr, - UnresolvedValue, + Constraint, ConstraintLevel, FieldType, JinjaExpression, Resolvable, StreamingBehavior, + StringOr, UnresolvedValue, }; use either::Either; use indexmap::{IndexMap, IndexSet}; @@ -15,7 +15,9 @@ use internal_baml_parser_database::{ Attributes, ParserDatabase, PromptAst, RetryPolicyStrategy, TypeWalker, }; -use internal_baml_schema_ast::ast::{self, Attribute, FieldArity, SubType, ValExpId, WithName, WithSpan}; +use internal_baml_schema_ast::ast::{ + self, Attribute, FieldArity, SubType, ValExpId, WithName, WithSpan, +}; use internal_llm_client::{ClientProvider, ClientSpec, UnresolvedClientProperty}; use serde::Serialize; @@ -347,10 +349,7 @@ fn to_ir_attributes( }); let streaming_done = streaming_done.as_ref().and_then(|v| { if *v { - Some(( - "stream.done".to_string(), - UnresolvedValue::Bool(true, ()), - )) + Some(("stream.done".to_string(), UnresolvedValue::Bool(true, ()))) } else { None } @@ -430,7 +429,7 @@ fn type_with_arity(t: FieldType, arity: &FieldArity) -> FieldType { impl WithRepr for ast::FieldType { // TODO: (Greg) This code only extracts constraints, and ignores any // other types of attributes attached to the type directly. - fn attributes(&self, db: &ParserDatabase) -> NodeAttributes { + fn attributes(&self, _db: &ParserDatabase) -> NodeAttributes { let constraints = self .attributes() .iter() @@ -594,7 +593,6 @@ impl WithRepr for ast::FieldType { ), }; - let use_metadata = has_constraints || has_special_streaming_behavior; let with_constraints = if use_metadata { FieldType::WithMetadata { @@ -1440,7 +1438,6 @@ mod tests { let alias = class.find_field("field").unwrap(); assert_eq!(*alias.r#type(), FieldType::Primitive(TypeValue::Int)); - } #[test] @@ -1461,7 +1458,10 @@ mod tests { let class = ir.find_class("Test").unwrap(); let alias = class.find_field("field").unwrap(); - let FieldType::WithMetadata { base, constraints, .. } = alias.r#type() else { + let FieldType::WithMetadata { + base, constraints, .. + } = alias.r#type() + else { panic!( "expected resolved constrained type, found {:?}", alias.r#type() diff --git a/engine/baml-lib/baml-types/src/baml_value.rs b/engine/baml-lib/baml-types/src/baml_value.rs index 7f05a4ef64..dd5c4ad3f1 100644 --- a/engine/baml-lib/baml-types/src/baml_value.rs +++ b/engine/baml-lib/baml-types/src/baml_value.rs @@ -673,7 +673,7 @@ impl From> for BamlValue { } } -/// This special-purpose serializer is used for the public-facing API. +/// This special-purpose serializer is used for jinja. /// When we want to extend the orchestrator with BamlValues packing more /// metadata than just a `Vec`, ` impl Serialize for BamlValueWithMeta> { @@ -747,156 +747,8 @@ fn add_checks<'a, S: SerializeMap>( Ok(()) } -// impl Serialize for BamlValueWithMeta -// where T: SerializeMetadata + std::fmt::Debug, -// { -// -// fn serialize(&self, serializer: S) -> Result -// where -// S: Serializer, -// { -// let bare_value = self.value(); -// let metadata_fields = &self.meta().metadata_fields(&bare_value)?; -// match self { -// BamlValueWithMeta::String(v, _metadata) => serialize_with_checks(v, metadata_fields, serializer), -// BamlValueWithMeta::Int(v, _metadata) => serialize_with_checks(v, metadata_fields, serializer), -// BamlValueWithMeta::Float(v, _metadata) => serialize_with_checks(v, metadata_fields, serializer), -// BamlValueWithMeta::Bool(v, _metadata) => serialize_with_checks(v, metadata_fields, serializer), -// BamlValueWithMeta::Map(v, _metadata) => { -// let mut map = serializer.serialize_map(None)?; -// for (key, value) in v { -// map.serialize_entry::>(key, value)?; -// } -// add_checks(&mut map, &self.meta().metadata_fields(&bare_value)?)?; -// map.end() -// } -// BamlValueWithMeta::List(v, _metadata) => serialize_with_checks(v, metadata_fields, serializer), -// BamlValueWithMeta::Media(v, _metadata) => serialize_with_checks(v, metadata_fields, serializer), -// BamlValueWithMeta::Enum(_enum_name, v, _metadata) => serialize_with_checks(v, metadata_fields, serializer), -// BamlValueWithMeta::Class(_class_name, v, _metadata) => { -// let metadata_fields = self.meta().metadata_fields(&bare_value); -// if metadata_fields.is_empty() { -// let mut map = serializer.serialize_map(None)?; -// v.into_iter().try_for_each(|(key, value)| { -// map.serialize_entry(key, value) -// })?; -// add_checks(&mut map, &metadata_fields)?; -// map.end() -// } else { -// let mut checked_value = serializer.serialize_map(Some(2))?; -// checked_value.serialize_entry("value", &v)?; -// add_checks(&mut checked_value, &metadata_fields)?; -// checked_value.end() -// } -// } -// BamlValueWithMeta::Null(_) => serialize_with_checks(&(), &self.meta().metadata_fields(), serializer), -// } -// } -// } -// -// fn serialize_with_checks( -// value: &T, -// metadata_fields: &Vec<(String, serde_json::Value)>, -// serializer: S, -// ) -> Result -// where -// S: Serializer, -// { -// if !metadata_fields.is_empty() { -// let mut map = serializer.serialize_map(Some(2))?; -// map.serialize_entry("value", value)?; -// add_checks(&mut map, metadata_fields)?; -// map.end() -// } else { -// value.serialize(serializer) -// } -// } -// -// fn add_checks<'a, S: SerializeMap>( -// map: &'a mut S, -// metadata_fields: &Vec<(String, serde_json::Value)>, -// ) -> Result<(), S::Error> -// { -// metadata_fields.iter().try_for_each(|(field_name, value)| { -// map.serialize_entry(&field_name, &value) -// })?; -// Ok(()) -// } -// -// pub trait SerializeMetadata { -// fn metadata_fields(&self, bare_value: &BamlValue) -> Result, serde_json::Error>; -// } -// -// // This instance is used in constraint tests. -// // Consider modifying that test and deleting this instance. -// impl SerializeMetadata for Vec { -// fn metadata_fields(&self, _bare_value: &serde_json::Value) -> Result, serde_json::Error> { -// if !self.is_empty() { -// let checks_map: HashMap<_,_> = self.iter().map(|check| (check.name.clone(), check)).collect(); -// let json_checks_map = serde_json::to_value(checks_map).expect("serialization of checks is safe"); -// Ok(vec![("checks", json_checks_map)]) -// } else { -// Ok(Vec::new()) -// } -// } -// } -// -// impl SerializeMetadata for (T, Vec, Option) { -// -// // If there are only checks: -// // [("checks", checks), ("value", value)] -// // If there is completion state: -// // [("state", state), ("value", value)] -// // If there are checks and completion state: -// // [("state", state), ("value": { "checks": checks, "value": value })] -// // If there are neither checks nor completion state: -// // [("value", value)] -// fn metadata_fields(&self, bare_value: &BamlValue) -> Result, serde_json::Error> { -// let checks: Vec<(&str, &ResponseCheck)> = self.1.iter().map(|check| (check.name.as_str(), check)).collect(); -// let completion_state: Option<&CompletionState> = self.2.as_ref(); -// -// let checks_json = serde_json::to_value(&checks)?; -// let bare_value_json = serde_json::to_value(bare_value)?; -// -// match (checks.len(), completion_state) { -// (0, None) => Ok(vec![("value", bare_value_json)]), -// (_, None) => Ok(vec![("value", bare_value_json), ("checks", checks_json)]), -// (0, Some(state)) => Ok(vec![("value", bare_value_json), ("state", serde_json::to_value(state)?)]), -// (_, Some(state)) => Ok(vec![ -// ("state", serde_json::to_value(state)?), -// ("value", serde_json::to_value(&vec![ -// ("value", bare_value_json), -// ("checks", checks_json) -// ].into_iter().collect::>())?) -// ]), -// -// } -// // if !checks.is_empty() { -// // let checks_json = serde_json::to_value(checks).expect("Serializing checks is safe."); -// // fields.push(("checks".to_string(), checks_json)); -// // } -// -// // let value_considering_checks = if checks.is_empty() { -// // serde_json::to_value(bare_value)? -// // } else { -// // let object = vec![ -// // ("value", serde_json::to_value(bare_value)?), -// // ("checks", serde_json::to_value(fields)?), -// // ].into_iter().collect::>(); -// // serde_json::to_value(object)? -// // }; -// -// // let value_considering_completion_state = if let Some(state) = completion_state { -// // vec![ ("state", serde_json::to_value(&state)?) ] -// // } else { -// // value_considering_checks -// // } -// -// // Ok(value_considering_completion_state) -// } -// -// } - +/// This type is used in `BamlResponseValue` to summarize data about the +/// completion state and completion behavior of a BamlValueWithMeta node. #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct Completion { pub state: CompletionState, @@ -913,7 +765,6 @@ pub enum CompletionState { impl Default for Completion { fn default() -> Self { - panic!("I hope we don't use this default"); Completion { state: CompletionState::Complete, display: false, diff --git a/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs b/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs index 4d5c3e5e21..40e6043ac3 100644 --- a/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs +++ b/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs @@ -137,7 +137,7 @@ impl std::fmt::Debug for DeserializerConditions { impl std::fmt::Display for DeserializerConditions { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if true { + if self.flags.is_empty() { return Ok(()); } diff --git a/engine/baml-lib/jsonish/src/deserializer/semantic_streaming.rs b/engine/baml-lib/jsonish/src/deserializer/semantic_streaming.rs index 9fee731d04..c8102cf3b6 100644 --- a/engine/baml-lib/jsonish/src/deserializer/semantic_streaming.rs +++ b/engine/baml-lib/jsonish/src/deserializer/semantic_streaming.rs @@ -41,12 +41,12 @@ pub fn validate_streaming_state( let baml_value_with_streaming_state_and_behavior = typed_baml_value.map_meta(|(flags, r#type)| (completion_state(&flags), r#type)); - let res = process_node( + let top_level_node = process_node( ir, baml_value_with_streaming_state_and_behavior, allow_partials, - ); - res + )?; + Ok(top_level_node) } /// Consider a node's type, streaming state, and streaming behavior annotations. Return @@ -57,8 +57,9 @@ pub fn validate_streaming_state( /// This function descends into child nodes when the argument is a compound value. /// /// Params: -/// value: A done in the BamlValue tree. -/// allow_partials: +/// value: A node in the BamlValue tree. +/// allow_partials: Whether this node may contain partial values. (Once we +/// see a false, all child nodes will also get false). fn process_node( ir: &IntermediateRepr, value: BamlValueWithMeta<(CompletionState, &FieldType)>, @@ -103,27 +104,21 @@ fn process_node( .collect(); let needed_fields: HashSet = needed_fields(ir, field_type, allow_partials_in_sub_nodes)?; - let present_nonnull_fields: HashSet = value_fields.iter().filter_map(|(field_name, field_value)| { - if matches!(field_value, BamlValueWithMeta::Null(_)) { - None - } else { - Some(field_name.to_string()) - } - }).collect(); - let missing_needed_fields: Vec<_> = needed_fields.difference(&present_nonnull_fields).into_iter().collect(); - - // The fields that need to be filled in by Null are initially the // fields in the Class type that are not present in the input // value. let fields_needing_null = fields_needing_null_filler(ir, field_type, value_field_names, allow_partials)?; - let mut deleted_fields: HashMap> = + // We might later delete fields from 'value_fields`, (e.g. if they + // were incomplete but required `done`). These deleted fields will + // need to be replaced with nulls. We initialize a hashmap to hold + // these nulls here. + let mut deletion_nulls: HashMap> = HashMap::new(); - // let unneeded_fields = field_names.difference(&needed_fields); - let needed_nulls = fields_needing_null + // Null values used to fill gaps in the input hashmap. + let filler_nulls = fields_needing_null .into_iter() .filter_map(|ref null_field_name| { let field = value_fields @@ -142,6 +137,8 @@ fn process_node( }) .collect::>>(); + // Fields of the input hashmap, transformed by running the + // semantic-streaming algorithm, and deleted if appropriate. let mut new_fields = value_fields .into_iter() .filter_map(|(field_name, field_value)| { @@ -161,13 +158,14 @@ fn process_node( required_done: false, }; let null = BamlValueWithMeta::Null(state); - deleted_fields.insert(field_name, null); + deletion_nulls.insert(field_name, null); None } } }) .collect::>>(); + // Names of fields from the input hashmap that survived semantic streaming. let derived_present_nonnull_fields: HashSet = new_fields.iter().filter_map(|(field_name, field_value)| { if matches!(field_value, BamlValueWithMeta::Null(_)) { None @@ -177,8 +175,8 @@ fn process_node( }).collect(); let missing_needed_fields: Vec<_> = needed_fields.difference(&derived_present_nonnull_fields).into_iter().collect(); - new_fields.extend(needed_nulls); - new_fields.extend(deleted_fields); + new_fields.extend(filler_nulls); + new_fields.extend(deletion_nulls); let res = BamlValueWithMeta::Class(class_name.clone(), new_fields, new_meta); if missing_needed_fields.clone().len() == 0 { @@ -203,6 +201,9 @@ fn process_node( } +/// Given a type and an input hashmap, if that type is a class, determine what +/// fields in the class need to be filled in by a null. A field needs to be +/// filled by a null if it is not present in the hashmap value. fn fields_needing_null_filler<'a>( ir: &'a IntermediateRepr, field_type: &'a FieldType, @@ -238,13 +239,11 @@ fn fields_needing_null_filler<'a>( /// For a given type, assume that it is a class, and list the fields of that /// class that were marked `@stream.not_null`. -/// The parameter must have already been passed through `distribute_metadata`, -/// it's an error to call this function with undistributed metadata. /// /// When allow_partials==false, we are in a context where we are done with /// streaming, so we override the normal implemenation of this function -/// and return an empty set (because we are ignoring the "needed" property, -/// which only applies to mid-stream messages). +/// and return an empty set (because we are ignoring the "@stream.not_null" property, +/// which only applies when `allow_partials==true`). fn needed_fields( ir: &IntermediateRepr, field_type: &FieldType, diff --git a/engine/baml-lib/jsonish/src/helpers/mod.rs b/engine/baml-lib/jsonish/src/helpers/mod.rs index 8db65ca486..e757b21121 100644 --- a/engine/baml-lib/jsonish/src/helpers/mod.rs +++ b/engine/baml-lib/jsonish/src/helpers/mod.rs @@ -2,7 +2,7 @@ pub mod common; use std::{collections::HashSet, path::PathBuf}; use anyhow::Result; -use baml_types::EvaluationContext; +use baml_types::{EvaluationContext, JinjaExpression}; use baml_types::{BamlValueWithMeta, ResponseCheck, StreamingBehavior}; use indexmap::{IndexMap, IndexSet}; use internal_baml_core::{ @@ -256,34 +256,42 @@ fn relevant_data_models<'a>( )) } -// /// Validate a parsed value, checking asserts and checks. -// pub fn parsed_value_to_response( -// ir: &IntermediateRepr, -// baml_value: BamlValueWithFlags, -// field_type: &FieldType, -// allow_partials: bool, -// ) -> Result { -// let meta_flags: BamlValueWithMeta> = baml_value.into(); -// -// let baml_value_with_streaming = meta_flags.map_meta_owned(|flags| { -// let constraint_results = constraint_results(&flags); -// let response_checks: Vec = constraint_results -// .iter() -// .map(|(label, expr, result)| { -// let status = (if *result { "succeeded" } else { "failed" }).to_string(); -// ResponseCheck { -// name: label.clone(), -// expression: expr.0.clone(), -// status, -// } -// }) -// .collect(); -// (flags, response_checks) -// }); -// -// let response_value2 = -// validate_streaming_state(ir, baml_value_with_streaming, field_type, allow_partials) -// .map_err(|s| anyhow::anyhow!("{s}"))?; -// -// Ok(crate::ResponseBamlValue(response_value)) -// } +/// Validate a parsed value, checking asserts and checks. +pub fn parsed_value_to_response( + ir: &IntermediateRepr, + baml_value: BamlValueWithFlags, + field_type: &FieldType, + allow_partials: bool, +) -> Result { + + let meta_flags: BamlValueWithMeta> = baml_value.clone().into(); + let baml_value_with_meta: BamlValueWithMeta> = + baml_value.clone().into(); + + let value_with_response_checks: BamlValueWithMeta> = baml_value_with_meta + .map_meta(|cs| { + cs.iter() + .map(|(label, expr, result)| { + let status = (if *result { "succeeded" } else { "failed" }).to_string(); + ResponseCheck { + name: label.clone(), + expression: expr.0.clone(), + status, + } + }) + .collect() + }); + + let baml_value_with_streaming = + validate_streaming_state(ir, &baml_value, field_type, allow_partials) + .map_err(|s| anyhow::anyhow!("{s:?}"))?; + + // Combine the baml_value, its types, the parser flags, and the streaming state + // into a final value. + // Node that we set the StreamState to `None` unless `allow_partials`. + let response_value = baml_value_with_streaming + .zip_meta(&value_with_response_checks)? + .zip_meta(&meta_flags)? + .map_meta(|((x, y), z)| (z.clone(), y.clone(), x.clone() )); + Ok(ResponseBamlValue(response_value)) +} diff --git a/engine/baml-lib/jsonish/src/jsonish/value.rs b/engine/baml-lib/jsonish/src/jsonish/value.rs index fb568fa6cd..a715a2913e 100644 --- a/engine/baml-lib/jsonish/src/jsonish/value.rs +++ b/engine/baml-lib/jsonish/src/jsonish/value.rs @@ -29,7 +29,7 @@ pub enum Value { // Fixed types Markdown(String, Box, CompletionState), - FixedJson(Box, Vec), // TODO: Does this really need a CompletionState? + FixedJson(Box, Vec), AnyOf(Vec, String), } @@ -160,25 +160,6 @@ impl Value { } } - pub fn completed_deeply(self) -> Self { - match self { - Value::String(v, _) => Value::String(v, CompletionState::Complete), - Value::Number(v, _) => Value::Number(v, CompletionState::Complete), - Value::Boolean(v) => Value::Boolean(v), - Value::Null => Value::Null, - Value::Object(v, _) => Value::Object(v, CompletionState::Complete), - Value::Array(v, _) => Value::Array( - v.into_iter().map(|v| v.completed_deeply()).collect(), - CompletionState::Complete, - ), - Value::Markdown(x, y, _) => Value::Markdown(x, y, CompletionState::Complete), - Value::FixedJson(x, y) => Value::FixedJson(x, y), - Value::AnyOf(choices, s) => Value::AnyOf( - choices.into_iter().map(|v| v.completed_deeply()).collect(), - s, - ), - } - } } impl std::fmt::Display for Value { @@ -226,14 +207,14 @@ impl std::fmt::Display for Value { // true for nested values, because serde will call the same `deserialize` // method on children of a serde container. // -// Numbers and strings should be considered Incomplete if they are encountered +// Numbers should be considered Incomplete if they are encountered // at the top level. Therefore the non-recursive callsite of `deserialize` // is responsible for setting completion state to Incomplete for top-level // strings and numbers. // -// Lists and objects at the top level are necessarily complete, because -// serde will not parse an array or an object unless the closing delimiter -// is present. +// Lists, strings and objects at the top level are necessarily complete, because +// serde will not parse an array, string or an object unless the closing +// delimiter is present. impl<'de> serde::Deserialize<'de> for Value { fn deserialize(deserializer: D) -> Result where @@ -242,7 +223,7 @@ impl<'de> serde::Deserialize<'de> for Value { let value = serde_json::Value::deserialize(deserializer)?; match value { serde_json::Value::String(s) => Ok(Value::String(s, CompletionState::Complete)), - serde_json::Value::Number(n) => Ok(Value::Number(n, CompletionState::Complete)), + serde_json::Value::Number(n) => Ok(Value::Number(n, CompletionState::Incomplete)), serde_json::Value::Bool(b) => Ok(Value::Boolean(b)), serde_json::Value::Null => Ok(Value::Null), serde_json::Value::Object(o) => { diff --git a/engine/baml-lib/jsonish/src/lib.rs b/engine/baml-lib/jsonish/src/lib.rs index 6d664bb12f..763ff8d5f7 100644 --- a/engine/baml-lib/jsonish/src/lib.rs +++ b/engine/baml-lib/jsonish/src/lib.rs @@ -35,18 +35,17 @@ pub enum SerializeMode { Partial, } -// impl serde::Serialize for (ResponseBamlValue, SerializeMode) { -// fn serialize(&self, serializer: S) -> Result { -// SerializeResponseBamlValue{ value: &self.0.0, serialize_mode: self.1 }.serialize(serializer) -// } -// } - +/// A special-purpose wrapper for specifying the serialization format of a +/// `ResponseBamlValue`. You should construct these from `ResponseBamlValue` +/// with the `serialize_final` or `serialize_partial` method. pub struct SerializeResponseBamlValue<'a>{ pub value: &'a BamlValueWithMeta<(Vec, Vec, Completion)>, pub serialize_mode: SerializeMode, } impl ResponseBamlValue { + /// Prepare a `ResponseBamlValue` for "final" serialization (serialization + /// with no stream-state metadata). pub fn serialize_final<'a> (&'a self) -> SerializeResponseBamlValue<'a> { SerializeResponseBamlValue { value: &self.0, @@ -54,6 +53,8 @@ impl ResponseBamlValue { } } + /// Prepare a `ResponseBamlValue` for "partial" serialization (serialization + /// with stream-state metadata). pub fn serialize_partial<'a> (&'a self) -> SerializeResponseBamlValue<'a> { SerializeResponseBamlValue { value: &self.0, @@ -62,7 +63,6 @@ impl ResponseBamlValue { } } - impl serde::Serialize for SerializeResponseBamlValue<'_> { fn serialize(&self, serializer: S) -> Result { use BamlValueWithMeta::*; @@ -163,25 +163,10 @@ pub fn from_str( // When the schema is just a string, i should really just return the raw_string w/o parsing it. let value = jsonish::parse(raw_string, jsonish::ParseOptions::default())?; - // let schema = deserializer::schema::from_jsonish_value(&value, None); - // eprintln!("value: {value:?}"); - - // See Note [Streaming Number Invalidation] - if allow_partials { - // invalidate_numbers_in_progress(&mut value, raw_string); - } // Pick the schema that is the most specific. - // log::info!("Parsed: {}", schema); log::debug!("Parsed JSONish (step 1 of parsing): {:#?}", value); let ctx = ParsingContext::new(of, allow_partials); - // let res = schema.cast_to(target); - // log::info!("Casted: {:?}", res); - - // match res { - // Ok(v) => Ok(v), - // Err(e) => anyhow::bail!("Failed to cast value: {}", e), - // } // Determine the best way to get the desired schema from the parsed schema. diff --git a/engine/baml-lib/jsonish/src/tests/test_lists.rs b/engine/baml-lib/jsonish/src/tests/test_lists.rs index 8804d2c6a1..4ebf882eea 100644 --- a/engine/baml-lib/jsonish/src/tests/test_lists.rs +++ b/engine/baml-lib/jsonish/src/tests/test_lists.rs @@ -137,4 +137,4 @@ test_deserializer!( r#"[1234"#, FieldType::List(FieldType::Primitive(TypeValue::Int).into()), [1234] -); \ No newline at end of file +); diff --git a/engine/baml-lib/jsonish/src/tests/test_maps.rs b/engine/baml-lib/jsonish/src/tests/test_maps.rs index 05337a32e7..e093f60d9d 100644 --- a/engine/baml-lib/jsonish/src/tests/test_maps.rs +++ b/engine/baml-lib/jsonish/src/tests/test_maps.rs @@ -212,4 +212,4 @@ test_partial_deserializer_streaming!( FieldType::Literal(LiteralValue::String("B".to_string())), ]), FieldType::string()), {"A": "one", "B": "two"} -); \ No newline at end of file +); diff --git a/engine/baml-lib/schema-ast/src/parser/parse_identifier.rs b/engine/baml-lib/schema-ast/src/parser/parse_identifier.rs index fd5f7d473a..e4a6fb6934 100644 --- a/engine/baml-lib/schema-ast/src/parser/parse_identifier.rs +++ b/engine/baml-lib/schema-ast/src/parser/parse_identifier.rs @@ -90,4 +90,4 @@ fn parse_namespaced_identifier(pair: Pair<'_>, diagnostics: &mut Diagnostics) -> ); Identifier::Local(name_parts.join("::"), span) -} \ No newline at end of file +} diff --git a/engine/baml-runtime/benches/lib.rs b/engine/baml-runtime/benches/lib.rs deleted file mode 100644 index d33fc01002..0000000000 --- a/engine/baml-runtime/benches/lib.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod sap_parser_benchmark; \ No newline at end of file diff --git a/engine/baml-runtime/benches/sap_parser_benchmark.rs b/engine/baml-runtime/benches/sap_parser_benchmark.rs deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs b/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs index 2417386ebf..bc0562efe5 100644 --- a/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs +++ b/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs @@ -96,7 +96,7 @@ fn find_new_class_field( let name = Name::new_with_alias(field_name.to_string(), alias.value()); let desc = desc.value(); - Ok((name, field_overrides.0.clone(), desc, false)) // TODO: Field overrides are not "stream.not_nul". Should this be configurable? + Ok((name, field_overrides.0.clone(), desc, false)) // TODO: Field overrides are not "stream.not_null". Should this be configurable? } fn find_existing_class_field( diff --git a/engine/baml-schema-wasm/src/runtime_wasm/mod.rs b/engine/baml-schema-wasm/src/runtime_wasm/mod.rs index 1b4ecca501..d53d68e44b 100644 --- a/engine/baml-schema-wasm/src/runtime_wasm/mod.rs +++ b/engine/baml-schema-wasm/src/runtime_wasm/mod.rs @@ -574,9 +574,6 @@ impl WasmTestResponse { _ => Err(anyhow::anyhow!("No parsed value")), } .context("No parsed value")?; - // let baml_value_with_response_checks = parsed_response - // .0 - // .map_meta(|(_, response_checks, _)| response_checks.clone()); let (flattened_checks, check_count) = serialize_value_counting_checks(&parsed_response); Ok(WasmParsedTestResponse { value: serde_json::to_string(&flattened_checks)?, diff --git a/engine/language_client_codegen/src/python/templates/types.py.j2 b/engine/language_client_codegen/src/python/templates/types.py.j2 index 997ad8f009..bbc7a7e9ab 100644 --- a/engine/language_client_codegen/src/python/templates/types.py.j2 +++ b/engine/language_client_codegen/src/python/templates/types.py.j2 @@ -23,6 +23,7 @@ def get_checks(checks: Dict[CheckName, Check]) -> List[Check]: def all_succeeded(checks: Dict[CheckName, Check]) -> bool: return all(check.status == "succeeded" for check in get_checks(checks)) + {# Enums -#} {% for enum in enums %} class {{enum.name}}(str, Enum): From 65780a2a707445a1011bd34bcd1ba00eef1d742b Mon Sep 17 00:00:00 2001 From: Greg Hale Date: Tue, 28 Jan 2025 14:30:43 -0800 Subject: [PATCH 11/13] fix tests --- engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs | 8 ++++---- engine/baml-lib/jsonish/src/jsonish/value.rs | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs b/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs index e8cbf6ef6c..deb6a8ba4f 100644 --- a/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs +++ b/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs @@ -1489,10 +1489,10 @@ mod subtype_tests { &FieldType::RecursiveTypeAlias("JsonValue".to_string()), &example_json ), - Some(FieldType::Map( - Box::new(FieldType::Primitive(TypeValue::String)), - Box::new(FieldType::RecursiveTypeAlias("JsonValue".to_string())) - )) + Some(FieldType::Union(vec![ + FieldType::RecursiveTypeAlias("JsonValue".to_string()), + FieldType::RecursiveTypeAlias("JsonValue".to_string()) + ])) ); } } diff --git a/engine/baml-lib/jsonish/src/jsonish/value.rs b/engine/baml-lib/jsonish/src/jsonish/value.rs index a715a2913e..b0c1400cd7 100644 --- a/engine/baml-lib/jsonish/src/jsonish/value.rs +++ b/engine/baml-lib/jsonish/src/jsonish/value.rs @@ -223,7 +223,7 @@ impl<'de> serde::Deserialize<'de> for Value { let value = serde_json::Value::deserialize(deserializer)?; match value { serde_json::Value::String(s) => Ok(Value::String(s, CompletionState::Complete)), - serde_json::Value::Number(n) => Ok(Value::Number(n, CompletionState::Incomplete)), + serde_json::Value::Number(n) => Ok(Value::Number(n, CompletionState::Complete)), serde_json::Value::Bool(b) => Ok(Value::Boolean(b)), serde_json::Value::Null => Ok(Value::Null), serde_json::Value::Object(o) => { From 787862158a669bfb5f7c41a7ffc6d238c7aeb4c9 Mon Sep 17 00:00:00 2001 From: Greg Hale Date: Tue, 28 Jan 2025 14:32:06 -0800 Subject: [PATCH 12/13] regen types.py --- integ-tests/python/baml_client/types.py | 1 + 1 file changed, 1 insertion(+) diff --git a/integ-tests/python/baml_client/types.py b/integ-tests/python/baml_client/types.py index e54edaa290..73aae3f586 100644 --- a/integ-tests/python/baml_client/types.py +++ b/integ-tests/python/baml_client/types.py @@ -38,6 +38,7 @@ def all_succeeded(checks: Dict[CheckName, Check]) -> bool: return all(check.status == "succeeded" for check in get_checks(checks)) + class AliasedEnum(str, Enum): KEY_ONE = "KEY_ONE" From 56c8156f70b5009c5d8d64f30503be16e045f314 Mon Sep 17 00:00:00 2001 From: Greg Hale Date: Tue, 28 Jan 2025 15:02:23 -0800 Subject: [PATCH 13/13] format --- engine/baml-lib/baml-core/src/ir/repr.rs | 20 ++++++++++---------- engine/baml-lib/jsonish/src/lib.rs | 19 +------------------ 2 files changed, 11 insertions(+), 28 deletions(-) diff --git a/engine/baml-lib/baml-core/src/ir/repr.rs b/engine/baml-lib/baml-core/src/ir/repr.rs index 6f8f455590..affd6872b4 100644 --- a/engine/baml-lib/baml-core/src/ir/repr.rs +++ b/engine/baml-lib/baml-core/src/ir/repr.rs @@ -2,8 +2,8 @@ use std::collections::HashSet; use anyhow::{anyhow, Result}; use baml_types::{ - Constraint, ConstraintLevel, FieldType, JinjaExpression, Resolvable, StreamingBehavior, - StringOr, UnresolvedValue, + Constraint, ConstraintLevel, FieldType, JinjaExpression, Resolvable, StreamingBehavior, StringOr, + UnresolvedValue, }; use either::Either; use indexmap::{IndexMap, IndexSet}; @@ -15,9 +15,7 @@ use internal_baml_parser_database::{ Attributes, ParserDatabase, PromptAst, RetryPolicyStrategy, TypeWalker, }; -use internal_baml_schema_ast::ast::{ - self, Attribute, FieldArity, SubType, ValExpId, WithName, WithSpan, -}; +use internal_baml_schema_ast::ast::{self, Attribute, FieldArity, SubType, ValExpId, WithName, WithSpan}; use internal_llm_client::{ClientProvider, ClientSpec, UnresolvedClientProperty}; use serde::Serialize; @@ -349,7 +347,10 @@ fn to_ir_attributes( }); let streaming_done = streaming_done.as_ref().and_then(|v| { if *v { - Some(("stream.done".to_string(), UnresolvedValue::Bool(true, ()))) + Some(( + "stream.done".to_string(), + UnresolvedValue::Bool(true, ()), + )) } else { None } @@ -593,6 +594,7 @@ impl WithRepr for ast::FieldType { ), }; + let use_metadata = has_constraints || has_special_streaming_behavior; let with_constraints = if use_metadata { FieldType::WithMetadata { @@ -1438,6 +1440,7 @@ mod tests { let alias = class.find_field("field").unwrap(); assert_eq!(*alias.r#type(), FieldType::Primitive(TypeValue::Int)); + } #[test] @@ -1458,10 +1461,7 @@ mod tests { let class = ir.find_class("Test").unwrap(); let alias = class.find_field("field").unwrap(); - let FieldType::WithMetadata { - base, constraints, .. - } = alias.r#type() - else { + let FieldType::WithMetadata { base, constraints, .. } = alias.r#type() else { panic!( "expected resolved constrained type, found {:?}", alias.r#type() diff --git a/engine/baml-lib/jsonish/src/lib.rs b/engine/baml-lib/jsonish/src/lib.rs index 763ff8d5f7..00d9a81d5c 100644 --- a/engine/baml-lib/jsonish/src/lib.rs +++ b/engine/baml-lib/jsonish/src/lib.rs @@ -243,21 +243,4 @@ impl WithScore for ResponseBamlValue { fn score(&self) -> i32 { self.0.iter().map(|node| node.meta().0.score()).sum() } -} - -// impl SerializeMetadata for ResponseBamlValue { -// fn metadata_fields(&self) -> Vec<(String, serde_json::Value)> { -// let mut fields = Vec::new(); -// let checks: Vec<(&str, &ResponseCheck)> = self.0.meta().1.iter().map(|check| (check.name.as_str(), check)).collect(); -// if !checks.is_empty() { -// let checks_json = serde_json::to_value(checks).expect("Serializing checks is safe."); -// fields.push(("checks".to_string(), checks_json)); -// } -// let completion_state: Option<&CompletionState> = self.0.meta().2.as_ref(); -// if let Some(state) = completion_state { -// let completion_state_json = serde_json::to_value(&state).expect("Serializing completion state is safe."); -// fields.push(("completion_state".to_string(), completion_state_json)); -// } -// fields -// } -// } +} \ No newline at end of file