11import openai
22from typing import Iterator , List , Optional , Dict , Any
3- from urllib .parse import urlparse
43from transformers import AutoTokenizer
54from loguru import logger
65from guidellm .backend import Backend , BackendTypes , GenerativeResponse
@@ -24,8 +23,10 @@ class OpenAIBackend(Backend):
2423 :type path: Optional[str]
2524 :param model: The OpenAI model to use, defaults to the first available model.
2625 :type model: Optional[str]
27- :param model_args: Additional model arguments for the request.
28- :type model_args: Optional[Dict[str, Any]]
26+ :param api_key: The OpenAI API key to use.
27+ :type api_key: Optional[str]
28+ :param request_args: Optional arguments for the OpenAI request.
29+ :type request_args: Dict[str, Any]
2930 """
3031
3132 def __init__ (
@@ -35,21 +36,30 @@ def __init__(
3536 port : Optional [int ] = None ,
3637 path : Optional [str ] = None ,
3738 model : Optional [str ] = None ,
38- ** model_args ,
39+ api_key : Optional [str ] = None ,
40+ ** request_args ,
3941 ):
40- if target :
41- parsed_url = urlparse (target )
42- self .host = parsed_url .hostname
43- self .port = parsed_url .port
44- self .path = parsed_url .path
45- else :
46- self .host = host
47- self .port = port
48- self .path = path
42+ self .target = target
4943 self .model = model
50- self .model_args = model_args
51- openai .api_key = model_args .get ("api_key" , None )
52- logger .info (f"Initialized OpenAIBackend with model: { self .model } " )
44+ self .request_args = request_args
45+
46+ if not self .target :
47+ if not host :
48+ raise ValueError ("Host is required if target is not provided." )
49+
50+ port_incl = f":{ port } " if port else ""
51+ path_incl = path if path else ""
52+ self .target = f"http://{ host } { port_incl } { path_incl } "
53+
54+ openai .api_base = self .target
55+ openai .api_key = api_key
56+
57+ if not model :
58+ self .model = self .default_model ()
59+
60+ logger .info (
61+ f"Initialized OpenAIBackend with target: { self .target } and model: { self .model } "
62+ )
5363
5464 def make_request (self , request : BenchmarkRequest ) -> Iterator [GenerativeResponse ]:
5565 """
@@ -61,14 +71,20 @@ def make_request(self, request: BenchmarkRequest) -> Iterator[GenerativeResponse
6171 :rtype: Iterator[GenerativeResponse]
6272 """
6373 logger .debug (f"Making request to OpenAI backend with prompt: { request .prompt } " )
74+ num_gen_tokens = request .params .get ("generated_tokens" , None )
75+ request_args = {
76+ "n" : 1 ,
77+ }
78+
79+ if num_gen_tokens :
80+ request_args ["max_tokens" ] = num_gen_tokens
81+ request_args ["stop" ] = None
82+
83+ if self .request_args :
84+ request_args .update (self .request_args )
85+
6486 response = openai .Completion .create (
65- engine = self .model or self .default_model (),
66- prompt = request .prompt ,
67- max_tokens = request .params .get ("max_tokens" , 100 ),
68- n = request .params .get ("n" , 1 ),
69- stop = request .params .get ("stop" , None ),
70- stream = True ,
71- ** self .model_args ,
87+ engine = self .model , prompt = request .prompt , stream = True , ** request_args ,
7288 )
7389
7490 for chunk in response :
@@ -80,8 +96,16 @@ def make_request(self, request: BenchmarkRequest) -> Iterator[GenerativeResponse
8096 type_ = "final" ,
8197 output = choice ["text" ],
8298 prompt = request .prompt ,
83- prompt_token_count = self ._token_count (request .prompt ),
84- output_token_count = self ._token_count (choice ["text" ]),
99+ prompt_token_count = (
100+ request .token_count
101+ if request .token_count
102+ else self ._token_count (request .prompt )
103+ ),
104+ output_token_count = (
105+ num_gen_tokens
106+ if num_gen_tokens
107+ else self ._token_count (choice ["text" ])
108+ ),
85109 )
86110 break
87111 else :
@@ -133,14 +157,6 @@ def model_tokenizer(self, model: str) -> Optional[Any]:
133157 return None
134158
135159 def _token_count (self , text : str ) -> int :
136- """
137- Count the number of tokens in a text.
138-
139- :param text: The text to tokenize.
140- :type text: str
141- :return: The number of tokens.
142- :rtype: int
143- """
144160 token_count = len (text .split ())
145161 logger .debug (f"Token count for text '{ text } ': { token_count } " )
146162 return token_count
0 commit comments