diff --git a/tesseract_core/runtime/config.py b/tesseract_core/runtime/config.py index 5376b738..aa21cf25 100644 --- a/tesseract_core/runtime/config.py +++ b/tesseract_core/runtime/config.py @@ -23,6 +23,8 @@ class RuntimeConfig(BaseModel): output_format: supported_format_type = "json" output_file: str = "" mlflow_tracking_uri: str = "" + mlflow_tracking_username: str = "" + mlflow_tracking_password: str = "" model_config = ConfigDict(frozen=True, extra="forbid") diff --git a/tesseract_core/runtime/mpa.py b/tesseract_core/runtime/mpa.py index 1c69fc5d..4c6d1dda 100644 --- a/tesseract_core/runtime/mpa.py +++ b/tesseract_core/runtime/mpa.py @@ -16,6 +16,7 @@ from io import UnsupportedOperation from pathlib import Path from typing import Any +from urllib.parse import quote, urlparse import requests @@ -139,9 +140,14 @@ def __init__(self, base_dir: str | None = None) -> None: "MLflow is required for MLflowBackend but is not installed" ) from exc - self._ensure_mlflow_reachable() self.mlflow = mlflow + tracking_uri = MLflowBackend._build_tracking_uri() + self._ensure_mlflow_reachable(tracking_uri) + mlflow.set_tracking_uri(tracking_uri) + @staticmethod + def _build_tracking_uri() -> str: + """Build the MLflow tracking URI with embedded credentials if provided.""" config = get_config() tracking_uri = config.mlflow_tracking_uri @@ -154,20 +160,42 @@ def __init__(self, base_dir: str | None = None) -> None: tracking_uri = (Path(get_config().output_path) / tracking_uri).resolve() tracking_uri = f"sqlite:///{tracking_uri}" + else: + username = config.mlflow_tracking_username + password = config.mlflow_tracking_password - mlflow.set_tracking_uri(tracking_uri) - - def _ensure_mlflow_reachable(self) -> None: + if bool(username) != bool(password): + raise RuntimeError( + "If one of TESSERACT_MLFLOW_TRACKING_USERNAME and TESSERACT_MLFLOW_TRACKING_PASSWORD is defined, " + "both must be defined." + ) + + if username and password: + parsed = urlparse(tracking_uri) + # Reconstruct URI with embedded credentials + tracking_uri = ( + f"{parsed.scheme}://{quote(username)}:{quote(password)}@" + f"{parsed.netloc}{parsed.path}" + ) + if parsed.query: + tracking_uri += f"?{parsed.query}" + if parsed.fragment: + tracking_uri += f"#{parsed.fragment}" + + return tracking_uri + + def _ensure_mlflow_reachable(self, tracking_uri: str) -> None: """Check if the MLflow tracking server is reachable.""" - config = get_config() - mlflow_tracking_uri = config.mlflow_tracking_uri - if mlflow_tracking_uri.startswith(("http://", "https://")): + if tracking_uri.startswith(("http://", "https://")): try: - response = requests.get(mlflow_tracking_uri, timeout=5) + response = requests.get(tracking_uri, timeout=5) response.raise_for_status() except requests.RequestException as e: + # Don't expose credentials in error message - use the original URI + config = get_config() + display_uri = config.mlflow_tracking_uri raise RuntimeError( - f"Failed to connect to MLflow tracking server at {mlflow_tracking_uri}. " + f"Failed to connect to MLflow tracking server at {display_uri}. " "Please make sure an MLflow server is running and TESSERACT_MLFLOW_TRACKING_URI is set correctly, " "or switch to file-based logging by setting TESSERACT_MLFLOW_TRACKING_URI to an empty string." ) from e diff --git a/tests/runtime_tests/test_mpa.py b/tests/runtime_tests/test_mpa.py index 2ff7de9d..86cbd5f1 100644 --- a/tests/runtime_tests/test_mpa.py +++ b/tests/runtime_tests/test_mpa.py @@ -204,3 +204,92 @@ def test_mlflow_log_calls(tmpdir): continue assert artifact_found + + +def test_build_tracking_uri_with_credentials(): + pytest.importorskip("mlflow") + update_config( + mlflow_tracking_uri="http://localhost:5000", + mlflow_tracking_username="testuser", + mlflow_tracking_password="testpass", + ) + tracking_uri = mpa.MLflowBackend._build_tracking_uri() + assert tracking_uri == "http://testuser:testpass@localhost:5000" + + +def test_build_tracking_uri_without_credentials(): + pytest.importorskip("mlflow") + update_config( + mlflow_tracking_uri="http://localhost:5000", + mlflow_tracking_username="", + mlflow_tracking_password="", + ) + tracking_uri = mpa.MLflowBackend._build_tracking_uri() + assert tracking_uri == "http://localhost:5000" + + +def test_build_tracking_uri_url_encoded_credentials(): + pytest.importorskip("mlflow") + update_config( + mlflow_tracking_uri="https://mlflow.example.com", + mlflow_tracking_username="user@example.com", + mlflow_tracking_password="p@ss:w0rd!", + ) + tracking_uri = mpa.MLflowBackend._build_tracking_uri() + assert ( + tracking_uri == "https://user%40example.com:p%40ss%3Aw0rd%21@mlflow.example.com" + ) + + +def test_build_tracking_uri_with_path_and_query(): + pytest.importorskip("mlflow") + update_config( + mlflow_tracking_uri="http://localhost:5000/api/mlflow?param=value", + mlflow_tracking_username="testuser", + mlflow_tracking_password="testpass", + ) + tracking_uri = mpa.MLflowBackend._build_tracking_uri() + assert ( + tracking_uri == "http://testuser:testpass@localhost:5000/api/mlflow?param=value" + ) + + +def test_build_tracking_uri_username_without_password(): + pytest.importorskip("mlflow") + update_config( + mlflow_tracking_uri="http://localhost:5000", + mlflow_tracking_username="testuser", + mlflow_tracking_password="", + ) + with pytest.raises( + RuntimeError, + match="If one of TESSERACT_MLFLOW_TRACKING_USERNAME and TESSERACT_MLFLOW_TRACKING_PASSWORD is defined", + ): + mpa.MLflowBackend._build_tracking_uri() + + +def test_build_tracking_uri_password_without_username(): + pytest.importorskip("mlflow") + update_config( + mlflow_tracking_uri="http://localhost:5000", + mlflow_tracking_username="", + mlflow_tracking_password="testpass", + ) + with pytest.raises( + RuntimeError, + match="If one of TESSERACT_MLFLOW_TRACKING_USERNAME and TESSERACT_MLFLOW_TRACKING_PASSWORD is defined", + ): + mpa.MLflowBackend._build_tracking_uri() + + +def test_build_tracking_uri_sqlite_ignores_credentials(): + pytest.importorskip("mlflow") + update_config( + mlflow_tracking_uri="sqlite:///mlflow.db", + mlflow_tracking_username="testuser", + mlflow_tracking_password="testpass", + ) + tracking_uri = mpa.MLflowBackend._build_tracking_uri() + assert "testuser" not in tracking_uri + assert "testpass" not in tracking_uri + assert tracking_uri.startswith("sqlite:///")