|
11 | 11 | import shutil |
12 | 12 | import tempfile |
13 | 13 | import time |
| 14 | +import re |
14 | 15 | from distutils import dir_util |
15 | 16 | from typing import Dict, Tuple, Union |
16 | 17 |
|
17 | 18 | from ads.common.auth import AuthContext, AuthType, create_signer |
18 | 19 | from ads.common.oci_client import OCIClientFactory |
| 20 | +from ads.config import ( |
| 21 | + CONDA_BUCKET_NAME, |
| 22 | + CONDA_BUCKET_NS, |
| 23 | +) |
19 | 24 | from ads.jobs import ( |
20 | 25 | ContainerRuntime, |
21 | 26 | DataScienceJob, |
@@ -65,6 +70,32 @@ def __init__(self, config: Dict) -> None: |
65 | 70 | self.auth_type = config["execution"].get("auth") |
66 | 71 | self.profile = config["execution"].get("oci_profile", None) |
67 | 72 | self.client = OCIClientFactory(**self.oci_auth).data_science |
| 73 | + self.object_storage = OCIClientFactory(**self.oci_auth).object_storage |
| 74 | + |
| 75 | + def _get_latest_conda_pack(self, |
| 76 | + prefix, |
| 77 | + python_version, |
| 78 | + base_conda) -> str: |
| 79 | + """ |
| 80 | + get the latest conda pack. |
| 81 | + """ |
| 82 | + try: |
| 83 | + objects = self.object_storage.list_objects(namespace_name=CONDA_BUCKET_NS, |
| 84 | + bucket_name=CONDA_BUCKET_NAME, |
| 85 | + prefix=prefix).data.objects |
| 86 | + py_str = python_version.replace(".", "") |
| 87 | + py_filter = [obj for obj in objects if f"p{py_str}" in obj.name] |
| 88 | + |
| 89 | + def extract_version(obj_name): |
| 90 | + match = re.search(rf"{prefix}([\d.]+)/", obj_name) |
| 91 | + return tuple(map(int, match.group(1).split("."))) if match else (0,) |
| 92 | + |
| 93 | + latest_obj = max(py_filter, key=lambda obj: extract_version(obj.name)) |
| 94 | + return latest_obj.name.split("/")[-1] |
| 95 | + except Exception as e: |
| 96 | + logger.warning(f"Error while fetching latest conda pack: {e}") |
| 97 | + return base_conda |
| 98 | + |
68 | 99 |
|
69 | 100 | def init( |
70 | 101 | self, |
@@ -100,6 +131,16 @@ def init( |
100 | 131 | or "" |
101 | 132 | ).lower() |
102 | 133 |
|
| 134 | + # If a tag is present |
| 135 | + if ":" in conda_slug: |
| 136 | + base_conda = conda_slug.split(":")[0] |
| 137 | + conda_slug = self._get_latest_conda_pack( |
| 138 | + self.config["prefix"], |
| 139 | + self.config["python_version"], |
| 140 | + base_conda |
| 141 | + ) |
| 142 | + logger.info(f"Proceeding with the {conda_slug} conda pack.") |
| 143 | + |
103 | 144 | # if conda slug contains '/' then the assumption is that it is a custom conda pack |
104 | 145 | # the conda prefix needs to be added |
105 | 146 | if "/" in conda_slug: |
|
0 commit comments