Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ classifiers = [
]

dependencies = [
"dlt>=1.17.1",
"dlt[duckdb]>=1.17.1",
"fastmcp>=2.11.3",
"pandas>=2.3.3",
]

[project.urls]
Expand Down
2 changes: 2 additions & 0 deletions src/dlt_mcp/_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def register_tool(fn: FunctionType) -> FunctionType:
execute_sql_query,
get_load_table,
get_pipeline_local_state,
get_table_schema_changes,
)


Expand All @@ -32,3 +33,4 @@ def register_tool(fn: FunctionType) -> FunctionType:
register_tool(execute_sql_query)
register_tool(get_load_table)
register_tool(get_pipeline_local_state)
register_tool(get_table_schema_changes)
61 changes: 59 additions & 2 deletions src/dlt_mcp/_tools/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@
It shouldn't depend on packages that aren't installed by `dlt`
"""

from typing import Any
from difflib import unified_diff
import json
import pprint
from typing import Any, Optional

import dlt
from dlt.common.schema.typing import LOADS_TABLE_NAME
from dlt.common.pipeline import TPipelineState
from dlt.common.schema.typing import TTableSchema
from dlt.common.pipeline import get_dlt_pipelines_dir
from dlt.common.storages.file_storage import FileStorage
import pandas as pd


def list_pipelines() -> list[str]:
Expand Down Expand Up @@ -58,7 +62,7 @@ def get_load_table(pipeline_name: str) -> list[dict[str, Any]]:
pipeline = dlt.attach(pipeline_name)
dataset = pipeline.dataset()
load_table = dataset(f"SELECT * FROM {LOADS_TABLE_NAME};").fetchall()
columns = list(dataset.schema.tables[LOADS_TABLE_NAME]["columns"])
columns = list(dataset.schema.tables[LOADS_TABLE_NAME]["columns"]) # type: ignore
return [dict(zip(columns, row)) for row in load_table]


Expand All @@ -68,3 +72,56 @@ def get_pipeline_local_state(pipeline_name: str) -> TPipelineState:
"""
pipeline = dlt.attach(pipeline_name)
return pipeline.state


def get_table_schema_changes(
pipeline_name: str, table_name: str, another_version_hash: Optional[str] = None
) -> str:
"""Retrieve the diff between versions of tables compared to it's previous version"""
pipeline = dlt.attach(pipeline_name)

dataset = pipeline.dataset()
schemas = _get_schemas(
another_version_hash, pipeline.default_schema.version_hash, dataset
)

if len(schemas) < 2:
return "There has been no change in the schema"
current_schema = _load_schema_for_table(table_name, schemas.iloc[0]["schema"])
previous_schema = _load_schema_for_table(table_name, schemas.iloc[1]["schema"])

return _dict_diff(current_schema, previous_schema, "Previous Schema")


def _get_schemas(another_version_hash, version_hash, dataset) -> pd.DataFrame:
if another_version_hash:
version_hashes = [version_hash, another_version_hash]
# Properly format the list for SQL IN clause
quoted_hashes = [f"'{h}'" for h in version_hashes]
hashes_str = ",".join(quoted_hashes)
return dataset.query(
f"select schema from _dlt_version where version_hash in ({hashes_str}) order by inserted_at desc"
).df()
return dataset.query(
"select schema from _dlt_version order by inserted_at desc limit 2"
).df()


def _load_schema_for_table(table_name, schema):
schema_dict = json.loads(schema).get("tables").get(table_name)
return schema_dict


def _dict_diff(schema_dict, another_schema_dict, compared_to: str) -> str:
# Convert dictionaries to string representation
str1 = pprint.pformat(schema_dict)
str2 = pprint.pformat(another_schema_dict)

# Split into lines
lines1 = str1.splitlines(keepends=True)
lines2 = str2.splitlines(keepends=True)

# Generate diff
return "".join(
unified_diff(lines2, lines1, fromfile="Current Schema", tofile=compared_to)
)
113 changes: 113 additions & 0 deletions tests/tools/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import dlt

from dlt_mcp._tools.core import get_table_schema_changes


