Skip to content

Commit dee913e

Browse files
fix: the OrderedPip RayRuntimeEnv plugin and the SFTTrainer code that uses it (#186)
* feat: fix the OrderedPip RayRuntimeEnv plugin and the SFTTrainer code that uses it Also, add support for using the pip_install_options feature in Ray 2.50.0 Signed-off-by: Vassilis Vassiliadis <[email protected]> * refactor: reorder observed properties for SFTTrainer experiments This makes it easier to glance at the metrics while running experiments by showing the more important measurements first. Signed-off-by: Vassilis Vassiliadis <[email protected]> * test: add unit-test for propagation of PIP_FIND_LINKS to pip_instal_options Signed-off-by: Vassilis Vassiliadis <[email protected]> --------- Signed-off-by: Vassilis Vassiliadis <[email protected]>
1 parent 6ae8999 commit dee913e

File tree

6 files changed

+513
-170
lines changed

6 files changed

+513
-170
lines changed

orchestrator/utilities/ray_env/ordered_pip.py

Lines changed: 169 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
import contextlib
55
import logging
66
import os
7+
import threading
78
import typing
89

10+
import ray._private.runtime_env.packaging
911
from ray._private.runtime_env import virtualenv_utils
1012
from ray._private.runtime_env.pip import PipPlugin
1113
from ray._private.runtime_env.plugin import RuntimeEnvPlugin
@@ -29,18 +31,24 @@ async def create_or_get_virtualenv(path: str, cwd: str, logger: logging.Logger):
2931
await original_create_or_get_virtualenv(path=path, cwd=cwd, logger=logger)
3032

3133

34+
_monkey_patch_lock = threading.RLock()
35+
36+
3237
@contextlib.contextmanager
3338
def patch_create_or_get_virtualenv(phase_index: int):
34-
if phase_index > 0:
35-
setattr(virtualenv_utils, "create_or_get_virtualenv", create_or_get_virtualenv)
36-
try:
37-
yield
38-
finally:
39-
setattr(
40-
virtualenv_utils,
41-
"create_or_get_virtualenv",
42-
original_create_or_get_virtualenv,
43-
)
39+
with _monkey_patch_lock:
40+
if phase_index > 0:
41+
setattr(
42+
virtualenv_utils, "create_or_get_virtualenv", create_or_get_virtualenv
43+
)
44+
try:
45+
yield
46+
finally:
47+
setattr(
48+
virtualenv_utils,
49+
"create_or_get_virtualenv",
50+
original_create_or_get_virtualenv,
51+
)
4452

4553

4654
class OrderedPipPlugin(RuntimeEnvPlugin):
@@ -93,27 +101,81 @@ def try_import_torch():
93101
"""
94102

95103
name = "ordered_pip"
104+
105+
# VV: Configure Ray to use this RuntimeEnvPlugin last
106+
priority = 100
96107
ClassPath = "orchestrator.utilities.ray_env.ordered_pip.OrderedPipPlugin"
97108

98109
def __init__(self, resources_dir: str | None = None):
99-
if resources_dir is None:
100-
import ray._private.ray_constants as ray_constants
110+
self._global_mtx = threading.RLock()
111+
self._create_env_mtx: dict[str, threading.RLock] = {}
112+
self._pip_resources_dir = resources_dir
101113

102-
resources_dir = os.environ.get(
103-
ray_constants.RAY_RUNTIME_ENV_CREATE_WORKING_DIR_ENV_VAR
104-
)
114+
# VV: Maintains a cache of the environments that have been built thus far
115+
self._cache = {}
105116

106-
if not resources_dir:
107-
import tempfile
117+
def _try_switch_resources_dir_from_context(
118+
self,
119+
context: "RuntimeEnvContext", # noqa: F821
120+
logger: logging.Logger | None = default_logger,
121+
):
122+
# VV: When ray instantiates custom RuntimeEnvPlugins it does not provide a resources_dir path.
123+
# This method is a HACK that the resources_dir based on the RuntimeEnvContext which is known
124+
# at the time of CREATING a virtual environment i.e. **after** the RuntimeEnvPlugin is initialized.
125+
126+
with self._global_mtx:
127+
# VV: Stick with whatever resources dir we've already picked
128+
if self._pip_resources_dir:
129+
return
130+
131+
logger.info("Generating resources dir")
132+
unique = set()
133+
if "PYTHONPATH" in context.env_vars:
134+
# VV: This is a HACK to find the "runtime_resources" path inside the PYTHONPATH env-var
135+
# This is an env-var that the WorkingDirPlugin inserts.
136+
# I noticed that sometimes the PYTHONPATH contains multiple copies of the same PATH.
137+
# The PYTHONPATH looks like this:
138+
# /tmp/ray/session_$timestamp/runtime_resources/working_dir_files/_ray_pkg_$uid
139+
many = context.env_vars["PYTHONPATH"].split(os.pathsep)
140+
logger.info(f"Current PYTHONPATH {many}")
141+
runtime_resources_followup = f"{os.sep}working_dir_files{os.sep}"
142+
unique.update(
143+
[
144+
os.path.join(
145+
x.split(runtime_resources_followup, 1)[0], "ordered_pip"
146+
)
147+
for x in many
148+
if runtime_resources_followup in x
149+
]
150+
)
108151

109-
resources_dir = tempfile.mkdtemp(prefix="ordered_pip_", dir="/tmp/ray")
152+
logger.info(f"The candidate locations of runtime_resources: {list(unique)}")
110153

111-
self._pip_resources_dir = resources_dir
154+
if len(unique) != 1:
155+
import tempfile
156+
157+
unique.clear()
158+
unique.add(tempfile.mkdtemp(prefix="ordered_pip_", dir="/tmp/ray"))
112159

113-
from ray._common.utils import try_to_create_directory
160+
self._switch_resources_dir(unique.pop())
114161

115-
try_to_create_directory(self._pip_resources_dir)
116-
self._pip_plugin = PipPlugin(self._pip_resources_dir)
162+
def _switch_resources_dir(self, resources_dir: str):
163+
with self._global_mtx:
164+
from ray._common.utils import try_to_create_directory
165+
166+
self._pip_resources_dir = resources_dir
167+
try_to_create_directory(self._pip_resources_dir)
168+
169+
@property
170+
def _pip_plugin(self) -> PipPlugin:
171+
# The PipPlugin keeps an internal cache of virtual environments it has created but not yet deleted.
172+
# When .create() is called, it checks this cache for a venv matching the given URI.
173+
# If a match is found, it assumes the venv already exists and skips re-creation.
174+
# However, ordered_pip needs to reuse the same venv multiple times (once per "phase").
175+
# Thus, we create a new PipPlugin instance on demand for each phase of ordered_pip.
176+
# Also, we maintain our own record of venvs to decide whether to create a new "ordered_pip"
177+
# venv or reuse an existing one.
178+
return PipPlugin(self._pip_resources_dir)
117179

118180
@staticmethod
119181
def validate(runtime_env_dict: dict[str, typing.Any]) -> RuntimeEnv:
@@ -132,14 +194,14 @@ def validate(runtime_env_dict: dict[str, typing.Any]) -> RuntimeEnv:
132194
raise ValueError("runtime_env must be a dictionary")
133195

134196
if "ordered_pip" not in runtime_env_dict:
135-
raise ValueError("missing the 'ordered_pip' key", runtime_env_dict)
197+
return RuntimeEnv(**runtime_env_dict)
136198

137199
if not isinstance(runtime_env_dict["ordered_pip"], dict):
138200
raise ValueError("runtime_env['ordered_pip'] must be a dictionary")
139201

140202
if not isinstance(runtime_env_dict["ordered_pip"]["phases"], list):
141203
raise ValueError(
142-
"runtime_env['ordered_pip']['phases'] must be a dictionary consistent with pip"
204+
"runtime_env['ordered_pip']['phases'] must be an array of pip entries"
143205
)
144206

145207
phases = []
@@ -164,6 +226,9 @@ def validate(runtime_env_dict: dict[str, typing.Any]) -> RuntimeEnv:
164226
return result
165227

166228
def get_uris(self, runtime_env: "RuntimeEnv") -> list[str]:
229+
if not self.is_ordered_pip_runtimeenv(runtime_env):
230+
return []
231+
167232
# VV: We want the hash to be invariant to the order of package names within a phase,
168233
# and we also want the order of phases to be reflected in the hash.
169234
aggregate_packages = [
@@ -178,31 +243,76 @@ def get_uris(self, runtime_env: "RuntimeEnv") -> list[str]:
178243
"pip://" + hashlib.sha1(str(aggregate_packages).encode("utf-8")).hexdigest()
179244
]
180245

246+
def is_ordered_pip_runtimeenv(self, runtime_env: "RuntimeEnv") -> bool:
247+
return bool(self.validate(runtime_env).get("ordered_pip"))
248+
181249
async def create(
182250
self,
183251
uri: str,
184252
runtime_env: "RuntimeEnv", # noqa: F821
185253
context: "RuntimeEnvContext", # noqa: F821
186254
logger: logging.Logger | None = default_logger,
187255
) -> int:
256+
self._try_switch_resources_dir_from_context(context, logger)
257+
258+
if not self.is_ordered_pip_runtimeenv(runtime_env):
259+
return 0
260+
188261
uri = self.get_uris(runtime_env)[0]
189-
total_bytes = 0
190-
191-
for idx, pip in enumerate(self.validate(runtime_env)["ordered_pip"]["phases"]):
192-
with patch_create_or_get_virtualenv(idx):
193-
total_bytes += await self._pip_plugin.create(
194-
uri=uri,
195-
runtime_env=RuntimeEnv(pip=pip),
196-
context=context,
197-
logger=logger,
198-
)
199262

200-
return total_bytes
263+
with self._global_mtx:
264+
if uri not in self._create_env_mtx:
265+
self._create_env_mtx[uri] = threading.RLock()
266+
267+
with self._create_env_mtx[uri]:
268+
logger.info(f"Creating {uri} for {runtime_env}")
269+
try:
270+
if os.path.isdir(self.get_path_to_pip_venv(uri)):
271+
logger.info(f"Virtual environment for {uri} already exists")
272+
return self._cache[uri]
273+
except KeyError:
274+
pass
275+
276+
self._cache[uri] = 0
277+
for idx, pip in enumerate(
278+
self.validate(runtime_env)["ordered_pip"]["phases"]
279+
):
280+
with patch_create_or_get_virtualenv(idx):
281+
logger.info(f"Creating {idx} for {uri}")
282+
283+
self._cache[uri] += await self._pip_plugin.create(
284+
uri=uri,
285+
runtime_env=RuntimeEnv(pip=pip),
286+
context=context,
287+
logger=logger,
288+
)
289+
logger.info(f"Done creating {idx} for {uri}")
290+
291+
return self._cache[uri]
292+
293+
def get_path_to_pip_venv(self, uri: str) -> str:
294+
_, env_hash = ray._private.runtime_env.packaging.parse_uri(uri)
295+
return os.path.join(self._pip_resources_dir, "pip", env_hash)
201296

202297
def delete_uri(
203298
self, uri: str, logger: logging.Logger | None = default_logger
204299
) -> int:
205-
return self._pip_plugin.delete_uri(uri=uri, logger=logger)
300+
logger.info(f"Cleaning up {uri}")
301+
del self._cache[uri]
302+
303+
import shutil
304+
305+
import ray._private.utils
306+
307+
env_dir = self.get_path_to_pip_venv(uri)
308+
num_bytes = ray._private.utils.get_directory_size_bytes(env_dir)
309+
310+
try:
311+
shutil.rmtree(env_dir)
312+
except Exception as e:
313+
logger.warning(f"Exception while cleaning up {env_dir} {e!s} - will ignore")
314+
315+
return num_bytes
206316

207317
def modify_context(
208318
self,
@@ -211,7 +321,14 @@ def modify_context(
211321
context: "RuntimeEnvContext", # noqa: F821
212322
logger: logging.Logger = default_logger,
213323
):
214-
phases = self.validate(runtime_env)["ordered_pip"]["phases"]
324+
self._try_switch_resources_dir_from_context(context)
325+
326+
runtime_env = self.validate(runtime_env)
327+
if not runtime_env.get("ordered_pip"):
328+
return
329+
330+
logger.info(f"Modifying the context for {uris} and {runtime_env}")
331+
phases = runtime_env["ordered_pip"]["phases"]
215332

216333
if not len(phases):
217334
return
@@ -222,3 +339,18 @@ def modify_context(
222339
context=context,
223340
logger=logger,
224341
)
342+
343+
if "PYTHONPATH" in context.env_vars:
344+
# VV: Ensure unique paths in PYTHONPATH
345+
paths = context.env_vars["PYTHONPATH"].split(os.pathsep)
346+
347+
unique = []
348+
for k in paths:
349+
if k not in unique:
350+
unique.append(k)
351+
352+
context.env_vars["PYTHONPATH"] = os.pathsep.join(unique)
353+
354+
logger.info(
355+
f"Modified the context for {uris} and {runtime_env} with {context.py_executable} {context.env_vars}"
356+
)

0 commit comments

Comments
 (0)