Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions tesseract_core/runtime/mpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,15 @@ def __init__(self, base_dir: Optional[str] = None) -> None:
tracking_uri = config.mlflow_tracking_uri

if not tracking_uri.startswith(("http://", "https://")):
# If it's a file URI, convert to local path
tracking_uri = tracking_uri.replace("file://", "")
# If it's a db file URI, convert to local path
tracking_uri = tracking_uri.replace("sqlite:///", "")

# Relative paths are resolved against the base output path
if not Path(tracking_uri).is_absolute():
tracking_uri = (Path(get_config().output_path) / tracking_uri).resolve()

tracking_uri = f"sqlite:///{tracking_uri}"

mlflow.set_tracking_uri(tracking_uri)

def _ensure_mlflow_reachable(self) -> None:
Expand Down
59 changes: 27 additions & 32 deletions tests/endtoend_tests/test_endtoend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import os
import shutil
import sqlite3
import subprocess
import uuid
from pathlib import Path
Expand Down Expand Up @@ -1107,7 +1108,7 @@ def test_mpa_mlflow_backend(mpa_test_image, tmpdir):
"tesseract",
"run",
"--env",
"TESSERACT_MLFLOW_TRACKING_URI=mlruns",
"TESSERACT_MLFLOW_TRACKING_URI=mlflow.db",
mpa_test_image,
"apply",
'{"inputs": {}}',
Expand All @@ -1122,42 +1123,36 @@ def test_mpa_mlflow_backend(mpa_test_image, tmpdir):
)
assert run_res.returncode == 0, run_res.stderr

# Check for mlruns directory structure
mlruns_dir = Path(tmpdir) / "mlruns"
assert mlruns_dir.exists()
assert (mlruns_dir / "0").exists() # Default experiment ID is 0
# Check for MLflow database file
mlflow_db_path = Path(tmpdir) / "mlflow.db"
assert mlflow_db_path.exists(), "Expected MLflow database file to exist"

# Find run directories
run_dirs = [d for d in (mlruns_dir / "0").iterdir() if d.is_dir()]
assert len(run_dirs) == 1 # Should be only one run
run_dir = run_dirs[0]
assert run_dir.is_dir()
assert (run_dir / "artifacts").exists()
assert (run_dir / "metrics").exists()
assert (run_dir / "params").exists()
# Query the database to verify content was logged
with sqlite3.connect(str(mlflow_db_path)) as conn:
cursor = conn.cursor()

# Verify parameters file
param_file = run_dir / "params" / "test_parameter"
assert param_file.exists()
with open(param_file) as f:
param_value = f.read().strip()
assert param_value == "test_param"
# Check parameters were logged
cursor.execute("SELECT key, value FROM params")
params = dict(cursor.fetchall())
assert params["test_parameter"] == "test_param"
assert params["steps_config"] == "5" # MLflow stores params as strings

# Verify metrics file
metrics_file = run_dir / "metrics" / "squared_step"
assert metrics_file.exists()
with open(metrics_file) as f:
metrics = f.readlines()
# Check metrics were logged
cursor.execute("SELECT key, value, step FROM metrics ORDER BY step")
metrics = cursor.fetchall()
assert len(metrics) == 5
for i, metric in enumerate(metrics):
parts = metric.split()
assert len(parts) == 3
assert float(parts[1]) == i**2 # Check squared_step values
assert int(parts[2]) == i

# Verify artifacts directory and artifact file
artifacts_dir = run_dir / "artifacts"
assert artifacts_dir.exists()
# Verify some of the squared_step values
squared_metrics = [m for m in metrics if m[0] == "squared_step"]
assert len(squared_metrics) == 5
assert squared_metrics[0] == ("squared_step", 0.0, 0)
assert squared_metrics[1] == ("squared_step", 1.0, 1)
assert squared_metrics[4] == ("squared_step", 16.0, 4)

# Check artifacts were logged (MLflow stores artifact info in runs table)
cursor.execute("SELECT artifact_uri FROM runs")
artifact_uris = [row[0] for row in cursor.fetchall()]
assert len(artifact_uris) > 0 # At least one run with artifacts


def test_multi_helloworld_endtoend(
Expand Down
53 changes: 44 additions & 9 deletions tests/runtime_tests/test_mpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import csv
import json
import os
import sqlite3

import pytest

Expand Down Expand Up @@ -141,17 +143,17 @@ def test_log_artifact_missing_file():
def test_mlflow_backend_creation(tmpdir):
"""Test that MLflowBackend is created when mlflow_tracking_uri is set."""
pytest.importorskip("mlflow") # Skip if MLflow is not installed
mlflow_dir = tmpdir / "mlflow_backend_test"
update_config(mlflow_tracking_uri=f"file://{mlflow_dir}")
mlflow_db_file = tmpdir / "mlflow.db"
update_config(mlflow_tracking_uri=f"sqlite:///{mlflow_db_file}")
backend = mpa._create_backend(None)
assert isinstance(backend, mpa.MLflowBackend)


def test_mlflow_log_calls(tmpdir):
"""Test MLflow backend logging functions with temporary directory."""
pytest.importorskip("mlflow") # Skip if MLflow is not installed
mlflow_dir = tmpdir / "mlflow_logging_test"
update_config(mlflow_tracking_uri=f"file://{mlflow_dir}")
mlflow_db_file = tmpdir / "mlflow.db"
update_config(mlflow_tracking_uri=f"sqlite:///{mlflow_db_file}")

with start_run():
log_parameter("model_type", "neural_network")
Expand All @@ -164,8 +166,41 @@ def test_mlflow_log_calls(tmpdir):
artifact_file.write_text("Test content", encoding="utf-8")
log_artifact(str(artifact_file))

# Verify MLflow directory structure was created
assert mlflow_dir.exists()
# MLflow creates experiment directories, so we should see some structure
mlflow_contents = list(mlflow_dir.listdir())
assert len(mlflow_contents) > 0
# Verify MLflow database file was created
assert mlflow_db_file.exists()

# Query the database to verify content was logged
with sqlite3.connect(str(mlflow_db_file)) as conn:
cursor = conn.cursor()

# Check parameters were logged
cursor.execute("SELECT key, value FROM params")
params = dict(cursor.fetchall())
assert params["model_type"] == "neural_network"
assert params["epochs"] == "100"

# Check metrics were logged
cursor.execute("SELECT key, value, step FROM metrics ORDER BY step")
metrics = cursor.fetchall()
assert len(metrics) == 2
assert metrics[0] == ("accuracy", 0.85, 0) # step defaults to 0
assert metrics[1] == ("loss", 0.25, 1)

# Check artifacts were logged (MLflow stores artifact info in runs table)
cursor.execute("SELECT artifact_uri FROM runs")
artifact_uris = [row[0] for row in cursor.fetchall()]
assert len(artifact_uris) > 0 # At least one run with artifacts

# Verify the artifact file was actually copied to the artifact location
artifact_found = False
for artifact_uri in artifact_uris:
if artifact_uri and os.path.exists(artifact_uri):
try:
artifact_files = os.listdir(artifact_uri)
if "model_config.json" in artifact_files:
artifact_found = True
break
except OSError:
continue

assert artifact_found
Loading