diff --git a/guidance/_parser.py b/guidance/_parser.py index cec4a0a54..600d0cb43 100644 --- a/guidance/_parser.py +++ b/guidance/_parser.py @@ -8,6 +8,7 @@ class ParserException(Exception): def __init__(self, *args, **kwargs): self.current_byte = kwargs.pop("current_byte", None) self.allowed_bytes = kwargs.pop("allowed_bytes", None) + self.consumed_bytes = kwargs.pop("consumed_bytes", None) super().__init__(*args, **kwargs) @@ -357,6 +358,8 @@ def consume_byte(self, byte, log_prob=0.0): raise ParserException( "Attempted to consume a byte that the grammar does not accept!", current_byte=byte, + allowed_bytes=self.valid_next_bytes(), + consumed_bytes=self.bytes, ) if found_invalid: # only update if we changed the set self.state_sets[self.state_set_pos + 1] = OrderedSet(new_next_state_set) @@ -654,8 +657,9 @@ def _compute_children(self, state_set_pos, item, reversed_state_sets, values_pos if self._compute_children( state_set_pos, item, reversed_state_sets, values_pos + 1 ): - item.children[values_pos] = ( - EarleyItem(value, tuple(), 0, state_set_pos, 0, state_set_pos) # this child has zero length since it was nullable + # this child has zero length since it was nullable + item.children[values_pos] = EarleyItem( + value, tuple(), 0, state_set_pos, 0, state_set_pos ) return True diff --git a/tests/library/test_json.py b/tests/library/test_json.py index 3c127a864..14a0a6458 100644 --- a/tests/library/test_json.py +++ b/tests/library/test_json.py @@ -1,13 +1,14 @@ import json -from typing import Any, Union +from typing import Any, Union, Set, Dict import pytest from jsonschema import validate from guidance import json as gen_json from guidance import models -from guidance._parser import ParserException +from guidance._grammar import Byte, ByteRange from guidance.library._json import _to_compact_json +from ..utils import check_match_failure as _check_match_failure def _generate_and_check( @@ -26,7 +27,7 @@ def _generate_and_check( # So append a 'stop' character which we don't # use in any of our tests - STOP_CHAR = "\g" + STOP_CHAR = chr(7) prepared_json = _to_compact_json(target_obj) assert STOP_CHAR not in prepared_json, "STOP_CHAR in string" @@ -64,11 +65,26 @@ def _generate_and_check( ) -def _check_match_failure(bad_string, failure_byte, schema_obj): +def check_match_failure( + bad_string: str, + good_bytes: bytes, + failure_byte: bytes, + allowed_bytes: Set[Union[Byte, ByteRange]], + schema_obj: Dict[str, Any], +): grammar = gen_json(schema=schema_obj) - with pytest.raises(ParserException) as pe: - grammar.match(bad_string, raise_exceptions=True) - assert pe.value.current_byte == failure_byte + _check_match_failure( + bad_string=bad_string, + good_bytes=good_bytes, + failure_byte=failure_byte, + allowed_bytes=allowed_bytes, + grammar=grammar, + ) + + +# Common sets of allowed_bytes +INTEGER_LEADING = {Byte(b"-"), Byte(b"0"), ByteRange(b"19")} +INTEGER_FOLLOWING = {ByteRange(b"09")} def test_null(): @@ -111,19 +127,25 @@ def test_integer_schema(self, my_int): _generate_and_check(my_int, schema_obj) @pytest.mark.parametrize( - ["bad_string", "failure_byte"], + ["bad_string", "good_bytes", "failure_byte", "allowed_bytes"], [ - ("9999a7777", b"a"), - ("123, []", b","), - ("a321", b"a"), - ("123789.456", b"."), - ("[]", b"["), - ('{"a":4}', b"{"), + ("9999a7777", b"9999", b"a", INTEGER_FOLLOWING), + ("123, []", b"123", b",", INTEGER_FOLLOWING), + ("a321", b"", b"a", INTEGER_LEADING), + ("123789.456", b"123789", b".", INTEGER_FOLLOWING), + ("[]", b"", b"[", INTEGER_LEADING), + ('{"a":4}', b"", b"{", INTEGER_LEADING), ], ) - def test_bad_integer(self, bad_string, failure_byte): + def test_bad_integer(self, bad_string, good_bytes, failure_byte, allowed_bytes): schema_obj = json.loads(TestInteger.schema) - _check_match_failure(bad_string, failure_byte, schema_obj) + check_match_failure( + bad_string=bad_string, + good_bytes=good_bytes, + failure_byte=failure_byte, + allowed_bytes=allowed_bytes, + schema_obj=schema_obj, + ) class TestNumber: @@ -160,18 +182,24 @@ def test_number(self, target_obj, temperature): _generate_and_check(target_obj, schema_obj, desired_temperature=temperature) @pytest.mark.parametrize( - ["bad_string", "failure_byte"], + ["bad_string", "good_bytes", "failure_byte", "allowed_bytes"], [ - ("9999a7777", b"a"), - ("123.6, []", b","), - ("a321", b"a"), - ("[]", b"["), - ('{"a":4}', b"{"), + ("9999a7777", b"9999", b"a", {Byte(b"e"), Byte(b"."), *INTEGER_FOLLOWING}), + ("123.6, []", b"123.6", b",", {Byte(b"e"), *INTEGER_FOLLOWING}), + ("a321", b"", b"a", INTEGER_LEADING), + ("[]", b"", b"[", INTEGER_LEADING), + ('{"a":4}', b"", b"{", INTEGER_LEADING), ], ) - def test_bad_number(self, bad_string, failure_byte): + def test_bad_number(self, bad_string, good_bytes, failure_byte, allowed_bytes): schema_obj = json.loads(TestNumber.schema) - _check_match_failure(bad_string, failure_byte, schema_obj) + check_match_failure( + bad_string=bad_string, + good_bytes=good_bytes, + failure_byte=failure_byte, + allowed_bytes=allowed_bytes, + schema_obj=schema_obj, + ) @pytest.mark.parametrize( @@ -308,14 +336,14 @@ def test_object_containing_list(self, temperature): _generate_and_check(target_obj, schema_obj, desired_temperature=temperature) @pytest.mark.parametrize( - ["bad_string", "failure_byte"], + ["bad_string", "good_bytes", "failure_byte", "allowed_bytes"], [ - ("9999a7777", b"9"), - ('{"a":1255.4567}', b"."), - ('{"a":"123"}', b'"'), + ("9999a7777", b"", b"9", {Byte(b"{")}), + ('{"a":1255.4567}', b'{"a":1255', b".", {Byte(b"}"), *INTEGER_FOLLOWING}), + ('{"a":"123"}', b'{"a":', b'"', INTEGER_LEADING), ], ) - def test_bad_object(self, bad_string, failure_byte): + def test_bad_object(self, bad_string, good_bytes, failure_byte, allowed_bytes): schema = """{ "type": "object", "properties": { @@ -324,7 +352,13 @@ def test_bad_object(self, bad_string, failure_byte): } """ schema_obj = json.loads(schema) - _check_match_failure(bad_string, failure_byte, schema_obj) + check_match_failure( + bad_string=bad_string, + good_bytes=good_bytes, + failure_byte=failure_byte, + allowed_bytes=allowed_bytes, + schema_obj=schema_obj, + ) class TestSimpleArray: @@ -391,14 +425,14 @@ def test_object_list(self, target_obj, temperature): _generate_and_check(target_obj, schema_obj, desired_temperature=temperature) @pytest.mark.parametrize( - ["bad_string", "failure_byte"], + ["bad_string", "good_bytes", "failure_byte", "allowed_bytes"], [ - ("9999a7777", b"9"), - ("[321.654]", b"."), - ('["123"]', b'"'), + ("9999a7777", b"", b"9", {Byte(b"[")}), + ("[321.654]", b"[321", b".", {Byte(b"]"), Byte(b","), *INTEGER_FOLLOWING}), + ('["123"]', b"[", b'"', {Byte(b"]"), *INTEGER_LEADING}), ], ) - def test_bad_object(self, bad_string, failure_byte): + def test_bad_object(self, bad_string, good_bytes, failure_byte, allowed_bytes): schema = """{ "type" : "array", "items" : { @@ -406,7 +440,13 @@ def test_bad_object(self, bad_string, failure_byte): } }""" schema_obj = json.loads(schema) - _check_match_failure(bad_string, failure_byte, schema_obj) + check_match_failure( + bad_string=bad_string, + good_bytes=good_bytes, + failure_byte=failure_byte, + allowed_bytes=allowed_bytes, + schema_obj=schema_obj, + ) class TestArrayWithLengthConstraints: @@ -494,44 +534,68 @@ def test_good_with_items(self, min_items, max_items, target_obj): _generate_and_check(target_obj, schema_obj) @pytest.mark.parametrize( - "min_items, max_items, bad_obj, failure_byte", + "min_items, max_items, bad_obj, good_bytes, failure_byte, allowed_bytes", [ ( 1, 4, [42, "string_not_bool", "hello", "extra"], + b"[42,", b'"', + {Byte(b"t"), Byte(b"f")}, ), # Second item does not match prefix schema ( 0, 3, [42, True, 100], + b"[42,true,", b"1", + {Byte(b'"')}, ), # Last item does not match general item schema ( 3, 5, [42, True, "valid", "extra1", "extra2", "too_many"], + b'[42,true,"valid","extra1","extra2"', b",", + {Byte(b"]")}, ), # Exceeds maxItems - (2, 3, [42], b"]"), # Not enough items - (1, 1, [42, True], b","), # Too many items for maxItems + ( + 2, + 3, + [42], + b"[42", + b"]", + {Byte(b","), *INTEGER_FOLLOWING}, + ), # Not enough items + ( + 1, + 1, + [42, True], + b"[42", + b",", + {Byte(b"]"), *INTEGER_FOLLOWING}, + ), # Too many items for maxItems ( 0, 0, [42, True, "str"], + b"[", b"4", + {Byte(b"]")}, ), # maxItems set to 0, but array is not empty ( 3, 5, [42, True], + b"[42,true", b"]", + {Byte(b",")}, ), # Array has one fewer item than required by minItems ], ) def test_bad_with_prefix_and_items( - self, min_items, max_items, bad_obj, failure_byte + self, min_items, max_items, bad_obj, good_bytes, failure_byte, allowed_bytes ): schema_obj = { "prefixItems": self.prefix_schema_obj, @@ -541,39 +605,62 @@ def test_bad_with_prefix_and_items( "type": "array", } bad_string = _to_compact_json(bad_obj) - _check_match_failure(bad_string, failure_byte, schema_obj) + check_match_failure( + bad_string=bad_string, + good_bytes=good_bytes, + failure_byte=failure_byte, + allowed_bytes=allowed_bytes, + schema_obj=schema_obj, + ) @pytest.mark.parametrize( - "min_items, max_items, bad_obj, failure_byte", + "min_items, max_items, bad_obj, good_bytes, failure_byte, allowed_bytes", [ ( 2, 2, [42], + b"[42", b"]", + {Byte(b","), *INTEGER_FOLLOWING}, ), # Array too short to meet minItems, despite matching prefixItems ( 1, 2, [42, "not_bool"], + b"[42,", b'"', + {Byte(b"t"), Byte(b"f")}, ), # Second item violates prefixItems type requirement ( 0, 1, [42, True], + b"[42", b",", + {Byte(b"]"), *INTEGER_FOLLOWING}, ), # Array exceeds maxItems with valid prefixItems types ( 1, 5, [42, True, "extra"], + b"[42,true", b",", + {Byte(b"]")}, ), # Item beyond prefixItems with no "items" schema - (0, 0, [42], b"4"), # maxItems set to 0, but array is not empty + ( + 0, + 0, + [42], + b"[", + b"4", + {Byte(b"]")}, + ), # maxItems set to 0, but array is not empty ], ) - def test_bad_with_prefix(self, min_items, max_items, bad_obj, failure_byte): + def test_bad_with_prefix( + self, min_items, max_items, bad_obj, good_bytes, failure_byte, allowed_bytes + ): schema_obj = { "prefixItems": self.prefix_schema_obj, "minItems": min_items, @@ -581,18 +668,54 @@ def test_bad_with_prefix(self, min_items, max_items, bad_obj, failure_byte): "type": "array", } bad_string = _to_compact_json(bad_obj) - _check_match_failure(bad_string, failure_byte, schema_obj) + check_match_failure( + bad_string=bad_string, + good_bytes=good_bytes, + failure_byte=failure_byte, + allowed_bytes=allowed_bytes, + schema_obj=schema_obj, + ) @pytest.mark.parametrize( - "min_items, max_items, bad_obj, failure_byte", + "min_items, max_items, bad_obj, good_bytes, failure_byte, allowed_bytes", [ - (1, 2, ["hello", "world", "extra"], b","), # Too many items for maxItems - (2, 3, ["hello"], b"]"), # Not enough items - (2, 3, ["hello", 42], b"4"), # Badly typed second item - (0, 0, ["hello"], b'"'), # maxItems set to 0, but array is not empty + ( + 1, + 2, + ["hello", "world", "extra"], + b'["hello","world"', + b",", + {Byte(b"]")}, + ), # Too many items for maxItems + ( + 2, + 3, + ["hello"], + b'["hello"', + b"]", + {Byte(b",")}, + ), # Not enough items + ( + 2, + 3, + ["hello", 42], + b'["hello",', + b"4", + {Byte(b'"')}, + ), # Badly typed second item + ( + 0, + 0, + ["hello"], + b"[", + b'"', + {Byte(b"]")}, + ), # maxItems set to 0, but array is not empty ], ) - def test_bad_with_items(self, min_items, max_items, bad_obj, failure_byte): + def test_bad_with_items( + self, min_items, max_items, bad_obj, good_bytes, failure_byte, allowed_bytes + ): schema_obj = { "items": self.items_schema_obj, "minItems": min_items, @@ -600,7 +723,13 @@ def test_bad_with_items(self, min_items, max_items, bad_obj, failure_byte): "type": "array", } bad_string = _to_compact_json(bad_obj) - _check_match_failure(bad_string, failure_byte, schema_obj) + check_match_failure( + bad_string=bad_string, + good_bytes=good_bytes, + failure_byte=failure_byte, + allowed_bytes=allowed_bytes, + schema_obj=schema_obj, + ) class TestWithReferences: @@ -936,30 +1065,42 @@ def test_enum(self, target_obj, temperature): _generate_and_check(target_obj, schema_obj, desired_temperature=temperature) @pytest.mark.parametrize( - "bad_obj, failure_byte", + "bad_obj, good_bytes, failure_byte, allowed_bytes", [ - ("1", b"1"), - (2, b"2"), - (True, b"t"), + ("1", b'"', b"1", {Byte(b"2")}), + (2, b"", b"2", {Byte(b'"'), Byte(b"1"), Byte(b"f")}), + (True, b"", b"t", {Byte(b'"'), Byte(b"1"), Byte(b"f")}), ], ) - def test_bad_enum(self, bad_obj, failure_byte): + def test_bad_enum(self, bad_obj, good_bytes, failure_byte, allowed_bytes): schema_obj = json.loads(self.simple_schema) - bad_str = _to_compact_json(bad_obj) - _check_match_failure(bad_str, failure_byte, schema_obj) + bad_string = _to_compact_json(bad_obj) + check_match_failure( + bad_string=bad_string, + good_bytes=good_bytes, + failure_byte=failure_byte, + allowed_bytes=allowed_bytes, + schema_obj=schema_obj, + ) @pytest.mark.parametrize( - "bad_obj, failure_byte", + "bad_obj, good_bytes, failure_byte, allowed_bytes", [ - ("ab", b"b"), - ("bc", b"c"), - ("ca", b"a"), + ("ab", b'"a', b"b", {Byte(b"a")}), + ("bc", b'"b', b"c", {Byte(b"b")}), + ("ca", b'"c', b"a", {Byte(b"c")}), ], ) - def test_bad_prefix_enum(self, bad_obj, failure_byte): + def test_bad_prefix_enum(self, bad_obj, good_bytes, failure_byte, allowed_bytes): schema_obj = json.loads(self.prefix_schema) - bad_str = _to_compact_json(bad_obj) - _check_match_failure(bad_str, failure_byte, schema_obj) + bad_string = _to_compact_json(bad_obj) + check_match_failure( + bad_string=bad_string, + good_bytes=good_bytes, + failure_byte=failure_byte, + allowed_bytes=allowed_bytes, + schema_obj=schema_obj, + ) class TestAdditionalProperties: @@ -976,7 +1117,7 @@ class TestAdditionalProperties: "type": "object", "additionalProperties": { "anyOf": [ - {"type" : "string"}, + {"type": "string"}, {"type": "integer"} ] } @@ -1005,16 +1146,27 @@ def test_simple_additional_properties(self, target_obj, temperature): _generate_and_check(target_obj, schema_obj, desired_temperature=temperature) @pytest.mark.parametrize( - "bad_obj, failure_byte", + "bad_obj, good_bytes, failure_byte, allowed_bytes", [ - ({"a": "1"}, b'"'), - ({"a": 1, "b": 1.5}, b"."), + ({"a": "1"}, b'{"a":', b'"', INTEGER_LEADING), + ( + {"a": 1, "b": 1.5}, + b'{"a":1,"b":1', + b".", + {Byte(b","), Byte(b"}"), *INTEGER_FOLLOWING}, + ), ], ) - def test_simple_bad_type(self, bad_obj, failure_byte): + def test_simple_bad_type(self, bad_obj, good_bytes, failure_byte, allowed_bytes): schema_obj = json.loads(self.simple_schema) bad_string = _to_compact_json(bad_obj) - _check_match_failure(bad_string, failure_byte, schema_obj) + check_match_failure( + bad_string=bad_string, + good_bytes=good_bytes, + failure_byte=failure_byte, + allowed_bytes=allowed_bytes, + schema_obj=schema_obj, + ) @pytest.mark.parametrize( "target_obj", [{}, {"a": 1}, {"a": "2"}, {"a": 1, "b": "2"}] @@ -1028,13 +1180,28 @@ def test_anyOf_additional_properties(self, target_obj): _generate_and_check(target_obj, schema_obj) @pytest.mark.parametrize( - "bad_obj, failure_byte", - [({"a": 1.5}, b"."), ({"a": True}, b"t"), ({"a": 1, "b": False}, b"f")], + "bad_obj, good_bytes, failure_byte, allowed_bytes", + [ + ({"a": 1.5}, b'{"a":1', b".", {Byte(b","), Byte(b"}"), *INTEGER_FOLLOWING}), + ({"a": True}, b'{"a":', b"t", {Byte(b'"'), *INTEGER_LEADING}), + ( + {"a": 1, "b": False}, + b'{"a":1,"b":', + b"f", + {Byte(b'"'), *INTEGER_LEADING}, + ), + ], ) - def test_anyOf_bad_type(self, bad_obj, failure_byte): + def test_anyOf_bad_type(self, bad_obj, good_bytes, failure_byte, allowed_bytes): schema_obj = json.loads(self.anyOf_schema) bad_string = _to_compact_json(bad_obj) - _check_match_failure(bad_string, failure_byte, schema_obj) + check_match_failure( + bad_string=bad_string, + good_bytes=good_bytes, + failure_byte=failure_byte, + allowed_bytes=allowed_bytes, + schema_obj=schema_obj, + ) @pytest.mark.parametrize( "target_obj", @@ -1054,30 +1221,49 @@ def test_properties_and_additional_properties(self, target_obj, temperature): _generate_and_check(target_obj, schema_obj, desired_temperature=temperature) @pytest.mark.parametrize( - "bad_obj, failure_byte", + "bad_obj, good_bytes, failure_byte, allowed_bytes", [ - ({}, b"}"), - ({"a": 1}, b"a"), - ({"a": 1, "b": 2}, b"a"), + ({}, b"{", b"}", {Byte(b'"')}), + ({"a": 1}, b'{"', b"a", {Byte(b"m")}), + ({"a": 1, "b": 2}, b'{"', b"a", {Byte(b"m")}), ], ) - def test_combined_missing_properties(self, bad_obj, failure_byte): + def test_combined_missing_properties( + self, bad_obj, good_bytes, failure_byte, allowed_bytes + ): schema_obj = json.loads(self.combined_schema) bad_string = _to_compact_json(bad_obj) - _check_match_failure(bad_string, failure_byte, schema_obj) + check_match_failure( + bad_string=bad_string, + good_bytes=good_bytes, + failure_byte=failure_byte, + allowed_bytes=allowed_bytes, + schema_obj=schema_obj, + ) @pytest.mark.parametrize( - "bad_obj, failure_byte", + "bad_obj, good_bytes, failure_byte, allowed_bytes", [ - ({"mystr": 1}, b"1"), - ({"mystr": 1, "a": 2}, b"1"), - ({"mystr": "hello", "a": False}, b"f"), + ({"mystr": 1}, b'{"mystr":', b"1", {Byte(b'"')}), + ({"mystr": 1, "a": 2}, b'{"mystr":', b"1", {Byte(b'"')}), + ( + {"mystr": "hello", "a": False}, + b'{"mystr":"hello","a":', + b"f", + INTEGER_LEADING, + ), ], ) - def test_combined_bad_type(self, bad_obj, failure_byte): + def test_combined_bad_type(self, bad_obj, good_bytes, failure_byte, allowed_bytes): schema_obj = json.loads(self.combined_schema) bad_string = _to_compact_json(bad_obj) - _check_match_failure(bad_string, failure_byte, schema_obj) + check_match_failure( + bad_string=bad_string, + good_bytes=good_bytes, + failure_byte=failure_byte, + allowed_bytes=allowed_bytes, + schema_obj=schema_obj, + ) class TestRecursiveStructures: diff --git a/tests/library/test_pydantic.py b/tests/library/test_pydantic.py index 447285db3..016d9a4e5 100644 --- a/tests/library/test_pydantic.py +++ b/tests/library/test_pydantic.py @@ -1,6 +1,6 @@ import inspect from json import dumps as json_dumps -from typing import Any, Dict, Generic, List, Literal, Tuple, Type, TypeVar, Union +from typing import Any, Dict, Generic, List, Literal, Tuple, Type, TypeVar, Union, Set import pydantic import pytest @@ -8,7 +8,8 @@ from guidance import json as gen_json from guidance import models -from guidance._parser import ParserException +from guidance._grammar import Byte, ByteRange +from ..utils import check_match_failure as _check_match_failure def to_compact_json(target: Any) -> str: @@ -87,14 +88,20 @@ def generate_and_check( def check_match_failure( bad_obj: Any, + good_bytes: bytes, failure_byte: bytes, + allowed_bytes: Set[Union[Byte, ByteRange]], pydantic_model: Union[Type[pydantic.BaseModel], pydantic.TypeAdapter], ): bad_string = to_compact_json(bad_obj) grammar = gen_json(schema=pydantic_model) - with pytest.raises(ParserException) as pe: - grammar.match(bad_string, raise_exceptions=True) - assert pe.value.current_byte == failure_byte + _check_match_failure( + bad_string=bad_string, + good_bytes=good_bytes, + failure_byte=failure_byte, + allowed_bytes=allowed_bytes, + grammar=grammar, + ) def test_simple_model(): @@ -169,7 +176,13 @@ def test_heterogeneous(self): def test_maxitems(self): model = pydantic.TypeAdapter(Tuple[int,]) - check_match_failure((1, 2), b",", model) + check_match_failure( + bad_obj=(1, 2), + good_bytes=b"[1", + failure_byte=b",", + allowed_bytes={ByteRange(b"09"), Byte(b"]")}, + pydantic_model=model, + ) class TestDict: @@ -239,14 +252,22 @@ def test_generic(self, my_type, my_obj): generate_and_check(obj, model) @pytest.mark.parametrize( - "my_type, my_obj, failure_byte", + "my_type, my_obj, good_bytes, failure_byte, allowed_bytes", [ - (bool, "True", b'"'), - (str, 42, b"4"), - (int, False, b"f"), + (bool, "True", b"", b'"', {Byte(b"t"), Byte(b"f")}), + (str, 42, b"", b"4", {Byte(b'"')}), + (int, False, b"", b"f", {Byte(b"0"), ByteRange(b"19"), Byte(b"-")}), ], ) - def test_bad_generic(self, my_type, my_obj, failure_byte): + def test_bad_generic( + self, my_type, my_obj, good_bytes, failure_byte, allowed_bytes + ): model = self.SimpleGeneric[my_type] obj = {"my_obj": my_obj} - check_match_failure(obj, failure_byte, model) + check_match_failure( + bad_obj=obj, + good_bytes=b'{"my_obj":' + good_bytes, + failure_byte=failure_byte, + allowed_bytes=allowed_bytes, + pydantic_model=model, + ) diff --git a/tests/utils.py b/tests/utils.py index 0e7982e2c..a8df4802f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,13 +1,16 @@ import os -from typing import Any +from typing import Set, Union import pytest from huggingface_hub import hf_hub_download import guidance +from guidance._grammar import Byte, ByteRange, GrammarFunction +from guidance._parser import ParserException opanai_model_cache = {} + def env_or_fail(var_name: str) -> str: env_value = os.getenv(var_name, None) @@ -15,6 +18,7 @@ def env_or_fail(var_name: str) -> str: return env_value + def get_model(model_name, caching=False, **kwargs): """Get an LLM by name.""" if model_name.startswith("openai:"): @@ -95,9 +99,7 @@ def get_llama_cpp_model(model_name, caching=False, **kwargs): # load it over and over again key = model_name + "_" + str(caching) + "_" + str(kwargs) if key not in llama_cpp_model_cache: - llama_cpp_model_cache[key] = guidance.models.LlamaCpp( - model_name, **kwargs - ) + llama_cpp_model_cache[key] = guidance.models.LlamaCpp(model_name, **kwargs) return llama_cpp_model_cache[key] @@ -132,3 +134,22 @@ def get_azure_guidance_model(model_name, caching=False, **kwargs): ) return azure_guidance_model_cache[key] + + +def check_match_failure( + bad_string: str, + good_bytes: bytes, + failure_byte: bytes, + allowed_bytes: Set[Union[Byte, ByteRange]], + grammar: GrammarFunction, +): + """ + Helper function to check that a string fails to match a grammar after consuming + zero or more bytes. It checks that the consumed bytes are as expected, that the + failure byte is as expected, and that the allowed bytes are as expected. + """ + with pytest.raises(ParserException) as pe: + grammar.match(bad_string, raise_exceptions=True) + assert pe.value.consumed_bytes[:-1] == good_bytes + assert pe.value.current_byte == failure_byte + assert pe.value.allowed_bytes == allowed_bytes