diff --git a/README.md b/README.md index d83294f..6393a8f 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,9 @@ An MCP server for ClickHouse. ### ClickHouse Tools +* `list_clickhouse_tenants` + * List all clickhouse tenants. + * `run_select_query` * Execute SQL queries on your ClickHouse cluster. * Input: `sql` (string): The SQL query to execute. @@ -24,6 +27,9 @@ An MCP server for ClickHouse. ### chDB Tools +* `list_chdb_tenants` + * List all chdb tenants. + * `run_chdb_select_query` * Execute SQL queries using [chDB](https://github.com/chdb-io/chdb)'s embedded ClickHouse engine. * Input: `sql` (string): The SQL query to execute. @@ -168,6 +174,132 @@ You can also enable both ClickHouse and chDB simultaneously: } ``` +Multi-tenancy configuration is also supported. This is enabled by defining custom prefixes in front of the base environment variables. The below configuration creates two tenants: `cluster1` and `cluster2`. + +```json +{ + "mcpServers": { + "mcp-clickhouse": { + "command": "uv", + "args": [ + "run", + "--with", + "mcp-clickhouse", + "--python", + "3.10", + "mcp-clickhouse" + ], + "env": { + "cluster1_CLICKHOUSE_HOST": "", + "cluster1_CLICKHOUSE_PORT": "", + "cluster1_CLICKHOUSE_USER": "", + "cluster1_CLICKHOUSE_PASSWORD": "", + "cluster1_CLICKHOUSE_SECURE": "true", + "cluster1_CLICKHOUSE_VERIFY": "true", + "cluster1_CLICKHOUSE_CONNECT_TIMEOUT": "30", + "cluster1_CLICKHOUSE_SEND_RECEIVE_TIMEOUT": "30", + "cluster1_CHDB_ENABLED": "true", + "cluster1_CHDB_DATA_PATH": "/path/to/chdb/data", + "cluster2_CLICKHOUSE_HOST": "", + "cluster2_CLICKHOUSE_PORT": "", + "cluster2_CLICKHOUSE_USER": "", + "cluster2_CLICKHOUSE_PASSWORD": "", + "cluster2_CLICKHOUSE_SECURE": "true", + "cluster2_CLICKHOUSE_VERIFY": "true", + "cluster2_CLICKHOUSE_CONNECT_TIMEOUT": "30", + "cluster2_CLICKHOUSE_SEND_RECEIVE_TIMEOUT": "30", + "cluster2_CHDB_ENABLED": "true", + "cluster2_CHDB_DATA_PATH": "/path/to/chdb/data" + } + } + } +} +``` + +If no custom prefix is defined, a `default` tenant is automatically assigned based on the original environment variables. Defining custom tenants using the reserved `default` prefix is not allowed. The below example creates two tenants: `default` and `custom`. + +```json +{ + "mcpServers": { + "mcp-clickhouse": { + "command": "uv", + "args": [ + "run", + "--with", + "mcp-clickhouse", + "--python", + "3.10", + "mcp-clickhouse" + ], + "env": { + "CLICKHOUSE_HOST": "", + "CLICKHOUSE_PORT": "", + "CLICKHOUSE_USER": "", + "CLICKHOUSE_PASSWORD": "", + "CLICKHOUSE_SECURE": "true", + "CLICKHOUSE_VERIFY": "true", + "CLICKHOUSE_CONNECT_TIMEOUT": "30", + "CLICKHOUSE_SEND_RECEIVE_TIMEOUT": "30", + "CHDB_ENABLED": "true", + "CHDB_DATA_PATH": "/path/to/chdb/data", + "custom_CLICKHOUSE_HOST": "", + "custom_CLICKHOUSE_PORT": "", + "custom_CLICKHOUSE_USER": "", + "custom_CLICKHOUSE_PASSWORD": "", + "custom_CLICKHOUSE_SECURE": "true", + "custom_CLICKHOUSE_VERIFY": "true", + "custom_CLICKHOUSE_CONNECT_TIMEOUT": "30", + "custom_CLICKHOUSE_SEND_RECEIVE_TIMEOUT": "30", + "custom_CHDB_ENABLED": "true", + "custom_CHDB_DATA_PATH": "/path/to/chdb/data" + } + } + } +} +``` + +The below example will throw an error as `default` prefix is used. + +```json +{ + "mcpServers": { + "mcp-clickhouse": { + "command": "uv", + "args": [ + "run", + "--with", + "mcp-clickhouse", + "--python", + "3.10", + "mcp-clickhouse" + ], + "env": { + "CLICKHOUSE_HOST": "", + "CLICKHOUSE_PORT": "", + "CLICKHOUSE_USER": "", + "CLICKHOUSE_PASSWORD": "", + "CLICKHOUSE_SECURE": "true", + "CLICKHOUSE_VERIFY": "true", + "CLICKHOUSE_CONNECT_TIMEOUT": "30", + "CLICKHOUSE_SEND_RECEIVE_TIMEOUT": "30", + "CHDB_ENABLED": "true", + "CHDB_DATA_PATH": "/path/to/chdb/data", + "default_CLICKHOUSE_HOST": "", + "default_CLICKHOUSE_PORT": "", + "default_CLICKHOUSE_USER": "", + "default_CLICKHOUSE_PASSWORD": "", + "default_CLICKHOUSE_SECURE": "true", + "default_CLICKHOUSE_VERIFY": "true", + "default_CLICKHOUSE_CONNECT_TIMEOUT": "30", + "default_CLICKHOUSE_SEND_RECEIVE_TIMEOUT": "30", + "default_CHDB_ENABLED": "true", + "default_CHDB_DATA_PATH": "/path/to/chdb/data" + } + } + } +} +``` + 3. Locate the command entry for `uv` and replace it with the absolute path to the `uv` executable. This ensures that the correct version of `uv` is used when starting the server. On a mac, you can find this path using `which uv`. 4. Restart Claude Desktop to apply the changes. diff --git a/mcp_clickhouse/__init__.py b/mcp_clickhouse/__init__.py index 879259d..e22e8a1 100644 --- a/mcp_clickhouse/__init__.py +++ b/mcp_clickhouse/__init__.py @@ -1,4 +1,6 @@ from .mcp_server import ( + list_clickhouse_tenants, + list_chdb_tenants, create_clickhouse_client, list_databases, list_tables, @@ -9,6 +11,8 @@ ) __all__ = [ + "list_clickhouse_tenants", + "list_chdb_tenants", "list_databases", "list_tables", "run_select_query", diff --git a/mcp_clickhouse/main.py b/mcp_clickhouse/main.py index 97599a4..0a622b5 100644 --- a/mcp_clickhouse/main.py +++ b/mcp_clickhouse/main.py @@ -1,9 +1,9 @@ from .mcp_server import mcp -from .mcp_env import get_config, TransportType +from .mcp_env import get_mcp_config, TransportType def main(): - config = get_config() + config = get_mcp_config() transport = config.mcp_server_transport # For HTTP and SSE transports, we need to specify host and port diff --git a/mcp_clickhouse/mcp_env.py b/mcp_clickhouse/mcp_env.py index 40c0424..176c3fd 100644 --- a/mcp_clickhouse/mcp_env.py +++ b/mcp_clickhouse/mcp_env.py @@ -1,4 +1,4 @@ -"""Environment configuration for the MCP ClickHouse server. +"""Environment configuration for the MCP ClickHouse server with Multi-Tenancy support. This module handles all environment variable configuration with sensible defaults and type conversion. @@ -6,7 +6,7 @@ from dataclasses import dataclass import os -from typing import Optional +from typing import Optional, Dict, List from enum import Enum @@ -30,29 +30,43 @@ class ClickHouseConfig: This class handles all environment variable configuration with sensible defaults and type conversion. It provides typed methods for accessing each configuration value. - Required environment variables (only when CLICKHOUSE_ENABLED=true): - CLICKHOUSE_HOST: The hostname of the ClickHouse server - CLICKHOUSE_USER: The username for authentication - CLICKHOUSE_PASSWORD: The password for authentication + Required environment variables (only when _CLICKHOUSE_ENABLED=true): + _CLICKHOUSE_HOST: The hostname of the ClickHouse server + _CLICKHOUSE_USER: The username for authentication + _CLICKHOUSE_PASSWORD: The password for authentication Optional environment variables (with defaults): - CLICKHOUSE_PORT: The port number (default: 8443 if secure=True, 8123 if secure=False) - CLICKHOUSE_SECURE: Enable HTTPS (default: true) - CLICKHOUSE_VERIFY: Verify SSL certificates (default: true) - CLICKHOUSE_CONNECT_TIMEOUT: Connection timeout in seconds (default: 30) - CLICKHOUSE_SEND_RECEIVE_TIMEOUT: Send/receive timeout in seconds (default: 300) - CLICKHOUSE_DATABASE: Default database to use (default: None) - CLICKHOUSE_PROXY_PATH: Path to be added to the host URL. For instance, for servers behind an HTTP proxy (default: None) - CLICKHOUSE_MCP_SERVER_TRANSPORT: MCP server transport method - "stdio", "http", or "sse" (default: stdio) - CLICKHOUSE_MCP_BIND_HOST: Host to bind the MCP server to when using HTTP or SSE transport (default: 127.0.0.1) - CLICKHOUSE_MCP_BIND_PORT: Port to bind the MCP server to when using HTTP or SSE transport (default: 8000) - CLICKHOUSE_ENABLED: Enable ClickHouse server (default: true) + _CLICKHOUSE_PORT: The port number (default: 8443 if secure=True, 8123 if secure=False) + _CLICKHOUSE_SECURE: Enable HTTPS (default: true) + _CLICKHOUSE_VERIFY: Verify SSL certificates (default: true) + _CLICKHOUSE_CONNECT_TIMEOUT: Connection timeout in seconds (default: 30) + _CLICKHOUSE_SEND_RECEIVE_TIMEOUT: Send/receive timeout in seconds (default: 300) + _CLICKHOUSE_DATABASE: Default database to use (default: None) + _CLICKHOUSE_PROXY_PATH: Path to be added to the host URL. For instance, for servers behind an HTTP proxy (default: None) + _CLICKHOUSE_MCP_SERVER_TRANSPORT: MCP server transport method - "stdio", "http", or "sse" (default: stdio) + _CLICKHOUSE_MCP_BIND_HOST: Host to bind the MCP server to when using HTTP or SSE transport (default: 127.0.0.1) + _CLICKHOUSE_MCP_BIND_PORT: Port to bind the MCP server to when using HTTP or SSE transport (default: 8000) + _CLICKHOUSE_ENABLED: Enable ClickHouse server (default: true) """ + tenant: str - def __init__(self): + def __post_init__(self): """Initialize the configuration from environment variables.""" if self.enabled: self._validate_required_vars() + + def _getenv(self, key: str, default=None, cast=str): + prefixed_key = f"{self.tenant}_{key}" + if self.tenant == "": + prefixed_key = key # default + + val = os.getenv(prefixed_key, os.getenv(key, default)) + if val is not None and cast is not str: + try: + return cast(val) + except Exception: + raise ValueError(f"Invalid value for {prefixed_key or key}: {val}") + return val @property def enabled(self) -> bool: @@ -60,12 +74,12 @@ def enabled(self) -> bool: Default: True """ - return os.getenv("CLICKHOUSE_ENABLED", "true").lower() == "true" + return self._getenv("CLICKHOUSE_ENABLED", "true", cast=lambda v: v.lower() == "true") @property def host(self) -> str: """Get the ClickHouse host.""" - return os.environ["CLICKHOUSE_HOST"] + return self._getenv("CLICKHOUSE_HOST") @property def port(self) -> int: @@ -74,24 +88,23 @@ def port(self) -> int: Defaults to 8443 if secure=True, 8123 if secure=False. Can be overridden by CLICKHOUSE_PORT environment variable. """ - if "CLICKHOUSE_PORT" in os.environ: - return int(os.environ["CLICKHOUSE_PORT"]) - return 8443 if self.secure else 8123 + default = 8443 if self.secure else 8123 + return self._getenv("CLICKHOUSE_PORT", default, cast=int) @property def username(self) -> str: """Get the ClickHouse username.""" - return os.environ["CLICKHOUSE_USER"] + return self._getenv("CLICKHOUSE_USER") @property def password(self) -> str: """Get the ClickHouse password.""" - return os.environ["CLICKHOUSE_PASSWORD"] + return self._getenv("CLICKHOUSE_PASSWORD") @property def database(self) -> Optional[str]: """Get the default database name if set.""" - return os.getenv("CLICKHOUSE_DATABASE") + return self._getenv("CLICKHOUSE_DATABASE") @property def secure(self) -> bool: @@ -99,7 +112,7 @@ def secure(self) -> bool: Default: True """ - return os.getenv("CLICKHOUSE_SECURE", "true").lower() == "true" + return self._getenv("CLICKHOUSE_SECURE", "true", cast=lambda v: v.lower() == "true") @property def verify(self) -> bool: @@ -107,7 +120,7 @@ def verify(self) -> bool: Default: True """ - return os.getenv("CLICKHOUSE_VERIFY", "true").lower() == "true" + return self._getenv("CLICKHOUSE_VERIFY", "true", cast=lambda v: v.lower() == "true") @property def connect_timeout(self) -> int: @@ -115,7 +128,7 @@ def connect_timeout(self) -> int: Default: 30 """ - return int(os.getenv("CLICKHOUSE_CONNECT_TIMEOUT", "30")) + return self._getenv("CLICKHOUSE_CONNECT_TIMEOUT", 30, cast=int) @property def send_receive_timeout(self) -> int: @@ -123,44 +136,11 @@ def send_receive_timeout(self) -> int: Default: 300 (ClickHouse default) """ - return int(os.getenv("CLICKHOUSE_SEND_RECEIVE_TIMEOUT", "300")) + return self._getenv("CLICKHOUSE_SEND_RECEIVE_TIMEOUT", 300, cast=int) @property - def proxy_path(self) -> str: - return os.getenv("CLICKHOUSE_PROXY_PATH") - - @property - def mcp_server_transport(self) -> str: - """Get the MCP server transport method. - - Valid options: "stdio", "http", "sse" - Default: "stdio" - """ - transport = os.getenv("CLICKHOUSE_MCP_SERVER_TRANSPORT", TransportType.STDIO.value).lower() - - # Validate transport type - if transport not in TransportType.values(): - valid_options = ", ".join(f'"{t}"' for t in TransportType.values()) - raise ValueError(f"Invalid transport '{transport}'. Valid options: {valid_options}") - return transport - - @property - def mcp_bind_host(self) -> str: - """Get the host to bind the MCP server to. - - Only used when transport is "http" or "sse". - Default: "127.0.0.1" - """ - return os.getenv("CLICKHOUSE_MCP_BIND_HOST", "127.0.0.1") - - @property - def mcp_bind_port(self) -> int: - """Get the port to bind the MCP server to. - - Only used when transport is "http" or "sse". - Default: 8000 - """ - return int(os.getenv("CLICKHOUSE_MCP_BIND_PORT", "8000")) + def proxy_path(self) -> Optional[str]: + return self._getenv("CLICKHOUSE_PROXY_PATH") def get_client_config(self) -> dict: """Get the configuration dictionary for clickhouse_connect client. @@ -177,7 +157,7 @@ def get_client_config(self) -> dict: "verify": self.verify, "connect_timeout": self.connect_timeout, "send_receive_timeout": self.send_receive_timeout, - "client_name": "mcp_clickhouse", + "client_name": f"mcp_clickhouse_{self.tenant if self.tenant else 'default'}", } # Add optional database if set @@ -197,11 +177,10 @@ def _validate_required_vars(self) -> None: """ missing_vars = [] for var in ["CLICKHOUSE_HOST", "CLICKHOUSE_USER", "CLICKHOUSE_PASSWORD"]: - if var not in os.environ: + if not self._getenv(var): missing_vars.append(var) - if missing_vars: - raise ValueError(f"Missing required environment variables: {', '.join(missing_vars)}") + raise ValueError(f"Missing required environment variables for tenant '{self.tenant}': {', '.join(missing_vars)}") @dataclass @@ -214,24 +193,38 @@ class ChDBConfig: Required environment variables: CHDB_DATA_PATH: The path to the chDB data directory (only required if CHDB_ENABLED=true) """ + tenant: str - def __init__(self): + def __post_init__(self): """Initialize the configuration from environment variables.""" if self.enabled: self._validate_required_vars() + def _getenv(self, key: str, default=None, cast=str): + prefixed_key = f"{self.tenant}_{key}" + if self.tenant == "": + prefixed_key = key # default + + val = os.getenv(prefixed_key, os.getenv(key, default)) + if val is not None and cast is not str: + try: + return cast(val) + except Exception: + raise ValueError(f"Invalid value for {prefixed_key or key}: {val}") + return val + @property def enabled(self) -> bool: """Get whether chDB is enabled. Default: False """ - return os.getenv("CHDB_ENABLED", "false").lower() == "true" + return self._getenv("CHDB_ENABLED", "false", cast=lambda v: v.lower() == "true") @property def data_path(self) -> str: """Get the chDB data path.""" - return os.getenv("CHDB_DATA_PATH", ":memory:") + return self._getenv("CHDB_DATA_PATH", ":memory:") def get_client_config(self) -> dict: """Get the configuration dictionary for chDB client. @@ -251,33 +244,82 @@ def _validate_required_vars(self) -> None: """ pass - -# Global instance placeholders for the singleton pattern -_CONFIG_INSTANCE = None -_CHDB_CONFIG_INSTANCE = None - - -def get_config(): +def get_mcp_config() -> dict: """ - Gets the singleton instance of ClickHouseConfig. - Instantiates it on the first call. + Get the MCP server configuration from environment variables. """ - global _CONFIG_INSTANCE - if _CONFIG_INSTANCE is None: - # Instantiate the config object here, ensuring load_dotenv() has likely run - _CONFIG_INSTANCE = ClickHouseConfig() - return _CONFIG_INSTANCE + # Global MCP transport config + MCP_TRANSPORT = os.getenv("CLICKHOUSE_MCP_SERVER_TRANSPORT", TransportType.STDIO.value).lower() + if MCP_TRANSPORT not in TransportType.values(): + raise ValueError(f"Invalid MCP transport '{MCP_TRANSPORT}'. Valid options: {TransportType.values()}") + MCP_BIND_HOST = os.getenv("CLICKHOUSE_MCP_BIND_HOST", "127.0.0.1") + MCP_BIND_PORT = int(os.getenv("CLICKHOUSE_MCP_BIND_PORT", 8000)) -def get_chdb_config() -> ChDBConfig: - """ - Gets the singleton instance of ChDBConfig. - Instantiates it on the first call. + return { + "mcp_server_transport": MCP_TRANSPORT, + "mcp_bind_host": MCP_BIND_HOST, + "mcp_bind_port": MCP_BIND_PORT, + } - Returns: - ChDBConfig: The chDB configuration instance - """ - global _CHDB_CONFIG_INSTANCE - if _CHDB_CONFIG_INSTANCE is None: - _CHDB_CONFIG_INSTANCE = ChDBConfig() - return _CHDB_CONFIG_INSTANCE +# Global instance placeholders for the singleton pattern +_CLICKHOUSE_TENANTS: Dict[str, ClickHouseConfig] = {} +_CHDB_TENANTS: Dict[str, ChDBConfig] = {} + +def load_clickhouse_configs() -> Dict[str, ClickHouseConfig]: + global _CLICKHOUSE_TENANTS + for key in os.environ: + if key.endswith("CLICKHOUSE_HOST") and not key.startswith("CLICKHOUSE_HOST"): + # _CLICKHOUSE_HOST + tenant = key[: -len("_CLICKHOUSE_HOST")] + if tenant == "default": + raise ValueError("default is a reserved tenant") + _CLICKHOUSE_TENANTS[tenant] = ClickHouseConfig(tenant=tenant) + elif key.endswith("CLICKHOUSE_HOST") and key.startswith("CLICKHOUSE_HOST"): + # default tenant -> CLICKHOUSE_HOST + _CLICKHOUSE_TENANTS["default"] = ClickHouseConfig(tenant="") + + return _CLICKHOUSE_TENANTS + +def load_chdb_configs() -> Dict[str, ChDBConfig]: + global _CHDB_TENANTS + for key in os.environ: + if key.endswith("CHDB_DATA_PATH") and not key.startswith("CHDB_DATA_PATH"): + # _CHDB_DATA_PATH + tenant = key[: -len("_CHDB_DATA_PATH")] + if tenant == "default": + raise ValueError("default is a reserved tenant") + _CHDB_TENANTS[tenant] = ChDBConfig(tenant=tenant) + elif key.endswith("CHDB_DATA_PATH") and key.startswith("CHDB_DATA_PATH"): + # default tenant -> CHDB_DATA_PATH + _CHDB_TENANTS["default"] = ChDBConfig(tenant="") + return _CHDB_TENANTS + +def get_config(tenant: str = "default") -> ClickHouseConfig: + """Get ClickHouseConfig for a specific tenant.""" + global _CLICKHOUSE_TENANTS + + # Check for tenant in the global config map + if tenant not in _CLICKHOUSE_TENANTS: + raise ValueError(f"No ClickHouse config found for tenant '{tenant}'") + + return _CLICKHOUSE_TENANTS[tenant] + +def get_chdb_config(tenant: str = "default") -> ChDBConfig: + """Get ChDBConfig for a specific tenant.""" + global _CHDB_TENANTS + + # Check for tenant in the global config map + if tenant not in _CHDB_TENANTS: + raise ValueError(f"No ChDB config found for tenant '{tenant}'") + return _CHDB_TENANTS[tenant] + +def get_clickhouse_tenants() -> List[str]: + """Get list of all clickhouse tenant names.""" + global _CLICKHOUSE_TENANTS + return [tenant for tenant in _CLICKHOUSE_TENANTS.keys()] + +def get_chdb_tenants() -> List[str]: + """Get list of all chdb tenant names.""" + global _CHDB_TENANTS + return [tenant for tenant in _CHDB_TENANTS.keys()] \ No newline at end of file diff --git a/mcp_clickhouse/mcp_server.py b/mcp_clickhouse/mcp_server.py index 589ff2e..702a89e 100644 --- a/mcp_clickhouse/mcp_server.py +++ b/mcp_clickhouse/mcp_server.py @@ -3,7 +3,6 @@ from typing import Optional, List, Any import concurrent.futures import atexit -import os import clickhouse_connect import chdb.session as chs @@ -17,7 +16,7 @@ from starlette.requests import Request from starlette.responses import PlainTextResponse -from mcp_clickhouse.mcp_env import get_config, get_chdb_config +from mcp_clickhouse.mcp_env import load_clickhouse_configs, load_chdb_configs, get_clickhouse_tenants, get_chdb_tenants, get_config, get_chdb_config from mcp_clickhouse.chdb_prompt import CHDB_PROMPT @@ -61,11 +60,32 @@ class Table: ) logger = logging.getLogger(MCP_SERVER_NAME) -QUERY_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=10) -atexit.register(lambda: QUERY_EXECUTOR.shutdown(wait=True)) -SELECT_QUERY_TIMEOUT_SECS = 30 - +# Load Configs load_dotenv() +load_clickhouse_configs() +load_chdb_configs() + +# List of Tenants +CLICKHOUSE_TENANTS = get_clickhouse_tenants() +CHDB_TENANTS = get_chdb_tenants() + +# Create ThreadPoolExecutors for each tenant +CLICKHOUSE_QUERY_EXECUTOR = { + tenant: concurrent.futures.ThreadPoolExecutor(max_workers=10) + for tenant in CLICKHOUSE_TENANTS +} + +CHDB_QUERY_EXECUTOR = { + tenant: concurrent.futures.ThreadPoolExecutor(max_workers=10) + for tenant in CHDB_TENANTS +} + +# Ensure all executors are properly shutdown on exit +atexit.register(lambda: [executor.shutdown(wait=True) for executor in CLICKHOUSE_QUERY_EXECUTOR.values()]) +atexit.register(lambda: [executor.shutdown(wait=True) for executor in CHDB_QUERY_EXECUTOR.values()]) + +# Default query timeout for selects +SELECT_QUERY_TIMEOUT_SECS = 30 mcp = FastMCP( name=MCP_SERVER_NAME, @@ -85,29 +105,40 @@ async def health_check(request: Request) -> PlainTextResponse: Returns OK if the server is running and can connect to ClickHouse. """ try: - # Check if ClickHouse is enabled by trying to create config - # If ClickHouse is disabled, this will succeed but connection will fail - clickhouse_enabled = os.getenv("CLICKHOUSE_ENABLED", "true").lower() == "true" - - if not clickhouse_enabled: - # If ClickHouse is disabled, check chDB status - chdb_config = get_chdb_config() - if chdb_config.enabled: - return PlainTextResponse("OK - MCP server running with chDB enabled") - else: - # Both ClickHouse and chDB are disabled - this is an error - return PlainTextResponse( - "ERROR - Both ClickHouse and chDB are disabled. At least one must be enabled.", - status_code=503, - ) - - # Try to create a client connection to verify ClickHouse connectivity - client = create_clickhouse_client() - version = client.server_version - return PlainTextResponse(f"OK - Connected to ClickHouse {version}") + reports = [] + + for tenant in CLICKHOUSE_TENANTS: + tenant_report = f"Tenant - '{tenant}': " + + # Check if ClickHouse is enabled by trying to create config + # If ClickHouse is disabled, this will succeed but connection will fail + try: + clickhouse_config = get_config(tenant) + if clickhouse_config.enabled: + client = create_clickhouse_client(tenant) + version = client.server_version + tenant_report += f"ClickHouse OK (v{version})" + else: + tenant_report += "ClickHouse Disabled" + except Exception as e: + tenant_report += f"ClickHouse ERROR ({str(e)})" + + # Check chDB status if enabled + try: + chdb_config = get_chdb_config(tenant) + if chdb_config.enabled: + tenant_report += ", chDB OK" + else: + tenant_report += ", chDB Disabled" + except Exception as e: + tenant_report += f", chDB ERROR ({str(e)})" + + reports.append(tenant_report) + + return PlainTextResponse("\n".join(reports)) + except Exception as e: - # Return 503 Service Unavailable if we can't connect to ClickHouse - return PlainTextResponse(f"ERROR - Cannot connect to ClickHouse: {str(e)}", status_code=503) + return PlainTextResponse(f"ERROR - Health check failed: {str(e)}", status_code=503) def result_to_table(query_columns, result) -> List[Table]: @@ -127,11 +158,34 @@ def to_json(obj: Any) -> str: return {key: to_json(value) for key, value in obj.items()} return obj +def clickhouse_tenant_available(tenant: str): + if tenant in CLICKHOUSE_TENANTS: + return True + return False + +def chdb_tenant_available(tenant: str): + if tenant in CHDB_TENANTS: + return True + return False + +def list_clickhouse_tenants(): + """List available Clickhouse tenants""" + global CLICKHOUSE_TENANTS + return json.dumps(CLICKHOUSE_TENANTS) -def list_databases(): +def list_chdb_tenants(): + """List available chDB tenants""" + global CHDB_TENANTS + return json.dumps(CHDB_TENANTS) + +def list_databases(tenant: str): """List available ClickHouse databases""" + if not clickhouse_tenant_available(tenant): + logger.warning(f"List databases not performed for invalid tenant - '{tenant}'") + raise ToolError(f"List databases not performed for invalid tenant - '{tenant}'") + logger.info("Listing all databases") - client = create_clickhouse_client() + client = create_clickhouse_client(tenant) result = client.command("SHOW DATABASES") # Convert newline-separated string to list and trim whitespace @@ -140,15 +194,20 @@ def list_databases(): else: databases = [result] - logger.info(f"Found {len(databases)} databases") + logger.info(f"Found {len(databases)} databases for tenant - '{tenant}'") return json.dumps(databases) -def list_tables(database: str, like: Optional[str] = None, not_like: Optional[str] = None): +def list_tables(tenant: str, database: str, like: Optional[str] = None, not_like: Optional[str] = None): """List available ClickHouse tables in a database, including schema, comment, row count, and column count.""" - logger.info(f"Listing tables in database '{database}'") - client = create_clickhouse_client() + + if not clickhouse_tenant_available(tenant): + logger.warning(f"List tables not performed for invalid tenant - '{tenant}'") + raise ToolError(f"List tables not performed for invalid tenant - '{tenant}'") + + logger.info(f"Listing tables for tenant - '{tenant}' in database '{database}'") + client = create_clickhouse_client(tenant) query = f"SELECT database, name, engine, create_table_query, dependencies_database, dependencies_table, engine_full, sorting_key, primary_key, total_rows, total_bytes, total_bytes_uncompressed, parts, active_parts, total_marks, comment FROM system.tables WHERE database = {format_query_value(database)}" if like: query += f" AND name LIKE {format_query_value(like)}" @@ -172,12 +231,16 @@ def list_tables(database: str, like: Optional[str] = None, not_like: Optional[st ) ] - logger.info(f"Found {len(tables)} tables") + logger.info(f"Found {len(tables)} tables for tenant - '{tenant}'") return [asdict(table) for table in tables] -def execute_query(query: str): - client = create_clickhouse_client() +def execute_query(tenant: str, query: str): + if not clickhouse_tenant_available(tenant): + logger.warning(f"Query not executed for invalid tenant - '{tenant}'") + raise ToolError(f"Query not executed for invalid tenant - '{tenant}'") + + client = create_clickhouse_client(tenant) try: read_only = get_readonly_setting(client) res = client.query(query, settings={"readonly": read_only}) @@ -188,38 +251,46 @@ def execute_query(query: str): raise ToolError(f"Query execution failed: {str(err)}") -def run_select_query(query: str): +def run_select_query(tenant: str, query: str): """Run a SELECT query in a ClickHouse database""" - logger.info(f"Executing SELECT query: {query}") + if not clickhouse_tenant_available(tenant): + logger.warning(f"Select Query not performed for invalid tenant - '{tenant}'") + raise ToolError(f"Select Query not performed for invalid tenant - '{tenant}'") + + logger.info(f"Executing SELECT query for tenant - '{tenant}': {query}") try: - future = QUERY_EXECUTOR.submit(execute_query, query) + future = CLICKHOUSE_QUERY_EXECUTOR[tenant].submit(execute_query, tenant, query) try: result = future.result(timeout=SELECT_QUERY_TIMEOUT_SECS) # Check if we received an error structure from execute_query if isinstance(result, dict) and "error" in result: - logger.warning(f"Query failed: {result['error']}") + logger.warning(f"Query failed for tenant - '{tenant}': {result['error']}") # MCP requires structured responses; string error messages can cause # serialization issues leading to BrokenResourceError return { "status": "error", - "message": f"Query failed: {result['error']}", + "message": f"Query failed for tenant - '{tenant}': {result['error']}", } return result except concurrent.futures.TimeoutError: - logger.warning(f"Query timed out after {SELECT_QUERY_TIMEOUT_SECS} seconds: {query}") + logger.warning(f"Query timed out after {SELECT_QUERY_TIMEOUT_SECS} seconds for tenant - '{tenant}': {query}") future.cancel() - raise ToolError(f"Query timed out after {SELECT_QUERY_TIMEOUT_SECS} seconds") + raise ToolError(f"Query timed out after {SELECT_QUERY_TIMEOUT_SECS} seconds for tenant - '{tenant}'") except ToolError: raise except Exception as e: - logger.error(f"Unexpected error in run_select_query: {str(e)}") - raise RuntimeError(f"Unexpected error during query execution: {str(e)}") + logger.error(f"Unexpected error in run_select_query for tenant - '{tenant}': {str(e)}") + raise RuntimeError(f"Unexpected error during query execution for tenant - '{tenant}': {str(e)}") -def create_clickhouse_client(): - client_config = get_config().get_client_config() +def create_clickhouse_client(tenant: str): + if not clickhouse_tenant_available(tenant): + logger.warning(f"Clickhouse client not created for invalid tenant - '{tenant}'") + raise ToolError(f"Clickhouse client not created for invalid tenant - '{tenant}'") + + client_config = get_config(tenant).get_client_config() logger.info( - f"Creating ClickHouse client connection to {client_config['host']}:{client_config['port']} " + f"Creating ClickHouse client connection for tenant - '{tenant}', to {client_config['host']}:{client_config['port']} " f"as {client_config['username']} " f"(secure={client_config['secure']}, verify={client_config['verify']}, " f"connect_timeout={client_config['connect_timeout']}s, " @@ -230,10 +301,10 @@ def create_clickhouse_client(): client = clickhouse_connect.get_client(**client_config) # Test the connection version = client.server_version - logger.info(f"Successfully connected to ClickHouse server version {version}") + logger.info(f"Successfully connected to ClickHouse server version {version} for tenant - '{tenant}'") return client except Exception as e: - logger.error(f"Failed to connect to ClickHouse: {str(e)}") + logger.error(f"Failed to connect to ClickHouse for tenant - '{tenant}': {str(e)}") raise @@ -267,16 +338,24 @@ def get_readonly_setting(client) -> str: return "1" # Default to basic read-only mode if setting isn't present -def create_chdb_client(): +def create_chdb_client(tenant: str): """Create a chDB client connection.""" - if not get_chdb_config().enabled: - raise ValueError("chDB is not enabled. Set CHDB_ENABLED=true to enable it.") + if not chdb_tenant_available(tenant): + logger.warning(f"chDB client not created for invalid tenant - '{tenant}'") + raise ToolError(f"chDB client not created for invalid tenant - '{tenant}'") + + if not get_chdb_config(tenant).enabled: + raise ValueError(f"chDB is not enabled for tenant - '{tenant}'. Set CHDB_ENABLED=true to enable it.") return _chdb_client -def execute_chdb_query(query: str): +def execute_chdb_query(tenant: str, query: str): """Execute a query using chDB client.""" - client = create_chdb_client() + if not chdb_tenant_available(tenant): + logger.warning(f"chDB query not executed for invalid tenant - '{tenant}'") + raise ToolError(f"chDB query not executed for invalid tenant - '{tenant}'") + + client = create_chdb_client(tenant) try: res = client.query(query, "JSON") if res.has_error(): @@ -297,33 +376,37 @@ def execute_chdb_query(query: str): return {"error": str(err)} -def run_chdb_select_query(query: str): +def run_chdb_select_query(tenant: str, query: str): """Run SQL in chDB, an in-process ClickHouse engine""" - logger.info(f"Executing chDB SELECT query: {query}") + if not chdb_tenant_available(tenant): + logger.warning(f"chDB query not performed for invalid tenant - '{tenant}'") + raise ToolError(f"chDB query not performed for invalid tenant - '{tenant}'") + + logger.info(f"Executing chDB SELECT query for tenant - '{tenant}': {query}") try: - future = QUERY_EXECUTOR.submit(execute_chdb_query, query) + future = CHDB_QUERY_EXECUTOR[tenant].submit(execute_chdb_query, tenant, query) try: result = future.result(timeout=SELECT_QUERY_TIMEOUT_SECS) # Check if we received an error structure from execute_chdb_query if isinstance(result, dict) and "error" in result: - logger.warning(f"chDB query failed: {result['error']}") + logger.warning(f"chDB query failed for tenant - '{tenant}': {result['error']}") return { "status": "error", - "message": f"chDB query failed: {result['error']}", + "message": f"chDB query failed for tenant - '{tenant}': {result['error']}", } return result except concurrent.futures.TimeoutError: logger.warning( - f"chDB query timed out after {SELECT_QUERY_TIMEOUT_SECS} seconds: {query}" + f"chDB query timed out after {SELECT_QUERY_TIMEOUT_SECS} seconds for tenant - '{tenant}': {query}" ) future.cancel() return { "status": "error", - "message": f"chDB query timed out after {SELECT_QUERY_TIMEOUT_SECS} seconds", + "message": f"chDB query timed out after {SELECT_QUERY_TIMEOUT_SECS} seconds for tenant - '{tenant}'", } except Exception as e: - logger.error(f"Unexpected error in run_chdb_select_query: {e}") - return {"status": "error", "message": f"Unexpected error: {e}"} + logger.error(f"Unexpected error in run_chdb_select_query for tenant - '{tenant}': {e}") + return {"status": "error", "message": f"Unexpected error for tenant - '{tenant}': {e}"} def chdb_initial_prompt() -> str: @@ -331,37 +414,47 @@ def chdb_initial_prompt() -> str: return CHDB_PROMPT -def _init_chdb_client(): +def _init_chdb_client(tenant: str): """Initialize the global chDB client instance.""" + if not chdb_tenant_available(tenant): + logger.warning(f"chDB client not initialised for invalid tenant - '{tenant}'") + raise ToolError(f"chDB client not initialised for invalid tenant - '{tenant}'") + try: - if not get_chdb_config().enabled: - logger.info("chDB is disabled, skipping client initialization") + if not get_chdb_config(tenant).enabled: + logger.info("chDB is disabled for tenant - '{tenant}', skipping client initialization") return None - client_config = get_chdb_config().get_client_config() + client_config = get_chdb_config(tenant).get_client_config() data_path = client_config["data_path"] - logger.info(f"Creating chDB client with data_path={data_path}") + logger.info(f"Creating chDB client with data_path={data_path} for tenant - '{tenant}'") client = chs.Session(path=data_path) - logger.info(f"Successfully connected to chDB with data_path={data_path}") + logger.info(f"Successfully connected to chDB with data_path={data_path} for tenant - '{tenant}'") return client except Exception as e: - logger.error(f"Failed to initialize chDB client: {e}") + logger.error(f"Failed to initialize chDB client for tenant - '{tenant}': {e}") return None -# Register tools based on configuration -if os.getenv("CLICKHOUSE_ENABLED", "true").lower() == "true": +# Register tools +if not CLICKHOUSE_TENANTS: + logger.info("ClickHouse tools not registered") +else: + mcp.add_tool(Tool.from_function(list_clickhouse_tenants)) mcp.add_tool(Tool.from_function(list_databases)) mcp.add_tool(Tool.from_function(list_tables)) mcp.add_tool(Tool.from_function(run_select_query)) logger.info("ClickHouse tools registered") +if not CHDB_TENANTS: + logger.info("chDB tools and prompts not registered") +else: + for tenant in CHDB_TENANTS: + _chdb_client = _init_chdb_client(tenant) + if _chdb_client: + atexit.register(lambda: _chdb_client.close()) -if os.getenv("CHDB_ENABLED", "false").lower() == "true": - _chdb_client = _init_chdb_client() - if _chdb_client: - atexit.register(lambda: _chdb_client.close()) - + mcp.add_tool(Tool.from_function(list_chdb_tenants)) mcp.add_tool(Tool.from_function(run_chdb_select_query)) chdb_prompt = Prompt.from_function( chdb_initial_prompt, diff --git a/tests/test_chdb_tool.py b/tests/test_chdb_tool.py index 1e16a93..2947fa2 100644 --- a/tests/test_chdb_tool.py +++ b/tests/test_chdb_tool.py @@ -1,37 +1,56 @@ import unittest from dotenv import load_dotenv - -from mcp_clickhouse import create_chdb_client, run_chdb_select_query +from fastmcp.exceptions import ToolError +from mcp_clickhouse import list_chdb_tenants, create_chdb_client, run_chdb_select_query load_dotenv() - class TestChDBTools(unittest.TestCase): @classmethod def setUpClass(cls): """Set up the environment before chDB tests.""" - cls.client = create_chdb_client() + cls.client = create_chdb_client(tenant="default") + + def test_list_chdb_tenants(self): + tenants = list_chdb_tenants() + self.assertIn("default", tenants) + self.assertEqual(len(tenants), 1) + + def test_run_chdb_select_query_wrong_tenant(self): + """Test running a simple SELECT query in chDB with wrong tenant.""" + tenant = "wrong_tenant" + query = "SELECT 1 as test_value" + with self.assertRaises(ToolError) as cm: + run_chdb_select_query(tenant, query) + + self.assertIn( + f"chDB query not performed for invalid tenant - '{tenant}'", + str(cm.exception) + ) def test_run_chdb_select_query_simple(self): """Test running a simple SELECT query in chDB.""" + tenant = "default" query = "SELECT 1 as test_value" - result = run_chdb_select_query(query) + result = run_chdb_select_query(tenant, query) self.assertIsInstance(result, list) self.assertIn("test_value", str(result)) def test_run_chdb_select_query_with_url_table_function(self): """Test running a SELECT query with url table function in chDB.""" + tenant = "default" query = "SELECT COUNT(1) FROM url('https://datasets.clickhouse.com/hits_compatible/athena_partitioned/hits_0.parquet', 'Parquet')" - result = run_chdb_select_query(query) + result = run_chdb_select_query(tenant, query) print(result) self.assertIsInstance(result, list) self.assertIn("1000000", str(result)) def test_run_chdb_select_query_failure(self): """Test running a SELECT query with an error in chDB.""" + tenant = "default" query = "SELECT * FROM non_existent_table_chDB" - result = run_chdb_select_query(query) + result = run_chdb_select_query(tenant, query) print(result) self.assertIsInstance(result, dict) self.assertEqual(result["status"], "error") @@ -39,8 +58,9 @@ def test_run_chdb_select_query_failure(self): def test_run_chdb_select_query_empty_result(self): """Test running a SELECT query that returns empty result in chDB.""" + tenant = "default" query = "SELECT 1 WHERE 1 = 0" - result = run_chdb_select_query(query) + result = run_chdb_select_query(tenant, query) print(result) self.assertIsInstance(result, list) self.assertEqual(len(result), 0) diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index 0119790..cc8966f 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -10,7 +10,6 @@ # Load environment variables load_dotenv() - @pytest.fixture(scope="module") def event_loop(): """Create an instance of the default event loop for the test session.""" @@ -18,11 +17,11 @@ def event_loop(): yield loop loop.close() - @pytest_asyncio.fixture(scope="module") async def setup_test_database(): """Set up test database and tables before running tests.""" - client = create_clickhouse_client() + + client = create_clickhouse_client(tenant= "default") # Test database and table names test_db = "test_mcp_db" @@ -87,13 +86,71 @@ def mcp_server(): return mcp +@pytest.mark.asyncio +async def test_list_databases_wrong_tenant(mcp_server): + """Test the list_databases tool with wrong tenant.""" + tenant = "wrong_tenant" + async with Client(mcp_server) as client: + with pytest.raises(ToolError) as exc_info: + await client.call_tool("list_databases", {"tenant": tenant}) + + assert f"List databases not performed for invalid tenant - '{tenant}'" in str(exc_info.value) + +@pytest.mark.asyncio +async def test_list_tables_wrong_tenant(mcp_server, setup_test_database): + """Test the list_tables tool with wrong tenant.""" + test_db, test_table, _ = setup_test_database + tenant = "wrong_tenant" + async with Client(mcp_server) as client: + with pytest.raises(ToolError) as exc_info: + await client.call_tool("list_tables", {"tenant": tenant, "database": test_db}) + assert f"List tables not performed for invalid tenant - '{tenant}'" in str(exc_info.value) + +@pytest.mark.asyncio +async def test_run_select_query_wrong_tenant(mcp_server, setup_test_database): + """Test the run_select_query tool with wrong tenant.""" + test_db, test_table, _ = setup_test_database + tenant = "wrong_tenant" + async with Client(mcp_server) as client: + with pytest.raises(ToolError) as exc_info: + await client.call_tool("run_select_query", {"tenant": tenant, "query": f"SELECT COUNT(*) FROM {test_db}.{test_table}",}) + + assert f"Query not performed for invalid tenant - '{tenant}'" in str(exc_info.value) + +@pytest.mark.asyncio +async def test_list_clickhouse_tenants(mcp_server): + """Test the list_clickhouse_tenants tool.""" + async with Client(mcp_server) as client: + result = await client.call_tool("list_clickhouse_tenants", {}) + # The result should be a list containing one item + assert len(result) == 1 + assert isinstance(result[0].text, str) + + # Parse the result text (it's a JSON list of tenant names) + tenants = json.loads(result[0].text) + + assert "default" in tenants # default tenant is defined in .env (no prefix) + +@pytest.mark.asyncio +async def test_list_chdb_tenants(mcp_server): + """Test the list_chdb_tenants tool.""" + async with Client(mcp_server) as client: + result = await client.call_tool("list_chdb_tenants", {}) + # The result should be a list containing one item + assert len(result) == 1 + assert isinstance(result[0].text, str) + # Parse the result text (it's a JSON list of tenant names) + tenants = json.loads(result[0].text) + + assert "default" in tenants # default tenant is defined in .env (no prefix) + @pytest.mark.asyncio async def test_list_databases(mcp_server, setup_test_database): """Test the list_databases tool.""" test_db, _, _ = setup_test_database async with Client(mcp_server) as client: - result = await client.call_tool("list_databases", {}) + result = await client.call_tool("list_databases", {"tenant": "default"}) # The result should be a list containing at least one item assert len(result) >= 1 @@ -111,7 +168,7 @@ async def test_list_tables_basic(mcp_server, setup_test_database): test_db, test_table, test_table2 = setup_test_database async with Client(mcp_server) as client: - result = await client.call_tool("list_tables", {"database": test_db}) + result = await client.call_tool("list_tables", {"tenant": "default", "database": test_db}) assert len(result) >= 1 tables = json.loads(result[0].text) @@ -147,7 +204,7 @@ async def test_list_tables_with_like_filter(mcp_server, setup_test_database): async with Client(mcp_server) as client: # Test with LIKE filter - result = await client.call_tool("list_tables", {"database": test_db, "like": "test_%"}) + result = await client.call_tool("list_tables", {"tenant": "default", "database": test_db, "like": "test_%"}) tables_data = json.loads(result[0].text) @@ -168,7 +225,7 @@ async def test_list_tables_with_not_like_filter(mcp_server, setup_test_database) async with Client(mcp_server) as client: # Test with NOT LIKE filter - result = await client.call_tool("list_tables", {"database": test_db, "not_like": "test_%"}) + result = await client.call_tool("list_tables", {"tenant": "default", "database": test_db, "not_like": "test_%"}) tables_data = json.loads(result[0].text) @@ -189,7 +246,7 @@ async def test_run_select_query_success(mcp_server, setup_test_database): async with Client(mcp_server) as client: query = f"SELECT id, name, age FROM {test_db}.{test_table} ORDER BY id" - result = await client.call_tool("run_select_query", {"query": query}) + result = await client.call_tool("run_select_query", {"tenant": "default", "query": query}) query_result = json.loads(result[0].text) @@ -215,7 +272,7 @@ async def test_run_select_query_with_aggregation(mcp_server, setup_test_database async with Client(mcp_server) as client: query = f"SELECT COUNT(*) as count, AVG(age) as avg_age FROM {test_db}.{test_table}" - result = await client.call_tool("run_select_query", {"query": query}) + result = await client.call_tool("run_select_query", {"tenant": "default", "query": query}) query_result = json.loads(result[0].text) @@ -232,7 +289,7 @@ async def test_run_select_query_with_join(mcp_server, setup_test_database): async with Client(mcp_server) as client: # Insert related data for join - client_direct = create_clickhouse_client() + client_direct = create_clickhouse_client("default") client_direct.command(f""" INSERT INTO {test_db}.{test_table2} (event_id, event_type, timestamp) VALUES (2001, 'purchase', '2024-01-01 14:00:00') @@ -243,7 +300,7 @@ async def test_run_select_query_with_join(mcp_server, setup_test_database): COUNT(DISTINCT event_type) as event_types_count FROM {test_db}.{test_table2} """ - result = await client.call_tool("run_select_query", {"query": query}) + result = await client.call_tool("run_select_query", {"tenant": "default", "query": query}) query_result = json.loads(result[0].text) assert query_result["rows"][0][0] == 3 # login, logout, purchase @@ -260,13 +317,15 @@ async def test_run_select_query_error(mcp_server, setup_test_database): # Should raise ToolError with pytest.raises(ToolError) as exc_info: - await client.call_tool("run_select_query", {"query": query}) + await client.call_tool("run_select_query", {"tenant": "default", "query": query}) assert "Query execution failed" in str(exc_info.value) @pytest.mark.asyncio -async def test_run_select_query_syntax_error(mcp_server): +async def test_run_select_query_syntax_error(mcp_server, setup_test_database): + _, _, _ = setup_test_database + """Test running a SELECT query with syntax error.""" async with Client(mcp_server) as client: # Invalid SQL syntax @@ -274,7 +333,7 @@ async def test_run_select_query_syntax_error(mcp_server): # Should raise ToolError with pytest.raises(ToolError) as exc_info: - await client.call_tool("run_select_query", {"query": query}) + await client.call_tool("run_select_query", {"tenant": "default", "query": query}) assert "Query execution failed" in str(exc_info.value) @@ -285,7 +344,7 @@ async def test_table_metadata_details(mcp_server, setup_test_database): test_db, test_table, _ = setup_test_database async with Client(mcp_server) as client: - result = await client.call_tool("list_tables", {"database": test_db}) + result = await client.call_tool("list_tables", {"tenant": "default", "database": test_db}) tables = json.loads(result[0].text) # Find our test table @@ -319,11 +378,12 @@ async def test_table_metadata_details(mcp_server, setup_test_database): @pytest.mark.asyncio -async def test_system_database_access(mcp_server): +async def test_system_database_access(mcp_server, setup_test_database): """Test that we can access system databases.""" + _, _, _ = setup_test_database async with Client(mcp_server) as client: # List tables in system database - result = await client.call_tool("list_tables", {"database": "system"}) + result = await client.call_tool("list_tables", {"tenant": "default", "database": "system"}) tables = json.loads(result[0].text) # System database should have many tables @@ -352,7 +412,7 @@ async def test_concurrent_queries(mcp_server, setup_test_database): # Execute all queries concurrently results = await asyncio.gather( - *[client.call_tool("run_select_query", {"query": query}) for query in queries] + *[client.call_tool("run_select_query", {"tenant": "default", "query": query}) for query in queries] ) # Verify all queries succeeded diff --git a/tests/test_tool.py b/tests/test_tool.py index 50878c4..fb4b117 100644 --- a/tests/test_tool.py +++ b/tests/test_tool.py @@ -4,16 +4,15 @@ from dotenv import load_dotenv from fastmcp.exceptions import ToolError -from mcp_clickhouse import create_clickhouse_client, list_databases, list_tables, run_select_query +from mcp_clickhouse import create_clickhouse_client, list_clickhouse_tenants, list_databases, list_tables, run_select_query load_dotenv() - class TestClickhouseTools(unittest.TestCase): @classmethod def setUpClass(cls): """Set up the environment before tests.""" - cls.client = create_clickhouse_client() + cls.client = create_clickhouse_client(tenant="default") # Prepare test database and table cls.test_db = "test_tool_db" @@ -41,23 +40,62 @@ def tearDownClass(cls): """Clean up the environment after tests.""" cls.client.command(f"DROP DATABASE IF EXISTS {cls.test_db}") + def test_list_clickhouse_tenants(self): + tenants = list_clickhouse_tenants() + self.assertIn("default", tenants) + self.assertEqual(len(tenants), 1) + + def test_list_databases_wrong_tenant(self): + """Test listing tables with wrong tenant.""" + tenant = "wrong_tenant" + with self.assertRaises(ToolError) as cm: + list_databases(tenant) + + self.assertIn( + f"List databases not performed for invalid tenant - '{tenant}'", + str(cm.exception) + ) + + def test_list_tables_wrong_tenant(self): + """Test listing tables with wrong tenant.""" + tenant = "wrong_tenant" + with self.assertRaises(ToolError) as cm: + list_tables(tenant, self.test_db) + + self.assertIn( + f"List tables not performed for invalid tenant - '{tenant}'", + str(cm.exception) + ) + + def test_run_select_query_wrong_tenant(self): + """Test run select query with wrong tenant.""" + tenant = "wrong_tenant" + query = f"SELECT * FROM {self.test_db}.{self.test_table}" + with self.assertRaises(ToolError) as cm: + run_select_query(tenant, query) + + self.assertIn( + f"Query not performed for invalid tenant - '{tenant}'", + str(cm.exception) + ) + def test_list_databases(self): """Test listing databases.""" - result = list_databases() + result = list_databases("default") # Parse JSON response databases = json.loads(result) self.assertIn(self.test_db, databases) def test_list_tables_without_like(self): """Test listing tables without a 'LIKE' filter.""" - result = list_tables(self.test_db) + result = list_tables("default", self.test_db) self.assertIsInstance(result, list) self.assertEqual(len(result), 1) self.assertEqual(result[0]["name"], self.test_table) def test_list_tables_with_like(self): """Test listing tables with a 'LIKE' filter.""" - result = list_tables(self.test_db, like=f"{self.test_table}%") + result = list_tables("default", self.test_db, like=f"{self.test_table}%") self.assertIsInstance(result, list) self.assertEqual(len(result), 1) self.assertEqual(result[0]["name"], self.test_table) @@ -65,7 +103,7 @@ def test_list_tables_with_like(self): def test_run_select_query_success(self): """Test running a SELECT query successfully.""" query = f"SELECT * FROM {self.test_db}.{self.test_table}" - result = run_select_query(query) + result = run_select_query("default", query) self.assertIsInstance(result, dict) self.assertEqual(len(result["rows"]), 2) self.assertEqual(result["rows"][0][0], 1) @@ -77,13 +115,13 @@ def test_run_select_query_failure(self): # Should raise ToolError with self.assertRaises(ToolError) as context: - run_select_query(query) + run_select_query("default", query) self.assertIn("Query execution failed", str(context.exception)) def test_table_and_column_comments(self): """Test that table and column comments are correctly retrieved.""" - result = list_tables(self.test_db) + result = list_tables("default", self.test_db) self.assertIsInstance(result, list) self.assertEqual(len(result), 1)