Skip to content

Commit 8b68660

Browse files
Luodiankcz358
andauthored
[Model] add openai compatible API interface (EvolvingLMMs-Lab#546)
* Update API usage and add environment configuration - Adjust API client initialization for OpenAI and Azure. - Modify image encoding to handle size limits and maintain aspect ratio. - Add new `OpenAICompatible` model class. - Introduce retry mechanism for API requests with configurable retries. - Update `.gitignore` to exclude `.env` and `scripts/`. .gitignore: - Exclude `.env` file for security. - Ensure no scripts directory is tracked. lmms_eval/models/gpt4v.py: - Refactor API client initialization to use new OpenAI and Azure clients. - Update image encoding to handle size limits with resizing logic. - Adjust retry logic for API calls, reducing sleep time. lmms_eval/models/openai_compatible.py: - Create new `OpenAICompatible` model class with similar structure. - Implement encoding functions for images and videos. - Integrate environment variable loading and persistent response caching. miscs/model_dryruns/openai_compatible.sh: - Add sample script for running the new model. * Improve code readability and organization - Remove unused import for deepcopy in `openai_compatible.py`. - Add a blank line for better separation of code sections. - Adjust comment formatting for `max_size_in_mb` for consistency. - Ensure consistent spacing around comments. File: `lmms_eval/models/openai_compatible.py` - Removed `deepcopy` import: cleaned up unnecessary code. - Added blank line after `load_dotenv`: improved readability. - Reformatted comment on `max_size_in_mb`: enhanced clarity. - Removed extra blank line before `Accelerator`: tightened spacing. * Fix init --------- Co-authored-by: kcz358 <[email protected]>
1 parent 968d5f1 commit 8b68660

File tree

5 files changed

+306
-60
lines changed

5 files changed

+306
-60
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,5 @@ VATEX/
4242
lmms_eval/tasks/vatex/__pycache__/utils.cpython-310.pyc
4343
lmms_eval/tasks/mlvu/__pycache__/utils.cpython-310.pyc
4444

45-
scripts/
45+
scripts/
46+
.env

lmms_eval/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
"minimonkey": "MiniMonkey",
4040
"moviechat": "MovieChat",
4141
"mplug_owl_video": "mplug_Owl",
42+
"openai_compatible": "OpenAICompatible",
4243
"oryx": "Oryx",
4344
"phi3v": "Phi3v",
4445
"qwen2_5_vl": "Qwen2_5_VL",

lmms_eval/models/gpt4v.py

+65-59
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
import time
55
from copy import deepcopy
66
from io import BytesIO
7-
from typing import List, Tuple
7+
from typing import List, Tuple, Union
88

99
import numpy as np
1010
import requests as url_requests
1111
from accelerate import Accelerator, DistributedType
12+
from openai import AzureOpenAI, OpenAI
1213
from tqdm import tqdm
1314

1415
from lmms_eval.api.instance import Instance
@@ -20,26 +21,19 @@
2021
except ImportError:
2122
pass
2223

24+
from loguru import logger as eval_logger
2325
from PIL import Image
2426

2527
API_TYPE = os.getenv("API_TYPE", "openai")
26-
NUM_SECONDS_TO_SLEEP = 30
27-
from loguru import logger as eval_logger
28-
28+
NUM_SECONDS_TO_SLEEP = 10
2929
if API_TYPE == "openai":
3030
API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
3131
API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
32-
headers = {
33-
"Authorization": f"Bearer {API_KEY}",
34-
"Content-Type": "application/json",
35-
}
32+
3633
elif API_TYPE == "azure":
3734
API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken")
3835
API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
39-
headers = {
40-
"api-key": API_KEY,
41-
"Content-Type": "application/json",
42-
}
36+
API_VERSION = os.getenv("AZURE_API_VERSION", "2023-07-01-preview")
4337

4438

4539
@register_model("gpt4v")
@@ -52,6 +46,7 @@ def __init__(
5246
timeout: int = 120,
5347
continual_mode: bool = False,
5448
response_persistent_folder: str = None,
49+
max_size_in_mb: int = 20,
5550
**kwargs,
5651
) -> None:
5752
super().__init__()
@@ -80,6 +75,11 @@ def __init__(
8075
self.response_cache = {}
8176
self.cache_mode = "start"
8277

78+
if API_TYPE == "openai":
79+
self.client = OpenAI(api_key=API_KEY)
80+
elif API_TYPE == "azure":
81+
self.client = AzureOpenAI(api_key=API_KEY, azure_endpoint=API_URL, api_version=API_VERSION)
82+
8383
accelerator = Accelerator()
8484
# assert self.batch_size_per_gpu == 1, "Llava currently does not support batched generation. See https://github.com/haotian-liu/LLaVA/issues/754. HF Llava also has this issue."
8585
if accelerator.num_processes > 1:
@@ -94,13 +94,30 @@ def __init__(
9494
self._rank = self.accelerator.local_process_index
9595
self._world_size = self.accelerator.num_processes
9696

97+
self.max_size_in_mb = max_size_in_mb
9798
self.device = self.accelerator.device
9899

99100
# Function to encode the image
100-
def encode_image(self, image: Image):
101+
def encode_image(self, image: Union[Image.Image, str]):
102+
max_size = self.max_size_in_mb * 1024 * 1024 # 20MB in bytes
103+
if isinstance(image, str):
104+
img = Image.open(image).convert("RGB")
105+
else:
106+
img = image.copy()
107+
101108
output_buffer = BytesIO()
102-
image.save(output_buffer, format="PNG")
109+
img.save(output_buffer, format="PNG")
103110
byte_data = output_buffer.getvalue()
111+
112+
# If image is too large, resize it while maintaining aspect ratio
113+
while len(byte_data) > max_size and img.size[0] > 100 and img.size[1] > 100:
114+
new_size = (int(img.size[0] * 0.75), int(img.size[1] * 0.75))
115+
img = img.resize(new_size, Image.Resampling.LANCZOS)
116+
117+
output_buffer = BytesIO()
118+
img.save(output_buffer, format="PNG")
119+
byte_data = output_buffer.getvalue()
120+
104121
base64_str = base64.b64encode(byte_data).decode("utf-8")
105122
return base64_str
106123

@@ -150,39 +167,30 @@ def generate_until(self, requests) -> List[str]:
150167
continue
151168

152169
visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
153-
visuals = self.flatten(visuals)
154-
imgs = [] # multiple images or frames for video
155-
for visual in visuals:
156-
if self.modality == "image":
157-
img = self.encode_image(visual)
158-
imgs.append(img)
159-
elif self.modality == "video":
160-
frames = self.encode_video(visual, self.max_frames_num)
161-
imgs.extend(frames)
170+
if None in visuals:
171+
visuals = []
172+
imgs = []
173+
else:
174+
visuals = self.flatten(visuals)
175+
imgs = [] # multiple images or frames for video
176+
for visual in visuals:
177+
if isinstance(visual, str) and (".mp4" in visual or ".avi" in visual or ".mov" in visual or ".flv" in visual or ".wmv" in visual):
178+
frames = self.encode_video(visual, self.max_frames_num)
179+
imgs.extend(frames)
180+
elif isinstance(visual, str) and (".jpg" in visual or ".jpeg" in visual or ".png" in visual or ".gif" in visual or ".bmp" in visual or ".tiff" in visual or ".webp" in visual):
181+
img = self.encode_image(visual)
182+
imgs.append(img)
183+
elif isinstance(visual, Image.Image):
184+
img = self.encode_image(visual)
185+
imgs.append(img)
162186

163187
payload = {"messages": []}
164-
if API_TYPE == "openai":
165-
payload["model"] = self.model_version
166-
167-
response_json = {"role": "user", "content": []}
168-
# When there is no image token in the context, append the image to the text
169-
if self.image_token not in contexts:
170-
payload["messages"].append(deepcopy(response_json))
171-
payload["messages"][0]["content"].append({"type": "text", "text": contexts})
172-
for img in imgs:
173-
payload["messages"][0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}})
174-
else:
175-
contexts = contexts.split(self.image_token)
176-
for idx, img in enumerate(imgs):
177-
payload["messages"].append(deepcopy(response_json))
178-
payload["messages"][idx]["content"].append({"type": "text", "text": contexts[idx]})
179-
payload["messages"][idx]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}})
180-
181-
# If n image tokens are in the contexts
182-
# contexts will be splitted into n+1 chunks
183-
# Manually add it into the payload
184-
payload["messages"].append(deepcopy(response_json))
185-
payload["messages"][-1]["content"].append({"type": "text", "text": contexts[-1]})
188+
payload["model"] = self.model_version
189+
190+
payload["messages"].append({"role": "user", "content": []})
191+
payload["messages"][0]["content"].append({"type": "text", "text": contexts})
192+
for img in imgs:
193+
payload["messages"][0]["content"].append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}})
186194

