Skip to content

Commit 1fed636

Browse files
committed
Implement module upload plugin (#8698)
1 parent 3d2685c commit 1fed636

File tree

1 file changed

+194
-0
lines changed

1 file changed

+194
-0
lines changed

distributed/diagnostics/plugin.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,21 @@
55
import functools
66
import logging
77
import os
8+
import shutil
89
import socket
910
import subprocess
1011
import sys
1112
import tempfile
1213
import uuid
1314
import zipfile
1415
from collections.abc import Awaitable
16+
from contextlib import contextmanager
17+
from importlib.util import find_spec
18+
from io import BytesIO
1519
from typing import TYPE_CHECKING, Any, Callable, ClassVar
20+
from types import ModuleType
21+
from typing import Any, Tuple
22+
from pathlib import Path
1623

1724
from dask.typing import Key
1825
from dask.utils import _deprecated_kwarg, funcname, tmpfile
@@ -29,6 +36,7 @@
2936
from distributed.scheduler import TaskStateState as SchedulerTaskStateState
3037
from distributed.worker import Worker
3138
from distributed.worker_state_machine import TaskStateState as WorkerTaskStateState
39+
from distributed.node import ServerNode
3240

3341
logger = logging.getLogger(__name__)
3442

@@ -1102,3 +1110,189 @@ def setup(self, worker):
11021110

11031111
def teardown(self, worker):
11041112
self._exit_stack.close()
1113+
1114+
1115+
@contextmanager
1116+
def serialize_module(
1117+
module: ModuleType, exclude: Tuple[str, ...] = ("__pycache__", ".DS_Store")
1118+
) -> Generator[tuple[Path, str], Any, None]:
1119+
if module.__file__ is None: # Need this to satisfy mypy
1120+
raise ValueError(f"Module {module.__name__} has no __file__ attribute")
1121+
module_path = Path(module.__file__)
1122+
1123+
if module_path.stem == "__init__":
1124+
# In case of package we serialize the whole package
1125+
module_path = module_path.parent
1126+
if "." in module.__name__:
1127+
# TODO: the problem is that we serialize the `package.module`, as module.egg that contains module.py,
1128+
# but it should contain the whole structure of the package (package/module.py)
1129+
raise Exception(
1130+
f"Plugin supports only top-level packages or single-file modules. You provided `{module.__name__}`, try `{module.__name__.split('.')[0]}`."
1131+
)
1132+
1133+
upload_id = str(uuid.uuid4())
1134+
1135+
# In case of single file we don't need to serialize anything
1136+
1137+
with tempfile.TemporaryDirectory() as tmp:
1138+
package_name = module_path.name
1139+
1140+
package_copy_path = Path(tmp).joinpath(package_name)
1141+
if module_path.is_dir():
1142+
copied_package = Path(
1143+
shutil.copytree(
1144+
module_path,
1145+
package_copy_path,
1146+
ignore=shutil.ignore_patterns(f"{package_name}.zip", *exclude),
1147+
)
1148+
)
1149+
with open(package_copy_path / "__init__.py", "a") as f:
1150+
f.write(f"\n__upload_id__ = '{upload_id}'")
1151+
else:
1152+
copied_package = Path(shutil.copy2(module_path, package_copy_path))
1153+
with open(copied_package, "a") as f:
1154+
f.write(f"\n__upload_id__ = '{upload_id}'")
1155+
1156+
archive_path = shutil.make_archive(
1157+
# output path including a name w/o extension
1158+
base_name=str(copied_package),
1159+
format="zip",
1160+
# chroot
1161+
root_dir=copied_package.parent,
1162+
# Name of the directory to archive and a common prefix of all files and directories in the archive
1163+
base_dir=package_name,
1164+
)
1165+
1166+
egg_file = shutil.move(archive_path, package_copy_path.with_suffix(".egg"))
1167+
1168+
# zip file handler
1169+
zip = ZipFile(egg_file)
1170+
# list available files in the container
1171+
logger.debug(
1172+
"The egg file %s contains the following files %s",
1173+
str(egg_file),
1174+
str(zip.namelist()),
1175+
)
1176+
1177+
logger.info("Created an egg file %s from %s", str(egg_file), str(module_path))
1178+
1179+
yield Path(egg_file), upload_id
1180+
1181+
1182+
class AbstractUploadModulePlugin:
1183+
def __init__(self, module: ModuleType):
1184+
self._module_name = module.__name__
1185+
self._data: bytes
1186+
self._filepath: Path
1187+
self._filename: str
1188+
with serialize_module(module) as (filepath, upload_id):
1189+
self._upload_id = upload_id
1190+
self._filename = filepath.name
1191+
with open(filepath, "rb") as f:
1192+
self._data = f.read()
1193+
1194+
async def _upload_file(self, node: ServerNode):
1195+
response = await node.upload_file(self._filename, self._data, load=True)
1196+
assert len(self._data) == response["nbytes"]
1197+
1198+
async def _upload(self, node: ServerNode):
1199+
import zipfile
1200+
import sys
1201+
try:
1202+
from IPython.extensions.autoreload import superreload
1203+
except ImportError:
1204+
superreload = lambda x: x
1205+
1206+
# Try to find already loaded module
1207+
module = (
1208+
sys.modules[self._module_name] if self._module_name in sys.modules else None
1209+
)
1210+
# Try to find module on disk
1211+
module_spec = find_spec(self._module_name)
1212+
1213+
if not module_spec and not module:
1214+
# If the module does not exist, we keep it as an egg file and load it.
1215+
# It happens when we create a module that does not exist on the node. It is a rare case but still possible.
1216+
logger.info(
1217+
'Uploading a new module "%s" to "%s" on %s "%s"',
1218+
self._module_name,
1219+
str(self._filename),
1220+
"worker" if isinstance(node, Worker) else "scheduler",
1221+
node.id,
1222+
)
1223+
await self._upload_file(node)
1224+
return
1225+
1226+
if module:
1227+
module_path = self._get_module_dir(module)
1228+
else:
1229+
module_path = Path(module_spec.origin) # type: ignore
1230+
1231+
if ".egg" in str(module_path):
1232+
# Update the previously uploaded egg module and reload it.
1233+
logger.info(
1234+
'Uploading an update for a previously uploaded a new module "%s" to "%s" on %s "%s"',
1235+
self._module_name,
1236+
str(self._filename),
1237+
"worker" if isinstance(node, Worker) else "scheduler",
1238+
node.id,
1239+
)
1240+
await self._upload_file(node)
1241+
return
1242+
1243+
if module_path.name == "__init__.py":
1244+
# Uploading the package
1245+
extract_to = module_path.parent.parent
1246+
else:
1247+
# When we are uploading a single file
1248+
extract_to = module_path.parent
1249+
1250+
with zipfile.ZipFile(BytesIO(self._data), "r") as zip_ref:
1251+
# In case, we received egg file for module that exists on node in source code,
1252+
# we overwrite each file separately by extracting it from the egg.
1253+
logger.info(
1254+
'Uploading an update for an existing module "%s" in "%s" on %s "%s"',
1255+
self._module_name,
1256+
str(extract_to),
1257+
"worker" if isinstance(node, Worker) else "scheduler",
1258+
node.id,
1259+
)
1260+
zip_ref.extractall(extract_to)
1261+
1262+
# TODO: Do we really need Jupyter's `superreload` here instead of built-in Python's function?
1263+
if self._module_name in sys.modules:
1264+
# Reload module if it is already loaded
1265+
superreload(sys.modules[self._module_name])
1266+
# Validate that the uploaded module has the correct upload_id
1267+
if sys.modules[self._module_name].__upload_id__ != self._upload_id:
1268+
logger.warning(
1269+
f"Module {self._module_name} was not updated. Old version found. __upload_id__ mismatch: expected {self._upload_id}, got {sys.modules[self._module_name].__upload_id__}")
1270+
elif "__upload_id__" not in sys.modules[self._module_name].__dict__:
1271+
logger.warning(
1272+
f"Module {self._module_name} was not updated. Missing __upload_id__")
1273+
1274+
@classmethod
1275+
def _get_module_dir(cls, module: ModuleType) -> Path:
1276+
"""Get the directory of the module."""
1277+
module_path = Path(sys.modules[module.__name__].__file__) # type: ignore
1278+
1279+
if module_path.stem == "__init__":
1280+
# In case of package we serialize the whole package
1281+
return module_path.parent
1282+
1283+
# In case of single file we don't need to serialize anything
1284+
return module_path
1285+
1286+
1287+
class UploadModule(WorkerPlugin, AbstractUploadModulePlugin):
1288+
name = "upload_module"
1289+
1290+
async def setup(self, worker: Worker):
1291+
await self._upload(worker)
1292+
1293+
1294+
class SchedulerUploadModule(SchedulerPlugin, AbstractUploadModulePlugin):
1295+
name = "upload_module"
1296+
1297+
async def start(self, scheduler: Scheduler) -> None:
1298+
await self._upload(scheduler)

0 commit comments

Comments
 (0)