44import contextlib
55import logging
66import os
7+ import threading
78import typing
89
10+ import ray ._private .runtime_env .packaging
911from ray ._private .runtime_env import virtualenv_utils
1012from ray ._private .runtime_env .pip import PipPlugin
1113from ray ._private .runtime_env .plugin import RuntimeEnvPlugin
@@ -29,18 +31,24 @@ async def create_or_get_virtualenv(path: str, cwd: str, logger: logging.Logger):
2931 await original_create_or_get_virtualenv (path = path , cwd = cwd , logger = logger )
3032
3133
34+ _monkey_patch_lock = threading .RLock ()
35+
36+
3237@contextlib .contextmanager
3338def patch_create_or_get_virtualenv (phase_index : int ):
34- if phase_index > 0 :
35- setattr (virtualenv_utils , "create_or_get_virtualenv" , create_or_get_virtualenv )
36- try :
37- yield
38- finally :
39- setattr (
40- virtualenv_utils ,
41- "create_or_get_virtualenv" ,
42- original_create_or_get_virtualenv ,
43- )
39+ with _monkey_patch_lock :
40+ if phase_index > 0 :
41+ setattr (
42+ virtualenv_utils , "create_or_get_virtualenv" , create_or_get_virtualenv
43+ )
44+ try :
45+ yield
46+ finally :
47+ setattr (
48+ virtualenv_utils ,
49+ "create_or_get_virtualenv" ,
50+ original_create_or_get_virtualenv ,
51+ )
4452
4553
4654class OrderedPipPlugin (RuntimeEnvPlugin ):
@@ -93,27 +101,81 @@ def try_import_torch():
93101 """
94102
95103 name = "ordered_pip"
104+
105+ # VV: Configure Ray to use this RuntimeEnvPlugin last
106+ priority = 100
96107 ClassPath = "orchestrator.utilities.ray_env.ordered_pip.OrderedPipPlugin"
97108
98109 def __init__ (self , resources_dir : str | None = None ):
99- if resources_dir is None :
100- import ray ._private .ray_constants as ray_constants
110+ self ._global_mtx = threading .RLock ()
111+ self ._create_env_mtx : dict [str , threading .RLock ] = {}
112+ self ._pip_resources_dir = resources_dir
101113
102- resources_dir = os .environ .get (
103- ray_constants .RAY_RUNTIME_ENV_CREATE_WORKING_DIR_ENV_VAR
104- )
114+ # VV: Maintains a cache of the environments that have been built thus far
115+ self ._cache = {}
105116
106- if not resources_dir :
107- import tempfile
117+ def _try_switch_resources_dir_from_context (
118+ self ,
119+ context : "RuntimeEnvContext" , # noqa: F821
120+ logger : logging .Logger | None = default_logger ,
121+ ):
122+ # VV: When ray instantiates custom RuntimeEnvPlugins it does not provide a resources_dir path.
123+ # This method is a HACK that the resources_dir based on the RuntimeEnvContext which is known
124+ # at the time of CREATING a virtual environment i.e. **after** the RuntimeEnvPlugin is initialized.
125+
126+ with self ._global_mtx :
127+ # VV: Stick with whatever resources dir we've already picked
128+ if self ._pip_resources_dir :
129+ return
130+
131+ logger .info ("Generating resources dir" )
132+ unique = set ()
133+ if "PYTHONPATH" in context .env_vars :
134+ # VV: This is a HACK to find the "runtime_resources" path inside the PYTHONPATH env-var
135+ # This is an env-var that the WorkingDirPlugin inserts.
136+ # I noticed that sometimes the PYTHONPATH contains multiple copies of the same PATH.
137+ # The PYTHONPATH looks like this:
138+ # /tmp/ray/session_$timestamp/runtime_resources/working_dir_files/_ray_pkg_$uid
139+ many = context .env_vars ["PYTHONPATH" ].split (os .pathsep )
140+ logger .info (f"Current PYTHONPATH { many } " )
141+ runtime_resources_followup = f"{ os .sep } working_dir_files{ os .sep } "
142+ unique .update (
143+ [
144+ os .path .join (
145+ x .split (runtime_resources_followup , 1 )[0 ], "ordered_pip"
146+ )
147+ for x in many
148+ if runtime_resources_followup in x
149+ ]
150+ )
108151
109- resources_dir = tempfile . mkdtemp ( prefix = "ordered_pip_" , dir = "/tmp/ray " )
152+ logger . info ( f"The candidate locations of runtime_resources: { list ( unique ) } " )
110153
111- self ._pip_resources_dir = resources_dir
154+ if len (unique ) != 1 :
155+ import tempfile
156+
157+ unique .clear ()
158+ unique .add (tempfile .mkdtemp (prefix = "ordered_pip_" , dir = "/tmp/ray" ))
112159
113- from ray . _common . utils import try_to_create_directory
160+ self . _switch_resources_dir ( unique . pop ())
114161
115- try_to_create_directory (self ._pip_resources_dir )
116- self ._pip_plugin = PipPlugin (self ._pip_resources_dir )
162+ def _switch_resources_dir (self , resources_dir : str ):
163+ with self ._global_mtx :
164+ from ray ._common .utils import try_to_create_directory
165+
166+ self ._pip_resources_dir = resources_dir
167+ try_to_create_directory (self ._pip_resources_dir )
168+
169+ @property
170+ def _pip_plugin (self ) -> PipPlugin :
171+ # The PipPlugin keeps an internal cache of virtual environments it has created but not yet deleted.
172+ # When .create() is called, it checks this cache for a venv matching the given URI.
173+ # If a match is found, it assumes the venv already exists and skips re-creation.
174+ # However, ordered_pip needs to reuse the same venv multiple times (once per "phase").
175+ # Thus, we create a new PipPlugin instance on demand for each phase of ordered_pip.
176+ # Also, we maintain our own record of venvs to decide whether to create a new "ordered_pip"
177+ # venv or reuse an existing one.
178+ return PipPlugin (self ._pip_resources_dir )
117179
118180 @staticmethod
119181 def validate (runtime_env_dict : dict [str , typing .Any ]) -> RuntimeEnv :
@@ -132,14 +194,14 @@ def validate(runtime_env_dict: dict[str, typing.Any]) -> RuntimeEnv:
132194 raise ValueError ("runtime_env must be a dictionary" )
133195
134196 if "ordered_pip" not in runtime_env_dict :
135- raise ValueError ( "missing the 'ordered_pip' key" , runtime_env_dict )
197+ return RuntimeEnv ( ** runtime_env_dict )
136198
137199 if not isinstance (runtime_env_dict ["ordered_pip" ], dict ):
138200 raise ValueError ("runtime_env['ordered_pip'] must be a dictionary" )
139201
140202 if not isinstance (runtime_env_dict ["ordered_pip" ]["phases" ], list ):
141203 raise ValueError (
142- "runtime_env['ordered_pip']['phases'] must be a dictionary consistent with pip"
204+ "runtime_env['ordered_pip']['phases'] must be an array of pip entries "
143205 )
144206
145207 phases = []
@@ -164,6 +226,9 @@ def validate(runtime_env_dict: dict[str, typing.Any]) -> RuntimeEnv:
164226 return result
165227
166228 def get_uris (self , runtime_env : "RuntimeEnv" ) -> list [str ]:
229+ if not self .is_ordered_pip_runtimeenv (runtime_env ):
230+ return []
231+
167232 # VV: We want the hash to be invariant to the order of package names within a phase,
168233 # and we also want the order of phases to be reflected in the hash.
169234 aggregate_packages = [
@@ -178,31 +243,76 @@ def get_uris(self, runtime_env: "RuntimeEnv") -> list[str]:
178243 "pip://" + hashlib .sha1 (str (aggregate_packages ).encode ("utf-8" )).hexdigest ()
179244 ]
180245
246+ def is_ordered_pip_runtimeenv (self , runtime_env : "RuntimeEnv" ) -> bool :
247+ return bool (self .validate (runtime_env ).get ("ordered_pip" ))
248+
181249 async def create (
182250 self ,
183251 uri : str ,
184252 runtime_env : "RuntimeEnv" , # noqa: F821
185253 context : "RuntimeEnvContext" , # noqa: F821
186254 logger : logging .Logger | None = default_logger ,
187255 ) -> int :
256+ self ._try_switch_resources_dir_from_context (context , logger )
257+
258+ if not self .is_ordered_pip_runtimeenv (runtime_env ):
259+ return 0
260+
188261 uri = self .get_uris (runtime_env )[0 ]
189- total_bytes = 0
190-
191- for idx , pip in enumerate (self .validate (runtime_env )["ordered_pip" ]["phases" ]):
192- with patch_create_or_get_virtualenv (idx ):
193- total_bytes += await self ._pip_plugin .create (
194- uri = uri ,
195- runtime_env = RuntimeEnv (pip = pip ),
196- context = context ,
197- logger = logger ,
198- )
199262
200- return total_bytes
263+ with self ._global_mtx :
264+ if uri not in self ._create_env_mtx :
265+ self ._create_env_mtx [uri ] = threading .RLock ()
266+
267+ with self ._create_env_mtx [uri ]:
268+ logger .info (f"Creating { uri } for { runtime_env } " )
269+ try :
270+ if os .path .isdir (self .get_path_to_pip_venv (uri )):
271+ logger .info (f"Virtual environment for { uri } already exists" )
272+ return self ._cache [uri ]
273+ except KeyError :
274+ pass
275+
276+ self ._cache [uri ] = 0
277+ for idx , pip in enumerate (
278+ self .validate (runtime_env )["ordered_pip" ]["phases" ]
279+ ):
280+ with patch_create_or_get_virtualenv (idx ):
281+ logger .info (f"Creating { idx } for { uri } " )
282+
283+ self ._cache [uri ] += await self ._pip_plugin .create (
284+ uri = uri ,
285+ runtime_env = RuntimeEnv (pip = pip ),
286+ context = context ,
287+ logger = logger ,
288+ )
289+ logger .info (f"Done creating { idx } for { uri } " )
290+
291+ return self ._cache [uri ]
292+
293+ def get_path_to_pip_venv (self , uri : str ) -> str :
294+ _ , env_hash = ray ._private .runtime_env .packaging .parse_uri (uri )
295+ return os .path .join (self ._pip_resources_dir , "pip" , env_hash )
201296
202297 def delete_uri (
203298 self , uri : str , logger : logging .Logger | None = default_logger
204299 ) -> int :
205- return self ._pip_plugin .delete_uri (uri = uri , logger = logger )
300+ logger .info (f"Cleaning up { uri } " )
301+ del self ._cache [uri ]
302+
303+ import shutil
304+
305+ import ray ._private .utils
306+
307+ env_dir = self .get_path_to_pip_venv (uri )
308+ num_bytes = ray ._private .utils .get_directory_size_bytes (env_dir )
309+
310+ try :
311+ shutil .rmtree (env_dir )
312+ except Exception as e :
313+ logger .warning (f"Exception while cleaning up { env_dir } { e !s} - will ignore" )
314+
315+ return num_bytes
206316
207317 def modify_context (
208318 self ,
@@ -211,7 +321,14 @@ def modify_context(
211321 context : "RuntimeEnvContext" , # noqa: F821
212322 logger : logging .Logger = default_logger ,
213323 ):
214- phases = self .validate (runtime_env )["ordered_pip" ]["phases" ]
324+ self ._try_switch_resources_dir_from_context (context )
325+
326+ runtime_env = self .validate (runtime_env )
327+ if not runtime_env .get ("ordered_pip" ):
328+ return
329+
330+ logger .info (f"Modifying the context for { uris } and { runtime_env } " )
331+ phases = runtime_env ["ordered_pip" ]["phases" ]
215332
216333 if not len (phases ):
217334 return
@@ -222,3 +339,18 @@ def modify_context(
222339 context = context ,
223340 logger = logger ,
224341 )
342+
343+ if "PYTHONPATH" in context .env_vars :
344+ # VV: Ensure unique paths in PYTHONPATH
345+ paths = context .env_vars ["PYTHONPATH" ].split (os .pathsep )
346+
347+ unique = []
348+ for k in paths :
349+ if k not in unique :
350+ unique .append (k )
351+
352+ context .env_vars ["PYTHONPATH" ] = os .pathsep .join (unique )
353+
354+ logger .info (
355+ f"Modified the context for { uris } and { runtime_env } with { context .py_executable } { context .env_vars } "
356+ )
0 commit comments