Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiannis128 committed Jan 27, 2024
1 parent 73d9056 commit e64c04e
Showing 1 changed file with 38 additions and 8 deletions.
46 changes: 38 additions & 8 deletions tests/test_ai_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Author: Yiannis Charalambous

from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValue
from langchain.schema import (
AIMessage,
Expand All @@ -16,7 +15,6 @@
AIModel,
AIModels,
get_ai_model_by_name,
AIModelOpenAI,
AIModelTextGen,
)

Expand Down Expand Up @@ -47,12 +45,7 @@ def test_add_custom_ai_model() -> None:

assert is_valid_ai_model(custom_model.name)


def test_add_custom_ai_model_again() -> None:
custom_model: AIModel = AIModel(
name="custom_ai",
tokens=999,
)
# Test add again.

if is_valid_ai_model(custom_model.name):
with raises(Exception):
Expand Down Expand Up @@ -111,3 +104,40 @@ def test_apply_chat_template() -> None:
prompt_text: str = custom_model_2.apply_chat_template(messages=messages).to_string()

assert prompt_text == "System: M1\n\nHuman: M2\n\nAI: M3"


def test_escape_messages() -> None:
"""Tests that the brackets are escaped properly using `AIModel.escape_message`."""

messages = [
HumanMessage(content="Hello my name is {name} and I like {{apples}}"),
SystemMessage(content="Hello my {{name is system} and I like {{{apples}}}"),
AIMessage(content="Hello my {{ {{{apples}}"),
SystemMessage(content="{descreption}{descreption}{{descreption}}{descreption}"),
SystemMessage(content="{descreption}{{{descreption}}}{{{{{descreption}}}}}"),
SystemMessage(
content="{apples}{{{apples}}}{{{{{apples}}}}}{{{{{{{apples}}}}}}}"
),
]

allowed = ["name", "descreption"]

filtered = [
HumanMessage(content="Hello my name is {name} and I like {{apples}}"),
SystemMessage(content="Hello my {{name is system}} and I like {{{{apples}}}}"),
AIMessage(content="Hello my {{ {{{{apples}}"),
SystemMessage(content="{descreption}{descreption}{{descreption}}{descreption}"),
SystemMessage(content="{descreption}{{{descreption}}}{{{{{descreption}}}}}"),
SystemMessage(
content="{{apples}}{{{{apples}}}}{{{{{{apples}}}}}}{{{{{{{{apples}}}}}}}}"
),
]

result = list(AIModel.escape_messages(messages, allowed))

assert result[0] == filtered[0]
assert result[1] == filtered[1]
assert result[2] == filtered[2]
assert result[3] == filtered[3]
assert result[4] == filtered[4]
assert result[5] == filtered[5]

0 comments on commit e64c04e

Please sign in to comment.