Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tesseract_core/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class RuntimeConfig(BaseModel):
output_format: supported_format_type = "json"
output_file: str = ""
mlflow_tracking_uri: str = ""
mlflow_tracking_username: str = ""
mlflow_tracking_password: str = ""

model_config = ConfigDict(frozen=True, extra="forbid")

Expand Down
46 changes: 37 additions & 9 deletions tesseract_core/runtime/mpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from io import UnsupportedOperation
from pathlib import Path
from typing import Any
from urllib.parse import quote, urlparse

import requests

Expand Down Expand Up @@ -139,9 +140,14 @@ def __init__(self, base_dir: str | None = None) -> None:
"MLflow is required for MLflowBackend but is not installed"
) from exc

self._ensure_mlflow_reachable()
self.mlflow = mlflow
tracking_uri = MLflowBackend._build_tracking_uri()
self._ensure_mlflow_reachable(tracking_uri)
mlflow.set_tracking_uri(tracking_uri)

@staticmethod
def _build_tracking_uri() -> str:
"""Build the MLflow tracking URI with embedded credentials if provided."""
config = get_config()
tracking_uri = config.mlflow_tracking_uri

Expand All @@ -154,20 +160,42 @@ def __init__(self, base_dir: str | None = None) -> None:
tracking_uri = (Path(get_config().output_path) / tracking_uri).resolve()

tracking_uri = f"sqlite:///{tracking_uri}"
else:
username = config.mlflow_tracking_username
password = config.mlflow_tracking_password

mlflow.set_tracking_uri(tracking_uri)

def _ensure_mlflow_reachable(self) -> None:
if bool(username) != bool(password):
raise RuntimeError(
"If one of TESSERACT_MLFLOW_TRACKING_USERNAME and TESSERACT_MLFLOW_TRACKING_PASSWORD is defined, "
"both must be defined."
)

if username and password:
parsed = urlparse(tracking_uri)
# Reconstruct URI with embedded credentials
tracking_uri = (
f"{parsed.scheme}://{quote(username)}:{quote(password)}@"
f"{parsed.netloc}{parsed.path}"
)
if parsed.query:
tracking_uri += f"?{parsed.query}"
if parsed.fragment:
tracking_uri += f"#{parsed.fragment}"

return tracking_uri

def _ensure_mlflow_reachable(self, tracking_uri: str) -> None:
"""Check if the MLflow tracking server is reachable."""
config = get_config()
mlflow_tracking_uri = config.mlflow_tracking_uri
if mlflow_tracking_uri.startswith(("http://", "https://")):
if tracking_uri.startswith(("http://", "https://")):
try:
response = requests.get(mlflow_tracking_uri, timeout=5)
response = requests.get(tracking_uri, timeout=5)
response.raise_for_status()
except requests.RequestException as e:
# Don't expose credentials in error message - use the original URI
config = get_config()
display_uri = config.mlflow_tracking_uri
raise RuntimeError(
f"Failed to connect to MLflow tracking server at {mlflow_tracking_uri}. "
f"Failed to connect to MLflow tracking server at {display_uri}. "
"Please make sure an MLflow server is running and TESSERACT_MLFLOW_TRACKING_URI is set correctly, "
"or switch to file-based logging by setting TESSERACT_MLFLOW_TRACKING_URI to an empty string."
) from e
Expand Down
89 changes: 89 additions & 0 deletions tests/runtime_tests/test_mpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,92 @@ def test_mlflow_log_calls(tmpdir):
continue

assert artifact_found


def test_build_tracking_uri_with_credentials():
pytest.importorskip("mlflow")
update_config(
mlflow_tracking_uri="http://localhost:5000",
mlflow_tracking_username="testuser",
mlflow_tracking_password="testpass",
)
tracking_uri = mpa.MLflowBackend._build_tracking_uri()
assert tracking_uri == "http://testuser:testpass@localhost:5000"


def test_build_tracking_uri_without_credentials():
pytest.importorskip("mlflow")
update_config(
mlflow_tracking_uri="http://localhost:5000",
mlflow_tracking_username="",
mlflow_tracking_password="",
)
tracking_uri = mpa.MLflowBackend._build_tracking_uri()
assert tracking_uri == "http://localhost:5000"


def test_build_tracking_uri_url_encoded_credentials():
pytest.importorskip("mlflow")
update_config(
mlflow_tracking_uri="https://mlflow.example.com",
mlflow_tracking_username="[email protected]",
mlflow_tracking_password="p@ss:w0rd!",
)
tracking_uri = mpa.MLflowBackend._build_tracking_uri()
assert (
tracking_uri == "https://user%40example.com:p%40ss%3Aw0rd%[email protected]"
)


def test_build_tracking_uri_with_path_and_query():
pytest.importorskip("mlflow")
update_config(
mlflow_tracking_uri="http://localhost:5000/api/mlflow?param=value",
mlflow_tracking_username="testuser",
mlflow_tracking_password="testpass",
)
tracking_uri = mpa.MLflowBackend._build_tracking_uri()
assert (
tracking_uri == "http://testuser:testpass@localhost:5000/api/mlflow?param=value"
)


def test_build_tracking_uri_username_without_password():
pytest.importorskip("mlflow")
update_config(
mlflow_tracking_uri="http://localhost:5000",
mlflow_tracking_username="testuser",
mlflow_tracking_password="",
)
with pytest.raises(
RuntimeError,
match="If one of TESSERACT_MLFLOW_TRACKING_USERNAME and TESSERACT_MLFLOW_TRACKING_PASSWORD is defined",
):
mpa.MLflowBackend._build_tracking_uri()


def test_build_tracking_uri_password_without_username():
pytest.importorskip("mlflow")
update_config(
mlflow_tracking_uri="http://localhost:5000",
mlflow_tracking_username="",
mlflow_tracking_password="testpass",
)
with pytest.raises(
RuntimeError,
match="If one of TESSERACT_MLFLOW_TRACKING_USERNAME and TESSERACT_MLFLOW_TRACKING_PASSWORD is defined",
):
mpa.MLflowBackend._build_tracking_uri()


def test_build_tracking_uri_sqlite_ignores_credentials():
pytest.importorskip("mlflow")
update_config(
mlflow_tracking_uri="sqlite:///mlflow.db",
mlflow_tracking_username="testuser",
mlflow_tracking_password="testpass",
)
tracking_uri = mpa.MLflowBackend._build_tracking_uri()
assert "testuser" not in tracking_uri
assert "testpass" not in tracking_uri
assert tracking_uri.startswith("sqlite:///")
Loading