Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 96 additions & 3 deletions garak/generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import json
import logging
import re
import tiktoken
from typing import List, Union

import openai
Expand Down Expand Up @@ -113,12 +114,24 @@
"gpt-4o": 128000,
"gpt-4o-2024-05-13": 128000,
"gpt-4o-2024-08-06": 128000,
"gpt-4o-mini": 16384,
"gpt-4o-mini": 128000,
"gpt-4o-mini-2024-07-18": 16384,
"o1-mini": 65536,
"o1": 200000,
"o1-mini": 128000,
"o1-mini-2024-09-12": 65536,
"o1-preview": 32768,
"o1-preview-2024-09-12": 32768,
"o3-mini": 200000,
}

output_max = {
"gpt-3.5-turbo": 4096,
"gpt-4": 8192,
"gpt-4o": 16384,
"o3-mini": 100000,
"o1": 100000,
"o1-mini": 65536,
"gpt-4o-mini": 16384,
}


Expand Down Expand Up @@ -172,6 +185,83 @@ def _clear_client(self):
def _validate_config(self):
pass

def _validate_token_args(self, create_args: dict, prompt: str) -> dict:
"""Ensure maximum token limit compatibility with OpenAI create request"""
token_generation_limit_key = "max_tokens"
fixed_cost = 0
if (
self.generator == self.client.chat.completions
and self.max_tokens is not None
):
token_generation_limit_key = "max_completion_tokens"
if not hasattr(self, "max_completion_tokens"):
create_args["max_completion_tokens"] = self.max_tokens

create_args.pop(
"max_tokens", None
) # remove deprecated value, utilize `max_completion_tokens`
# every reply is primed with <|start|>assistant<|message|> (3 toks) plus 1 for name change
# see https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
# section 6 "Counting tokens for chat completions API calls"
fixed_cost = 7

# basic token boundary validation to ensure requests are not rejected for exceeding target context length
token_generation_limit = create_args.pop(token_generation_limit_key, None)
if token_generation_limit is not None:
# Suppress max_tokens if greater than context_len
if (
hasattr(self, "context_len")
and self.context_len is not None
and token_generation_limit > self.context_len
):
logging.warning(
f"Requested garak maximum tokens {token_generation_limit} exceeds context length {self.context_len}, no limit will be applied to the request"
)
token_generation_limit = None

if (
self.name in output_max
and token_generation_limit > output_max[self.name]
):
logging.warning(
f"Requested maximum tokens {token_generation_limit} exceeds max output {output_max[self.name]}, no limit will be applied to the request"
)
token_generation_limit = None

if self.context_len is not None and token_generation_limit is not None:
# count tokens in prompt and ensure token_generation_limit requested is <= context_len or output_max allowed
prompt_tokens = 0 # this should apply to messages object
try:
encoding = tiktoken.encoding_for_model(self.name)
prompt_tokens = len(encoding.encode(prompt))
except KeyError as e:
prompt_tokens = int(
len(prompt.split()) * 4 / 3
) # extra naive fallback 1 token ~= 3/4 of a word

if (
prompt_tokens + fixed_cost + token_generation_limit
> self.context_len
) and (prompt_tokens + fixed_cost < self.context_len):
token_generation_limit = (
self.context_len - prompt_tokens - fixed_cost
)
elif token_generation_limit > prompt_tokens + fixed_cost:
token_generation_limit = (
token_generation_limit - prompt_tokens - fixed_cost
)
else:
raise garak.exception.GarakException(
"A response of %s toks plus prompt %s toks cannot be generated; API capped at context length %s toks"
% (
self.max_tokens,
prompt_tokens + fixed_cost,
self.context_len,
)
)
create_args[token_generation_limit_key] = token_generation_limit
return create_args

