1+ import asyncio
2+ import contextlib
3+ import os
4+ import shlex
5+ import subprocess
6+ import sys
7+ import threading
8+ import time
9+ from tempfile import TemporaryDirectory
10+
11+ import docker
12+ import pytest
13+ from docker .errors import NotFound
14+ import logging
15+ from gaudi .test_embed import TEST_CONFIGS
16+ import aiohttp
17+
18+ logging .basicConfig (
19+ level = logging .INFO ,
20+ format = "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>" ,
21+ stream = sys .stdout ,
22+ )
23+ logger = logging .getLogger (__file__ )
24+
25+ # Use the latest image from the local docker build
26+ DOCKER_IMAGE = os .getenv ("DOCKER_IMAGE" , "tei_hpu" )
27+ DOCKER_VOLUME = os .getenv ("DOCKER_VOLUME" , None )
28+
29+ if DOCKER_VOLUME is None :
30+ logger .warning (
31+ "DOCKER_VOLUME is not set, this will lead to the tests redownloading the models on each run, consider setting it to speed up testing"
32+ )
33+
34+ LOG_LEVEL = os .getenv ("LOG_LEVEL" , "info" )
35+
36+ BASE_ENV = {
37+ "HF_HUB_ENABLE_HF_TRANSFER" : "1" ,
38+ "LOG_LEVEL" : LOG_LEVEL ,
39+ "HABANA_VISIBLE_DEVICES" : "all" ,
40+ }
41+
42+ HABANA_RUN_ARGS = {
43+ "runtime" : "habana" ,
44+ }
45+
46+ def stream_container_logs (container , test_name ):
47+ """Stream container logs in a separate thread."""
48+ try :
49+ for log in container .logs (stream = True , follow = True ):
50+ print (
51+ f"[TEI Server Logs - { test_name } ] { log .decode ('utf-8' )} " ,
52+ end = "" ,
53+ file = sys .stderr ,
54+ flush = True ,
55+ )
56+ except Exception as e :
57+ logger .error (f"Error streaming container logs: { str (e )} " )
58+
59+
60+ class LauncherHandle :
61+ def __init__ (self , port : int ):
62+ self .port = port
63+ self .base_url = f"http://127.0.0.1:{ port } "
64+
65+ async def generate (self , prompt : str ):
66+ async with aiohttp .ClientSession () as session :
67+ async with session .post (
68+ f"{ self .base_url } /embed" ,
69+ json = {"inputs" : prompt },
70+ headers = {"Content-Type" : "application/json" }
71+ ) as response :
72+ if response .status != 200 :
73+ error_text = await response .text ()
74+ raise RuntimeError (f"Request failed with status { response .status } : { error_text } " )
75+ return await response .json ()
76+
77+ def _inner_health (self ):
78+ raise NotImplementedError
79+
80+ async def health (self , timeout : int = 60 ):
81+ assert timeout > 0
82+ start_time = time .time ()
83+ logger .info (f"Starting health check with timeout of { timeout } s" )
84+
85+ for attempt in range (timeout ):
86+ if not self ._inner_health ():
87+ logger .error ("Launcher crashed during health check" )
88+ raise RuntimeError ("Launcher crashed" )
89+
90+ try :
91+ # Try to make a request using generate
92+ await self .generate ("test" )
93+ elapsed = time .time () - start_time
94+ logger .info (f"Health check passed after { elapsed :.1f} s" )
95+ return
96+ except (aiohttp .ClientError , asyncio .TimeoutError ) as e :
97+ if attempt == timeout - 1 :
98+ logger .error (f"Health check failed after { timeout } s: { str (e )} " )
99+ raise RuntimeError (f"Health check failed: { str (e )} " )
100+ if attempt % 10 == 0 and attempt != 0 : # Only log every 10th attempt
101+ logger .debug (f"Connection attempt { attempt } /{ timeout } failed: { str (e )} " )
102+ await asyncio .sleep (1 )
103+ except Exception as e :
104+ logger .error (f"Unexpected error during health check: { str (e )} " )
105+ import traceback
106+ logger .error (f"Full traceback:\n { traceback .format_exc ()} " )
107+ raise
108+
109+
110+ class ContainerLauncherHandle (LauncherHandle ):
111+ def __init__ (self , docker_client , container_name , port : int ):
112+ super ().__init__ (port )
113+ self .docker_client = docker_client
114+ self .container_name = container_name
115+
116+ def _inner_health (self ) -> bool :
117+ try :
118+ container = self .docker_client .containers .get (self .container_name )
119+ status = container .status
120+ if status not in ["running" , "created" ]:
121+ logger .warning (f"Container status is { status } " )
122+ # Get container logs for debugging
123+ logs = container .logs ().decode ("utf-8" )
124+ logger .debug (f"Container logs:\n { logs } " )
125+ return False
126+ return True
127+ except Exception as e :
128+ logger .error (f"Error checking container health: { str (e )} " )
129+ return False
130+
131+ class ProcessLauncherHandle (LauncherHandle ):
132+ def __init__ (self , process , port : int ):
133+ super (ProcessLauncherHandle , self ).__init__ (port )
134+ self .process = process
135+
136+ def _inner_health (self ) -> bool :
137+ return self .process .poll () is None
138+
139+
140+ @pytest .fixture (scope = "module" )
141+ def data_volume ():
142+ tmpdir = TemporaryDirectory ()
143+ yield tmpdir .name
144+ try :
145+ # Cleanup the temporary directory using sudo as it contains root files created by the container
146+ subprocess .run (shlex .split (f"sudo rm -rf { tmpdir .name } " ), check = True )
147+ except subprocess .CalledProcessError as e :
148+ logger .error (f"Error cleaning up temporary directory: { str (e )} " )
149+
150+
151+ @pytest .fixture (scope = "function" )
152+ def gaudi_launcher (event_loop ):
153+ @contextlib .contextmanager
154+ def docker_launcher (
155+ model_id : str ,
156+ test_name : str ,
157+ ):
158+ logger .info (
159+ f"Starting docker launcher for model { model_id } and test { test_name } "
160+ )
161+
162+
163+ port = 8080
164+
165+ client = docker .from_env ()
166+
167+ container_name = f"tei-hpu-test-{ test_name .replace ('/' , '-' )} "
168+
169+ try :
170+ container = client .containers .get (container_name )
171+ logger .info (
172+ f"Stopping existing container { container_name } for test { test_name } "
173+ )
174+ container .stop ()
175+ container .wait ()
176+ except NotFound :
177+ pass
178+ except Exception as e :
179+ logger .error (f"Error handling existing container: { str (e )} " )
180+
181+ tei_args = TEST_CONFIGS [test_name ]["args" ].copy ()
182+
183+ # add model_id to tei args
184+ tei_args .append ("--model-id" )
185+ tei_args .append (model_id )
186+
187+ env = BASE_ENV .copy ()
188+ env ["HF_TOKEN" ] = os .getenv ("HF_TOKEN" )
189+
190+ # Add env config that is definied in the fixture parameter
191+ if "env_config" in TEST_CONFIGS [test_name ]:
192+ env .update (TEST_CONFIGS [test_name ]["env_config" ].copy ())
193+
194+ volumes = [f"{ DOCKER_VOLUME } :/data" ]
195+ logger .debug (f"Using volume { volumes } " )
196+
197+ try :
198+ logger .info (f"Creating container with name { container_name } " )
199+
200+ # Log equivalent docker run command for debugging, this is not actually executed
201+ container = client .containers .run (
202+ DOCKER_IMAGE ,
203+ command = tei_args ,
204+ name = container_name ,
205+ environment = env ,
206+ detach = True ,
207+ volumes = volumes ,
208+ ports = {"80/tcp" : port },
209+ ** HABANA_RUN_ARGS ,
210+ )
211+
212+ logger .info (f"Container { container_name } started successfully" )
213+
214+ # Start log streaming in a background thread
215+ log_thread = threading .Thread (
216+ target = stream_container_logs ,
217+ args = (container , test_name ),
218+ daemon = True , # This ensures the thread will be killed when the main program exits
219+ )
220+ log_thread .start ()
221+
222+ # Add a small delay to allow container to initialize
223+ time .sleep (2 )
224+
225+ # Check container status after creation
226+ status = container .status
227+ logger .debug (f"Initial container status: { status } " )
228+ if status not in ["running" , "created" ]:
229+ logs = container .logs ().decode ("utf-8" )
230+ logger .error (f"Container failed to start properly. Logs:\n { logs } " )
231+
232+ yield ContainerLauncherHandle (client , container .name , port )
233+
234+ except Exception as e :
235+ logger .error (f"Error starting container: { str (e )} " )
236+ # Get full traceback for debugging
237+ import traceback
238+
239+ logger .error (f"Full traceback:\n { traceback .format_exc ()} " )
240+ raise
241+ finally :
242+ try :
243+ container = client .containers .get (container_name )
244+ logger .info (f"Stopping container { container_name } " )
245+ container .stop ()
246+ container .wait ()
247+
248+ container_output = container .logs ().decode ("utf-8" )
249+ print (container_output , file = sys .stderr )
250+
251+ container .remove ()
252+ logger .info (f"Container { container_name } removed successfully" )
253+ except NotFound :
254+ pass
255+ except Exception as e :
256+ logger .warning (f"Error cleaning up container: { str (e )} " )
257+
258+ return docker_launcher
0 commit comments