|
| 1 | +# Based on the shell script's variable names for clarity in logic translation. |
| 2 | +from __future__ import annotations |
| 3 | + |
| 4 | +import dataclasses |
| 5 | +import enum |
| 6 | +import shutil |
| 7 | +import unittest |
| 8 | +from pathlib import Path |
| 9 | + |
| 10 | +JUPYTER_MINIMAL_NOTEBOOK_ID = "minimal" |
| 11 | +JUPYTER_DATASCIENCE_NOTEBOOK_ID = "datascience" |
| 12 | +JUPYTER_TRUSTYAI_NOTEBOOK_ID = "trustyai" |
| 13 | +JUPYTER_PYTORCH_NOTEBOOK_ID = "pytorch" |
| 14 | +JUPYTER_TENSORFLOW_NOTEBOOK_ID = "tensorflow" |
| 15 | + |
| 16 | +RSTUDIO_NOTEBOOK_ID = "rstudio" |
| 17 | + |
| 18 | +MAKE = shutil.which("gmake") or shutil.which("make") |
| 19 | + |
| 20 | + |
| 21 | +@enum.unique |
| 22 | +class NotebookType(enum.Enum): |
| 23 | + """Enum for the different notebook types.""" |
| 24 | + |
| 25 | + RUNTIME = "runtime" |
| 26 | + WORKBENCH = "workbench" |
| 27 | + |
| 28 | + |
| 29 | +@dataclasses.dataclass(frozen=True) |
| 30 | +class NotebookMetadata: |
| 31 | + """Stores metadata parsed from a notebook's directory path.""" |
| 32 | + |
| 33 | + type: NotebookType |
| 34 | + feature: str |
| 35 | + |
| 36 | + """Name of the notebook identifier (e.g., 'minimal', 'pytorch').""" |
| 37 | + scope: str |
| 38 | + |
| 39 | + """The operating system flavor (e.g., 'ubi9')""" |
| 40 | + os_flavor: str |
| 41 | + |
| 42 | + """The python version string (e.g., 'python-3.12')""" |
| 43 | + python_flavor: str |
| 44 | + |
| 45 | + """The accelerator flavor (e.g., 'cuda', 'cpu', or None)""" |
| 46 | + accelerator_flavor: str | None |
| 47 | + |
| 48 | + |
| 49 | +def extract_metadata_from_path(directory: Path) -> NotebookMetadata: |
| 50 | + """ |
| 51 | + Parses a notebook's directory path to extract metadata needed to find its manifest. |
| 52 | + This logic is derived from the test_jupyter_with_papermill.sh script. |
| 53 | +
|
| 54 | + Args: |
| 55 | + directory: The directory containing the notebook's pyproject.toml. |
| 56 | + (e.g., .../jupyter/rocm/tensorflow/ubi9-python-3.12) |
| 57 | +
|
| 58 | + Returns: |
| 59 | + A dataclass containing the parsed notebook metadata. |
| 60 | +
|
| 61 | + Raises: |
| 62 | + ValueError: If the path format is unexpected and metadata cannot be extracted. |
| 63 | + """ |
| 64 | + # 1. Parse OS and Python flavor from the directory name |
| 65 | + os_python_part = directory.name # e.g., 'ubi9-python-3.12' |
| 66 | + try: |
| 67 | + os_flavor, python_version_str = os_python_part.split("-python-") |
| 68 | + python_flavor = f"python-{python_version_str}" |
| 69 | + except ValueError as e: |
| 70 | + raise ValueError(f"Directory name '{os_python_part}' does not match 'os-python-version' format.") from e |
| 71 | + |
| 72 | + # 2. Find the notebook's characteristic path components |
| 73 | + path_parts = directory.parts |
| 74 | + # Find the root component ('jupyter', 'runtimes', etc.) to anchor the search |
| 75 | + for root_candidate in ("jupyter", "codeserver", "rstudio", "runtimes"): |
| 76 | + try: |
| 77 | + start_index = path_parts.index(root_candidate) |
| 78 | + break |
| 79 | + except ValueError: |
| 80 | + continue |
| 81 | + else: |
| 82 | + raise ValueError(f"Cannot determine notebook root in path: {directory}") from None |
| 83 | + |
| 84 | + # The parts between the root and the OS/python dir define the notebook flavor |
| 85 | + # e.g., ('minimal',), ('rocm', 'tensorflow',), ('pytorch',) |
| 86 | + notebook_identity_parts = path_parts[start_index + 1 : -1] |
| 87 | + |
| 88 | + # Determine scope (e.g., 'minimal', 'tensorflow') |
| 89 | + # The shell script uses the last part of the path-like notebook_id. |
| 90 | + try: |
| 91 | + scope = notebook_identity_parts[-1] |
| 92 | + except IndexError: |
| 93 | + # rstudio and codeserver don't have scope |
| 94 | + scope = "" |
| 95 | + if "-" in scope: |
| 96 | + assert path_parts[start_index] == "runtimes", "this naming pattern only appears in rocm runtime images" |
| 97 | + scope = scope.split("-", 1)[-1] |
| 98 | + |
| 99 | + # Determine accelerator flavor |
| 100 | + accelerator_flavor = None |
| 101 | + if "rocm" in notebook_identity_parts: |
| 102 | + accelerator_flavor = "rocm" |
| 103 | + elif "cuda" in notebook_identity_parts: |
| 104 | + accelerator_flavor = "cuda" |
| 105 | + # The shell script has an implicit rule for pytorch being cuda. We can |
| 106 | + # replicate this by checking for a specific Dockerfile. |
| 107 | + elif (directory / "Dockerfile.cuda").exists(): |
| 108 | + accelerator_flavor = "cuda" |
| 109 | + elif (directory / "Dockerfile.rocm").exists(): |
| 110 | + accelerator_flavor = "rocm" |
| 111 | + |
| 112 | + return NotebookMetadata( |
| 113 | + type=NotebookType.RUNTIME if "runtimes" == path_parts[start_index] else NotebookType.WORKBENCH, |
| 114 | + feature="runtime" if path_parts[start_index] == "runtimes" else path_parts[start_index], |
| 115 | + scope="datascience" if path_parts[start_index] == "codeserver" else scope, |
| 116 | + os_flavor=os_flavor, |
| 117 | + python_flavor=python_flavor, |
| 118 | + accelerator_flavor=accelerator_flavor, |
| 119 | + ) |
| 120 | + |
| 121 | + |
| 122 | +def get_source_of_truth_filepath( |
| 123 | + root_repo_directory: Path, |
| 124 | + metadata: NotebookMetadata, |
| 125 | +) -> Path: |
| 126 | + """ |
| 127 | + Computes the absolute path of the imagestream manifest for the notebook under test. |
| 128 | + This is a Python conversion of the shell function `_get_source_of_truth_filepath`. |
| 129 | +
|
| 130 | + Returns: |
| 131 | + The absolute path to the imagestream manifest file. |
| 132 | +
|
| 133 | + Raises: |
| 134 | + ValueError: If the logic cannot determine the filename for the given inputs. |
| 135 | + """ |
| 136 | + notebook_id = metadata.feature |
| 137 | + python_flavor = metadata.python_flavor |
| 138 | + os_flavor = metadata.os_flavor |
| 139 | + accelerator_flavor = metadata.accelerator_flavor |
| 140 | + |
| 141 | + manifest_directory = root_repo_directory / "manifests" |
| 142 | + filename = "" |
| 143 | + |
| 144 | + if python_flavor == "python-3.12": |
| 145 | + imagestream_directory = manifest_directory / "overlays" / "additional" |
| 146 | + file_suffix = "imagestream.yaml" |
| 147 | + |
| 148 | + if metadata.type == NotebookType.WORKBENCH: |
| 149 | + feature = metadata.feature |
| 150 | + elif metadata.type == NotebookType.RUNTIME: |
| 151 | + # WARNING: we need the jupyter imagestream, because runtime stream does not list software versions |
| 152 | + feature = "jupyter" |
| 153 | + else: |
| 154 | + raise NotImplementedError(f"Unsupported notebook type: {metadata.type}") |
| 155 | + |
| 156 | + scope = metadata.scope.replace("+", "-") # pytorch+llmcompressor |
| 157 | + |
| 158 | + # Shell script defaults accelerator to 'cpu' if it's not set |
| 159 | + current_accelerator = accelerator_flavor or "cpu" |
| 160 | + # Assumes python_flavor is like 'python-3.12' -> 'py312' |
| 161 | + py_version_short = "py" + python_flavor.split("-")[1].replace(".", "") |
| 162 | + filename = f"{feature}-{scope}-{current_accelerator}-{py_version_short}-{os_flavor}-{file_suffix}" |
| 163 | + else: |
| 164 | + # Default case from the shell script for other python versions |
| 165 | + imagestream_directory = manifest_directory / "base" |
| 166 | + file_suffix = "notebook-imagestream.yaml" |
| 167 | + |
| 168 | + if JUPYTER_MINIMAL_NOTEBOOK_ID in notebook_id: |
| 169 | + # Logic for minimal notebook |
| 170 | + accelerator_prefix = f"{accelerator_flavor}-" if accelerator_flavor else "" |
| 171 | + filename = f"jupyter-{accelerator_prefix}{notebook_id}-{file_suffix}" |
| 172 | + if accelerator_flavor == "cuda": |
| 173 | + filename = f"jupyter-{notebook_id}-gpu-{file_suffix}" |
| 174 | + |
| 175 | + elif JUPYTER_DATASCIENCE_NOTEBOOK_ID in notebook_id or JUPYTER_TRUSTYAI_NOTEBOOK_ID in notebook_id: |
| 176 | + # Logic for datascience and trustyai |
| 177 | + filename = f"jupyter-{notebook_id}-{file_suffix}" |
| 178 | + |
| 179 | + elif JUPYTER_PYTORCH_NOTEBOOK_ID in notebook_id or JUPYTER_TENSORFLOW_NOTEBOOK_ID in notebook_id: |
| 180 | + # Logic for pytorch and tensorflow |
| 181 | + accelerator_prefix = f"{accelerator_flavor}-" if accelerator_flavor else "" |
| 182 | + filename = f"jupyter-{accelerator_prefix}{notebook_id}-{file_suffix}" |
| 183 | + if accelerator_flavor == "cuda": |
| 184 | + # This override is intentionally different from the 'minimal' one, as per the script |
| 185 | + filename = f"jupyter-{notebook_id}-{file_suffix}" |
| 186 | + |
| 187 | + elif RSTUDIO_NOTEBOOK_ID in notebook_id: |
| 188 | + filename = f"rstudio-gpu-{file_suffix}" |
| 189 | + |
| 190 | + if not filename: |
| 191 | + raise ValueError( |
| 192 | + f"Unable to determine imagestream filename for notebook_id='{notebook_id}', " |
| 193 | + f"python_flavor='{python_flavor}', accelerator_flavor='{accelerator_flavor}'" |
| 194 | + ) |
| 195 | + |
| 196 | + filepath = imagestream_directory / filename |
| 197 | + |
| 198 | + return filepath |
| 199 | + |
| 200 | + |
| 201 | +class SelfTests(unittest.TestCase): |
| 202 | + def test_rstudio_path(self): |
| 203 | + metadata = extract_metadata_from_path(Path("notebooks/rstudio/rhel9-python-3.11")) |
| 204 | + assert metadata == NotebookMetadata( |
| 205 | + type=NotebookType.WORKBENCH, |
| 206 | + feature="rstudio", |
| 207 | + scope="", |
| 208 | + os_flavor="rhel9", |
| 209 | + python_flavor="python-3.11", |
| 210 | + accelerator_flavor=None, |
| 211 | + ) |
| 212 | + |
| 213 | + def test_rstudio_truth_manifest(self): |
| 214 | + metadata = extract_metadata_from_path(Path("notebooks/rstudio/rhel9-python-3.11")) |
| 215 | + path = get_source_of_truth_filepath(root_repo_directory=Path("notebooks"), metadata=metadata) |
| 216 | + assert path == Path("notebooks/manifests/base/rstudio-gpu-notebook-imagestream.yaml") |
| 217 | + |
| 218 | + def test_jupyter_path(self): |
| 219 | + metadata = extract_metadata_from_path(Path("notebooks/jupyter/rocm/tensorflow/ubi9-python-3.12")) |
| 220 | + assert metadata == NotebookMetadata( |
| 221 | + type=NotebookType.WORKBENCH, |
| 222 | + feature="jupyter", |
| 223 | + scope="tensorflow", |
| 224 | + os_flavor="ubi9", |
| 225 | + python_flavor="python-3.12", |
| 226 | + accelerator_flavor="rocm", |
| 227 | + ) |
| 228 | + |
| 229 | + def test_codeserver(self): |
| 230 | + metadata = extract_metadata_from_path(Path("notebooks/codeserver/ubi9-python-3.12")) |
| 231 | + assert metadata == NotebookMetadata( |
| 232 | + type=NotebookType.WORKBENCH, |
| 233 | + feature="codeserver", |
| 234 | + scope="datascience", |
| 235 | + os_flavor="ubi9", |
| 236 | + python_flavor="python-3.12", |
| 237 | + accelerator_flavor=None, |
| 238 | + ) |
| 239 | + |
| 240 | + def test_codeserver_path(self): |
| 241 | + metadata = extract_metadata_from_path(Path("notebooks/codeserver/ubi9-python-3.12")) |
| 242 | + path = get_source_of_truth_filepath(root_repo_directory=Path("notebooks"), metadata=metadata) |
| 243 | + assert path == Path( |
| 244 | + "notebooks/manifests/overlays/additional/codeserver-datascience-cpu-py312-ubi9-imagestream.yaml" |
| 245 | + ) |
| 246 | + |
| 247 | + def test_runtime_pytorch_path(self): |
| 248 | + metadata = extract_metadata_from_path( |
| 249 | + Path("/Users/jdanek/IdeaProjects/notebooks/runtimes/rocm-tensorflow/ubi9-python-3.12") |
| 250 | + ) |
| 251 | + assert metadata == NotebookMetadata( |
| 252 | + type=NotebookType.RUNTIME, |
| 253 | + feature="runtime", |
| 254 | + scope="tensorflow", |
| 255 | + os_flavor="ubi9", |
| 256 | + python_flavor="python-3.12", |
| 257 | + accelerator_flavor="rocm", |
| 258 | + ) |
| 259 | + |
| 260 | + def test_jupyter_pytorch_path(self): |
| 261 | + """We need to get path to the Jupyter imagestream, not to runtime imagestream""" |
| 262 | + metadata = extract_metadata_from_path( |
| 263 | + Path("/Users/jdanek/IdeaProjects/notebooks/runtimes/rocm-tensorflow/ubi9-python-3.12") |
| 264 | + ) |
| 265 | + path = get_source_of_truth_filepath(root_repo_directory=Path("notebooks"), metadata=metadata) |
| 266 | + assert path == Path( |
| 267 | + "notebooks/manifests/overlays/additional/jupyter-tensorflow-rocm-py312-ubi9-imagestream.yaml" |
| 268 | + ) |
| 269 | + |
| 270 | + def test_source_of_truth_jupyter_tensorflow_rocm(self): |
| 271 | + metadata = extract_metadata_from_path(Path("notebooks/jupyter/rocm/tensorflow/ubi9-python-3.12")) |
| 272 | + path = get_source_of_truth_filepath(root_repo_directory=Path("notebooks"), metadata=metadata) |
| 273 | + assert path == Path( |
| 274 | + "notebooks/manifests/overlays/additional/jupyter-tensorflow-rocm-py312-ubi9-imagestream.yaml" |
| 275 | + ) |
0 commit comments