def __init__(self, name="", config_root=_config):
self.name = name
self._load_config(config_root)
Expand Down Expand Up @@ -217,7 +307,8 @@ def _call_model(
create_args = {}
if "n" not in self.suppressed_params:
create_args["n"] = generations_this_call
for arg in inspect.signature(self.generator.create).parameters:
create_params = inspect.signature(self.generator.create).parameters
for arg in create_params:
if arg == "model":
create_args[arg] = self.name
continue
Expand All @@ -231,6 +322,8 @@ def _call_model(
for k, v in self.extra_params.items():
create_args[k] = v

create_args = self._validate_token_args(create_args, prompt)

if self.generator == self.client.completions:
if not isinstance(prompt, str):
msg = (
Expand Down
112 changes: 111 additions & 1 deletion tests/generators/test_openai_compatible.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import inspect

from collections.abc import Iterable
from garak.generators.openai import OpenAICompatible
from garak.generators.openai import OpenAICompatible, output_max, context_lengths
from garak.generators.rest import RestGenerator


Expand Down Expand Up @@ -105,3 +105,113 @@ def test_openai_multiprocessing(openai_compat_mocks, classname):
with Pool(parallel_attempts) as attempt_pool:
for result in attempt_pool.imap_unordered(generate_in_subprocess, prompts):
assert result is not None


def test_validate_call_model_chat_token_restrictions(openai_compat_mocks):
import lorem
import json
import tiktoken
from garak.exception import GarakException

generator = build_test_instance(OpenAICompatible)
mock_url = getattr(generator, "uri", "https://api.openai.com/v1")
with respx.mock(base_url=mock_url, assert_all_called=False) as respx_mock:
mock_response = openai_compat_mocks["chat"]
respx_mock.post("chat/completions").mock(
return_value=httpx.Response(
mock_response["code"], json=mock_response["json"]
)
)
generator._call_model("test values")
req_body = json.loads(respx_mock.routes[0].calls[0].request.content)
assert (
req_body["max_completion_tokens"] <= generator.max_tokens
), "request max_completion_tokens must account for prompt tokens"

test_large_context = ""
encoding = tiktoken.encoding_for_model(MODEL_NAME)
while len(encoding.encode(test_large_context)) < generator.max_tokens:
test_large_context += "\n".join(lorem.paragraph())
large_context_len = len(encoding.encode(test_large_context))

generator.context_len = large_context_len * 2
generator.max_tokens = generator.context_len * 2
generator._call_model("test values")
req_body = json.loads(respx_mock.routes[0].calls[1].request.content)
assert (
req_body.get("max_completion_tokens", None) is None
and req_body.get("max_tokens", None) is None
), "request max_completion_tokens is suppressed when larger than context length"

generator.max_tokens = large_context_len - int(large_context_len / 2)
generator.context_len = large_context_len
with pytest.raises(GarakException) as exc_info:
generator._call_model(test_large_context)
assert "API capped" in str(
exc_info.value
), "a prompt larger than max_tokens must raise exception"

max_output_model = "gpt-3.5-turbo"
generator.name = max_output_model
generator.max_tokens = output_max[max_output_model] * 2
generator.context_len = generator.max_tokens * 2
generator._call_model("test values")
req_body = json.loads(respx_mock.routes[0].calls[2].request.content)
assert (
req_body.get("max_completion_tokens", None) is None
and req_body.get("max_tokens", None) is None
), "request max_completion_tokens is suppressed when larger than output_max limited known model"

generator.max_completion_tokens = int(output_max[max_output_model] / 2)
generator._call_model("test values")
req_body = json.loads(respx_mock.routes[0].calls[3].request.content)
assert (
req_body["max_completion_tokens"] < generator.max_completion_tokens
and req_body.get("max_tokens", None) is None
), "request max_tokens is suppressed when max_completion_tokens is sent"


def test_validate_call_model_completion_token_restrictions(openai_compat_mocks):
import lorem
import json
import tiktoken
from garak.exception import GarakException

generator = build_test_instance(OpenAICompatible)
generator._load_client()
generator.generator = generator.client.completions
mock_url = getattr(generator, "uri", "https://api.openai.com/v1")
with respx.mock(base_url=mock_url, assert_all_called=False) as respx_mock:
mock_response = openai_compat_mocks["completion"]
respx_mock.post("/completions").mock(
return_value=httpx.Response(
mock_response["code"], json=mock_response["json"]
)
)
generator._call_model("test values")
req_body = json.loads(respx_mock.routes[0].calls[0].request.content)
assert (
req_body["max_tokens"] <= generator.max_tokens
), "request max_tokens must account for prompt tokens"

test_large_context = ""
encoding = tiktoken.encoding_for_model(MODEL_NAME)
while len(encoding.encode(test_large_context)) < generator.max_tokens:
test_large_context += "\n".join(lorem.paragraph())
large_context_len = len(encoding.encode(test_large_context))

generator.context_len = large_context_len * 2
generator.max_tokens = generator.context_len * 2
generator._call_model("test values")
req_body = json.loads(respx_mock.routes[0].calls[1].request.content)
assert (
req_body.get("max_tokens", None) is None
), "request max_tokens is suppressed when larger than context length"

generator.max_tokens = large_context_len - int(large_context_len / 2)
generator.context_len = large_context_len
with pytest.raises(GarakException) as exc_info:
generator._call_model(test_large_context)
assert "API capped" in str(
exc_info.value
), "a prompt larger than max_tokens must raise exception"