|
5 | 5 | import functools |
6 | 6 | import logging |
7 | 7 | import os |
| 8 | +import shutil |
8 | 9 | import socket |
9 | 10 | import subprocess |
10 | 11 | import sys |
11 | 12 | import tempfile |
12 | 13 | import uuid |
13 | 14 | import zipfile |
14 | 15 | from collections.abc import Awaitable |
| 16 | +from contextlib import contextmanager |
| 17 | +from importlib.util import find_spec |
| 18 | +from io import BytesIO |
15 | 19 | from typing import TYPE_CHECKING, Any, Callable, ClassVar |
| 20 | +from types import ModuleType |
| 21 | +from typing import Any, Tuple |
| 22 | +from pathlib import Path |
16 | 23 |
|
17 | 24 | from dask.typing import Key |
18 | 25 | from dask.utils import _deprecated_kwarg, funcname, tmpfile |
|
29 | 36 | from distributed.scheduler import TaskStateState as SchedulerTaskStateState |
30 | 37 | from distributed.worker import Worker |
31 | 38 | from distributed.worker_state_machine import TaskStateState as WorkerTaskStateState |
| 39 | + from distributed.node import ServerNode |
32 | 40 |
|
33 | 41 | logger = logging.getLogger(__name__) |
34 | 42 |
|
@@ -1102,3 +1110,189 @@ def setup(self, worker): |
1102 | 1110 |
|
1103 | 1111 | def teardown(self, worker): |
1104 | 1112 | self._exit_stack.close() |
| 1113 | + |
| 1114 | + |
| 1115 | +@contextmanager |
| 1116 | +def serialize_module( |
| 1117 | + module: ModuleType, exclude: Tuple[str, ...] = ("__pycache__", ".DS_Store") |
| 1118 | +) -> Generator[tuple[Path, str], Any, None]: |
| 1119 | + if module.__file__ is None: # Need this to satisfy mypy |
| 1120 | + raise ValueError(f"Module {module.__name__} has no __file__ attribute") |
| 1121 | + module_path = Path(module.__file__) |
| 1122 | + |
| 1123 | + if module_path.stem == "__init__": |
| 1124 | + # In case of package we serialize the whole package |
| 1125 | + module_path = module_path.parent |
| 1126 | + if "." in module.__name__: |
| 1127 | + # TODO: the problem is that we serialize the `package.module`, as module.egg that contains module.py, |
| 1128 | + # but it should contain the whole structure of the package (package/module.py) |
| 1129 | + raise Exception( |
| 1130 | + f"Plugin supports only top-level packages or single-file modules. You provided `{module.__name__}`, try `{module.__name__.split('.')[0]}`." |
| 1131 | + ) |
| 1132 | + |
| 1133 | + upload_id = str(uuid.uuid4()) |
| 1134 | + |
| 1135 | + # In case of single file we don't need to serialize anything |
| 1136 | + |
| 1137 | + with tempfile.TemporaryDirectory() as tmp: |
| 1138 | + package_name = module_path.name |
| 1139 | + |
| 1140 | + package_copy_path = Path(tmp).joinpath(package_name) |
| 1141 | + if module_path.is_dir(): |
| 1142 | + copied_package = Path( |
| 1143 | + shutil.copytree( |
| 1144 | + module_path, |
| 1145 | + package_copy_path, |
| 1146 | + ignore=shutil.ignore_patterns(f"{package_name}.zip", *exclude), |
| 1147 | + ) |
| 1148 | + ) |
| 1149 | + with open(package_copy_path / "__init__.py", "a") as f: |
| 1150 | + f.write(f"\n__upload_id__ = '{upload_id}'") |
| 1151 | + else: |
| 1152 | + copied_package = Path(shutil.copy2(module_path, package_copy_path)) |
| 1153 | + with open(copied_package, "a") as f: |
| 1154 | + f.write(f"\n__upload_id__ = '{upload_id}'") |
| 1155 | + |
| 1156 | + archive_path = shutil.make_archive( |
| 1157 | + # output path including a name w/o extension |
| 1158 | + base_name=str(copied_package), |
| 1159 | + format="zip", |
| 1160 | + # chroot |
| 1161 | + root_dir=copied_package.parent, |
| 1162 | + # Name of the directory to archive and a common prefix of all files and directories in the archive |
| 1163 | + base_dir=package_name, |
| 1164 | + ) |
| 1165 | + |
| 1166 | + egg_file = shutil.move(archive_path, package_copy_path.with_suffix(".egg")) |
| 1167 | + |
| 1168 | + # zip file handler |
| 1169 | + zip = ZipFile(egg_file) |
| 1170 | + # list available files in the container |
| 1171 | + logger.debug( |
| 1172 | + "The egg file %s contains the following files %s", |
| 1173 | + str(egg_file), |
| 1174 | + str(zip.namelist()), |
| 1175 | + ) |
| 1176 | + |
| 1177 | + logger.info("Created an egg file %s from %s", str(egg_file), str(module_path)) |
| 1178 | + |
| 1179 | + yield Path(egg_file), upload_id |
| 1180 | + |
| 1181 | + |
| 1182 | +class AbstractUploadModulePlugin: |
| 1183 | + def __init__(self, module: ModuleType): |
| 1184 | + self._module_name = module.__name__ |
| 1185 | + self._data: bytes |
| 1186 | + self._filepath: Path |
| 1187 | + self._filename: str |
| 1188 | + with serialize_module(module) as (filepath, upload_id): |
| 1189 | + self._upload_id = upload_id |
| 1190 | + self._filename = filepath.name |
| 1191 | + with open(filepath, "rb") as f: |
| 1192 | + self._data = f.read() |
| 1193 | + |
| 1194 | + async def _upload_file(self, node: ServerNode): |
| 1195 | + response = await node.upload_file(self._filename, self._data, load=True) |
| 1196 | + assert len(self._data) == response["nbytes"] |
| 1197 | + |
| 1198 | + async def _upload(self, node: ServerNode): |
| 1199 | + import zipfile |
| 1200 | + import sys |
| 1201 | + try: |
| 1202 | + from IPython.extensions.autoreload import superreload |
| 1203 | + except ImportError: |
| 1204 | + superreload = lambda x: x |
| 1205 | + |
| 1206 | + # Try to find already loaded module |
| 1207 | + module = ( |
| 1208 | + sys.modules[self._module_name] if self._module_name in sys.modules else None |
| 1209 | + ) |
| 1210 | + # Try to find module on disk |
| 1211 | + module_spec = find_spec(self._module_name) |
| 1212 | + |
| 1213 | + if not module_spec and not module: |
| 1214 | + # If the module does not exist, we keep it as an egg file and load it. |
| 1215 | + # It happens when we create a module that does not exist on the node. It is a rare case but still possible. |
| 1216 | + logger.info( |
| 1217 | + 'Uploading a new module "%s" to "%s" on %s "%s"', |
| 1218 | + self._module_name, |
| 1219 | + str(self._filename), |
| 1220 | + "worker" if isinstance(node, Worker) else "scheduler", |
| 1221 | + node.id, |
| 1222 | + ) |
| 1223 | + await self._upload_file(node) |
| 1224 | + return |
| 1225 | + |
| 1226 | + if module: |
| 1227 | + module_path = self._get_module_dir(module) |
| 1228 | + else: |
| 1229 | + module_path = Path(module_spec.origin) # type: ignore |
| 1230 | + |
| 1231 | + if ".egg" in str(module_path): |
| 1232 | + # Update the previously uploaded egg module and reload it. |
| 1233 | + logger.info( |
| 1234 | + 'Uploading an update for a previously uploaded a new module "%s" to "%s" on %s "%s"', |
| 1235 | + self._module_name, |
| 1236 | + str(self._filename), |
| 1237 | + "worker" if isinstance(node, Worker) else "scheduler", |
| 1238 | + node.id, |
| 1239 | + ) |
| 1240 | + await self._upload_file(node) |
| 1241 | + return |
| 1242 | + |
| 1243 | + if module_path.name == "__init__.py": |
| 1244 | + # Uploading the package |
| 1245 | + extract_to = module_path.parent.parent |
| 1246 | + else: |
| 1247 | + # When we are uploading a single file |
| 1248 | + extract_to = module_path.parent |
| 1249 | + |
| 1250 | + with zipfile.ZipFile(BytesIO(self._data), "r") as zip_ref: |
| 1251 | + # In case, we received egg file for module that exists on node in source code, |
| 1252 | + # we overwrite each file separately by extracting it from the egg. |
| 1253 | + logger.info( |
| 1254 | + 'Uploading an update for an existing module "%s" in "%s" on %s "%s"', |
| 1255 | + self._module_name, |
| 1256 | + str(extract_to), |
| 1257 | + "worker" if isinstance(node, Worker) else "scheduler", |
| 1258 | + node.id, |
| 1259 | + ) |
| 1260 | + zip_ref.extractall(extract_to) |
| 1261 | + |
| 1262 | + # TODO: Do we really need Jupyter's `superreload` here instead of built-in Python's function? |
| 1263 | + if self._module_name in sys.modules: |
| 1264 | + # Reload module if it is already loaded |
| 1265 | + superreload(sys.modules[self._module_name]) |
| 1266 | + # Validate that the uploaded module has the correct upload_id |
| 1267 | + if sys.modules[self._module_name].__upload_id__ != self._upload_id: |
| 1268 | + logger.warning( |
| 1269 | + f"Module {self._module_name} was not updated. Old version found. __upload_id__ mismatch: expected {self._upload_id}, got {sys.modules[self._module_name].__upload_id__}") |
| 1270 | + elif "__upload_id__" not in sys.modules[self._module_name].__dict__: |
| 1271 | + logger.warning( |
| 1272 | + f"Module {self._module_name} was not updated. Missing __upload_id__") |
| 1273 | + |
| 1274 | + @classmethod |
| 1275 | + def _get_module_dir(cls, module: ModuleType) -> Path: |
| 1276 | + """Get the directory of the module.""" |
| 1277 | + module_path = Path(sys.modules[module.__name__].__file__) # type: ignore |
| 1278 | + |
| 1279 | + if module_path.stem == "__init__": |
| 1280 | + # In case of package we serialize the whole package |
| 1281 | + return module_path.parent |
| 1282 | + |
| 1283 | + # In case of single file we don't need to serialize anything |
| 1284 | + return module_path |
| 1285 | + |
| 1286 | + |
| 1287 | +class UploadModule(WorkerPlugin, AbstractUploadModulePlugin): |
| 1288 | + name = "upload_module" |
| 1289 | + |
| 1290 | + async def setup(self, worker: Worker): |
| 1291 | + await self._upload(worker) |
| 1292 | + |
| 1293 | + |
| 1294 | +class SchedulerUploadModule(SchedulerPlugin, AbstractUploadModulePlugin): |
| 1295 | + name = "upload_module" |
| 1296 | + |
| 1297 | + async def start(self, scheduler: Scheduler) -> None: |
| 1298 | + await self._upload(scheduler) |
0 commit comments