187195
if "max_new_tokens" not in gen_kwargs:
188196
gen_kwargs["max_new_tokens"] = 1024
@@ -198,26 +206,24 @@ def generate_until(self, requests) -> List[str]:
198206
payload["max_tokens"] = gen_kwargs["max_new_tokens"]
199207
payload["temperature"] = gen_kwargs["temperature"]
200208

201-
for attempt in range(5):
209+
MAX_RETRIES = 5
210+
for attempt in range(MAX_RETRIES):
202211
try:
203-
response = url_requests.post(API_URL, headers=headers, json=payload, timeout=self.timeout)
204-
response_data = response.json()
205-
206-
response_text = response_data["choices"][0]["message"]["content"].strip()
212+
response = self.client.chat.completions.create(**payload)
213+
response_text = response.choices[0].message.content
207214
break # If successful, break out of the loop
208215

209216
except Exception as e:
210-
try:
211-
error_msg = response.json()
212-
except:
213-
error_msg = ""
217+
error_msg = str(e)
218+
eval_logger.info(f"Attempt {attempt + 1}/{MAX_RETRIES} failed with error: {error_msg}")
214219

215-
eval_logger.info(f"Attempt {attempt + 1} failed with error: {str(e)}.\nReponse: {error_msg}")
216-
if attempt <= 5:
217-
time.sleep(NUM_SECONDS_TO_SLEEP)
218-
else: # If this was the last attempt, log and return empty string
219-
eval_logger.error(f"All 5 attempts failed. Last error message: {str(e)}.\nResponse: {response.json()}")
220+
# On last attempt, log error and set empty response
221+
if attempt == MAX_RETRIES - 1:
222+
eval_logger.error(f"All {MAX_RETRIES} attempts failed. Last error: {error_msg}")
220223
response_text = ""
224+
else:
225+
time.sleep(NUM_SECONDS_TO_SLEEP)
226+
221227
res.append(response_text)
222228
pbar.update(1)
223229

0 commit comments

Comments
 (0)