5
5
import signal
6
6
import sys
7
7
import os
8
+ from typing import List
9
+ from PIL import Image
8
10
9
11
import backend_pb2
10
12
import backend_pb2_grpc
15
17
from vllm .sampling_params import SamplingParams
16
18
from vllm .utils import random_uuid
17
19
from vllm .transformers_utils .tokenizer import get_tokenizer
20
+ from vllm .multimodal .utils import fetch_image
21
+ from vllm .assets .video import VideoAsset
18
22
19
23
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
20
24
@@ -105,6 +109,7 @@ async def LoadModel(self, request, context):
105
109
try :
106
110
self .llm = AsyncLLMEngine .from_engine_args (engine_args )
107
111
except Exception as err :
112
+ print (f"Unexpected { err = } , { type (err )= } " , file = sys .stderr )
108
113
return backend_pb2 .Result (success = False , message = f"Unexpected { err = } , { type (err )= } " )
109
114
110
115
try :
@@ -117,7 +122,7 @@ async def LoadModel(self, request, context):
117
122
)
118
123
except Exception as err :
119
124
return backend_pb2 .Result (success = False , message = f"Unexpected { err = } , { type (err )= } " )
120
-
125
+ print ( "Model loaded successfully" , file = sys . stderr )
121
126
return backend_pb2 .Result (message = "Model loaded successfully" , success = True )
122
127
123
128
async def Predict (self , request , context ):
@@ -196,15 +201,33 @@ async def _predict(self, request, context, streaming=False):
196
201
if request .Seed != 0 :
197
202
sampling_params .seed = request .Seed
198
203
204
+ # Extract image paths and process images
199
205
prompt = request .Prompt
200
-
201
- # If tokenizer template is enabled and messages are provided instead of prompt apply the tokenizer template
206
+
207
+ image_paths = request .Images
208
+ image_data = [self .load_image (img_path ) for img_path in image_paths ]
209
+
210
+ videos_path = request .Videos
211
+ video_data = [self .load_video (video_path ) for video_path in videos_path ]
212
+
213
+ # If tokenizer template is enabled and messages are provided instead of prompt, apply the tokenizer template
202
214
if not request .Prompt and request .UseTokenizerTemplate and request .Messages :
203
215
prompt = self .tokenizer .apply_chat_template (request .Messages , tokenize = False , add_generation_prompt = True )
204
216
205
- # Generate text
217
+ # Generate text using the LLM engine
206
218
request_id = random_uuid ()
207
- outputs = self .llm .generate (prompt , sampling_params , request_id )
219
+ print (f"Generating text with request_id: { request_id } " , file = sys .stderr )
220
+ outputs = self .llm .generate (
221
+ {
222
+ "prompt" : prompt ,
223
+ "multi_modal_data" : {
224
+ "image" : image_data if image_data else None ,
225
+ "video" : video_data if video_data else None ,
226
+ } if image_data or video_data else None ,
227
+ },
228
+ sampling_params = sampling_params ,
229
+ request_id = request_id ,
230
+ )
208
231
209
232
# Stream the results
210
233
generated_text = ""
@@ -227,9 +250,49 @@ async def _predict(self, request, context, streaming=False):
227
250
if streaming :
228
251
return
229
252
253
+ # Remove the image files from /tmp folder
254
+ for img_path in image_paths :
255
+ try :
256
+ os .remove (img_path )
257
+ except Exception as e :
258
+ print (f"Error removing image file: { img_path } , { e } " , file = sys .stderr )
259
+
230
260
# Sending the final generated text
231
261
yield backend_pb2 .Reply (message = bytes (generated_text , encoding = 'utf-8' ))
232
262
263
+ def load_image (self , image_path : str ):
264
+ """
265
+ Load an image from the given file path.
266
+
267
+ Args:
268
+ image_path (str): The path to the image file.
269
+
270
+ Returns:
271
+ Image: The loaded image.
272
+ """
273
+ try :
274
+ return Image .open (image_path )
275
+ except Exception as e :
276
+ print (f"Error loading image { image_path } : { e } " , file = sys .stderr )
277
+ return self .load_video (image_path )
278
+
279
+ def load_video (self , video_path : str ):
280
+ """
281
+ Load a video from the given file path.
282
+
283
+ Args:
284
+ video_path (str): The path to the image file.
285
+
286
+ Returns:
287
+ Video: The loaded video.
288
+ """
289
+ try :
290
+ video = VideoAsset (name = video_path ).np_ndarrays
291
+ return video
292
+ except Exception as e :
293
+ print (f"Error loading video { image_path } : { e } " , file = sys .stderr )
294
+ return None
295
+
233
296
async def serve (address ):
234
297
# Start asyncio gRPC server
235
298
server = grpc .aio .server (migration_thread_pool = futures .ThreadPoolExecutor (max_workers = MAX_WORKERS ))
0 commit comments