@dlt.resource(table_name="users")
def user_data(updated_user: bool):
if updated_user:
yield [
{"id": 1, "name": "Alice", "email": "[email protected]", "age": 30},
{"id": 2, "name": "Bob", "email": "[email protected]", "age": 35},
]
yield [
{"id": 1, "name": "Alice", "email": "[email protected]"},
{"id": 2, "name": "Bob", "email": "[email protected]"},
]


def test_get_table_schema_changes_when_schema_has_changed():
pipeline_name = "table_schema_change_pipeline"
pipeline = dlt.pipeline(pipeline_name, destination="duckdb", dev_mode=True)

pipeline.run(user_data(False))
pipeline.run(user_data(True))

expected_diff_message = """--- Current Schema
+++ Previous Schema
@@ -6,6 +6,7 @@
'_dlt_load_id': {'data_type': 'text',
'name': '_dlt_load_id',
'nullable': False},
+ 'age': {'data_type': 'bigint', 'name': 'age', 'nullable': True},
'email': {'data_type': 'text', 'name': 'email', 'nullable': True},
'id': {'data_type': 'bigint', 'name': 'id', 'nullable': True},
'name': {'data_type': 'text', 'name': 'name', 'nullable': True}},
"""

diff = get_table_schema_changes(pipeline_name, "users")
# Print actual and expected diff for readability in case of failure
assert diff.strip() == expected_diff_message.strip(), f"""
Expected and actual schema differences do not match.

Expected:
{expected_diff_message.strip()}

Actual:
{diff.strip()}
"""


def test_get_table_schema_should_say_no_change():
pipeline_name = "no_change_pipeline"
pipeline = dlt.pipeline(pipeline_name, destination="duckdb", dev_mode=True)

# Run the resource twice with the same schema
pipeline.run(user_data(False))
pipeline.run(user_data(False))

# Get schema changes
diff = get_table_schema_changes(pipeline_name, "users")

# Assert that there are no schema changes
assert diff.strip() == "There has been no change in the schema", f"""
Expected no schema changes, but got:
{diff.strip()}
"""


def test_get_table_schema_with_same_version_hash():
pipeline_name = "schema_time_comparison_pipeline"
pipeline = dlt.pipeline(pipeline_name, destination="duckdb", dev_mode=True)

load_info = pipeline.run(user_data(False))
version_hash = load_info.load_packages[0].schema_hash
diff = get_table_schema_changes(pipeline_name, "users", version_hash)

assert diff.strip() == "There has been no change in the schema", f"""
Expected no schema changes, but got:
{diff.strip()}
"""


def test_get_table_schema_with_different_version_hash():
pipeline_name = "schema_time_comparison_pipeline"
pipeline = dlt.pipeline(pipeline_name, destination="duckdb", dev_mode=True)

load_info = pipeline.run(user_data(False))
version_hash = load_info.load_packages[0].schema_hash

pipeline.run(user_data(True))
diff = get_table_schema_changes(pipeline_name, "users", version_hash)

expected_diff_message = """--- Current Schema
+++ Previous Schema
@@ -6,6 +6,7 @@
'_dlt_load_id': {'data_type': 'text',
'name': '_dlt_load_id',
'nullable': False},
+ 'age': {'data_type': 'bigint', 'name': 'age', 'nullable': True},
'email': {'data_type': 'text', 'name': 'email', 'nullable': True},
'id': {'data_type': 'bigint', 'name': 'id', 'nullable': True},
'name': {'data_type': 'text', 'name': 'name', 'nullable': True}},
"""

assert diff.strip() == expected_diff_message.strip(), f"""
Expected and actual schema differences do not match.

Expected:
{expected_diff_message.strip()}

Actual:
{diff.strip()}
"""
2 changes: 2 additions & 0 deletions tests/tools/test_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def test_expected_tools_in_all_clause():
"execute_sql_query",
"get_load_table",
"get_pipeline_local_state",
"get_table_schema_changes",
]

assert len(TOOLS_REGISTRY) == len(expected_tool_names)
Expand All @@ -35,6 +36,7 @@ def test_expected_tools_are_registered():
"execute_sql_query",
"get_load_table",
"get_pipeline_local_state",
"get_table_schema_changes",
]

mcp_server = create_server()
Expand Down
Loading