diff --git a/.dockerignore b/.dockerignore index f3e0b46..834db77 100644 --- a/.dockerignore +++ b/.dockerignore @@ -13,6 +13,5 @@ SECURITY.md CODE_OF_CONDUCT.md CONTRIBUTING.md docker-compose.yml -policy .pre-commit-config.yaml .gitignore diff --git a/.gitignore b/.gitignore index ea53511..1c5bcc0 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,3 @@ __pycache__/ dist/ build/ example.db -policy/ diff --git a/AGENTS.md b/AGENTS.md index e8da51f..89a4c15 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -11,13 +11,14 @@ Core package: `src/secure_sql_mcp` - `config.py` - Loads env config. + - Injects async driver suffixes for bare `DATABASE_URL` schemes (`postgresql://`, `mysql://`, `sqlite://`). - Parses `ALLOWED_POLICY_FILE` in strict `table:columns` format. - `query_validator.py` - SQL safety checks (read-only, single statement). - Strict table/column authorization checks. - `database.py` - Async SQLAlchemy access. - - Read-only session preparation and query timeout/row caps. + - Read-only session preparation and query timeout/row caps (PostgreSQL, MySQL, SQLite). - `server.py` - MCP tool surface (`query`, `list_tables`, `describe_table`). - User/agent-facing responses. diff --git a/Dockerfile b/Dockerfile index 8f0ed61..89c691a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,15 +8,27 @@ RUN python -m venv /opt/venv \ && /opt/venv/bin/pip install --no-cache-dir . \ && find /opt/venv -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null; true +FROM openpolicyagent/opa:1.5.1-static AS opa + FROM python:3.12-slim ENV PYTHONDONTWRITEBYTECODE=1 \ PYTHONUNBUFFERED=1 \ - PATH="/opt/venv/bin:$PATH" + PATH="/opt/venv/bin:$PATH" \ + OPA_URL="http://127.0.0.1:8181" \ + OPA_DECISION_PATH="/v1/data/secure_sql/authz/decision" \ + OPA_TIMEOUT_MS="50" \ + OPA_FAIL_CLOSED="true" COPY --from=builder /opt/venv /opt/venv +COPY --from=opa /opa /usr/local/bin/opa +COPY policy /app/policy +COPY docker/entrypoint.sh /app/entrypoint.sh +COPY docker/wait_for_opa.py /app/wait_for_opa.py + +RUN chmod 0555 /usr/local/bin/opa /app/entrypoint.sh /app/wait_for_opa.py RUN useradd -r -s /usr/sbin/nologin appuser USER appuser -ENTRYPOINT ["python", "-m", "secure_sql_mcp.server"] +ENTRYPOINT ["/app/entrypoint.sh"] diff --git a/README.md b/README.md index 0a65b3f..cef8c6c 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # Secure SQL MCP Server -Read-only SQL MCP server with strict table/column policy controls. +Read-only SQL MCP server with strict table/column policy controls, with OPA-based +authorization running inside the same container. [![CI](https://github.com/jrhuerta/secure-sql-mcp/actions/workflows/ci.yml/badge.svg)](https://github.com/jrhuerta/secure-sql-mcp/actions/workflows/ci.yml) [![GHCR](https://img.shields.io/badge/ghcr-jrhuerta%2Fsecure--sql--mcp-blue)](https://github.com/jrhuerta/secure-sql-mcp/pkgs/container/secure-sql-mcp) @@ -31,27 +32,33 @@ To use this server with Cursor, Claude Desktop, or other MCP clients, add it to **Claude Desktop** (`claude_desktop_config.json`): same structure under `mcpServers`. -The `--env-file` should point to a file containing `DATABASE_URL` and `ALLOWED_POLICY_FILE=/run/policy/allowed_policy.txt` (see Environment Variables below). The volume mounts the policy directory read-only. Pull the image first: `docker pull ghcr.io/jrhuerta/secure-sql-mcp:latest` +The `--env-file` should point to a file containing `DATABASE_URL` and +`ALLOWED_POLICY_FILE=/run/policy/allowed_policy.txt` (see Environment Variables below). +The volume mounts the policy directory read-only. Pull the image first: +`docker pull ghcr.io/jrhuerta/secure-sql-mcp:latest` ## Security Model - Database credentials stay server-side (env vars), never in prompts. - Only read queries are allowed. -- Policy is strict and file-based: - - one required file: `ALLOWED_POLICY_FILE` - - each line is `table:col1,col2,col3` or `table:*` +- OPA authorization runs in-process for the container image (local loopback only, no external port exposure). +- Policy is strict and deny-by-default: + - baseline constraints and ACL rules are evaluated by OPA + - ACL source can come from a native OPA data file or transformed legacy `ALLOWED_POLICY_FILE` - If a table/column is not explicitly allowed, it is blocked. ## Implemented Security Controls -- **Query shape enforcement** +- **OPA baseline constraints (`default_constraints`)** - Exactly one SQL statement is allowed per request. - Non-read operations are blocked (`INSERT`, `UPDATE`, `DELETE`, `DROP`, `ALTER`, `CREATE`, `TRUNCATE`, `GRANT`, `REVOKE`, `MERGE`, and related command expressions). -- **Strict access policy enforcement** + - Unqualified columns in multi-table queries are rejected under strict mode. +- **OPA ACL policy (`acl`)** - Deny-by-default for tables and columns. - Access checks apply across direct queries and composed queries (`JOIN`, `UNION`, subqueries, aliases). - - `SELECT *` is rejected unless the table policy is `table:*`. - - Unqualified columns in multi-table queries are rejected under strict mode. + - `SELECT *` is rejected unless ACL explicitly allows wildcard (`*`) for that table. +- **Composed authorization (`authz`)** + - Access is granted only when both `default_constraints` and `acl` allow. - **Runtime safety controls** - Query timeout and row cap are enforced server-side. - Row-cap truncation is explicit in response payloads. @@ -63,12 +70,29 @@ The `--env-file` should point to a file containing `DATABASE_URL` and `ALLOWED_P | Variable | Required | Default | Description | |----------|----------|---------|-------------| -| `DATABASE_URL` | Yes | — | SQLAlchemy async URL (e.g. `sqlite+aiosqlite:///./example.db` or `postgresql+asyncpg://...`) | +| `DATABASE_URL` | Yes | — | Database URL. Bare `postgresql://`, `mysql://`, and `sqlite://` URLs are accepted and auto-upgraded to async drivers (`+asyncpg`, `+aiomysql`, `+aiosqlite`). | | `ALLOWED_POLICY_FILE` | Yes | — | Path to the policy file | +| `OPA_URL` | No | `http://127.0.0.1:8181` in Docker image; unset otherwise | OPA base URL. When set, queries/tools are authorized via OPA. | +| `OPA_DECISION_PATH` | No | `/v1/data/secure_sql/authz/decision` | OPA decision endpoint path. | +| `OPA_TIMEOUT_MS` | No | `50` | OPA decision timeout in milliseconds. | +| `OPA_FAIL_CLOSED` | No | `true` | If `true`, OPA errors/timeouts block access. | +| `OPA_ACL_DATA_FILE` | No | unset | Optional JSON ACL file (`secure_sql.acl.tables`) preferred over transformed `ALLOWED_POLICY_FILE`. | +| `WRITE_MODE_ENABLED` | No | `false` | Enables write execution path (`INSERT`/`UPDATE`/`DELETE`) when `true`. | +| `ALLOW_INSERT` | No | `false` | Allows `INSERT` statements when write mode is enabled. | +| `ALLOW_UPDATE` | No | `false` | Allows `UPDATE` statements when write mode is enabled. | +| `ALLOW_DELETE` | No | `false` | Allows `DELETE` statements when write mode is enabled. | +| `REQUIRE_WHERE_FOR_UPDATE` | No | `true` | When `true`, `UPDATE` requires a `WHERE` clause. | +| `REQUIRE_WHERE_FOR_DELETE` | No | `true` | When `true`, `DELETE` requires a `WHERE` clause. | +| `ALLOW_RETURNING` | No | `false` | Allows `RETURNING` on write statements when `true`. | | `MAX_ROWS` | No | 100 | Maximum rows returned per query (1–10000) | | `QUERY_TIMEOUT` | No | 30 | Query timeout in seconds (1–300) | | `LOG_LEVEL` | No | INFO | Logging level (DEBUG, INFO, WARNING, ERROR) | +Write mode guardrails: +- All write-related flags default to `false` (deny-by-default). +- OPA remains the policy decision point, but config gates are enforced first as coarse runtime brakes. +- The server logs a `WARNING` when config gates block a write that policy would otherwise allow. + ## Policy File Format `allowed_policy.txt`: @@ -84,6 +108,20 @@ Rules: - `#` comments and blank lines are allowed. - Matching is case-insensitive. +## OPA Policy Layout + +- Rego bundle directory: `policy/rego/` + - `default_constraints.rego` + - `acl.rego` + - `authz.rego` +- Example ACL data file: `policy/data/acl.example.json` +- Policy authoring guide: [`docs/POLICY_AUTHORING.md`](docs/POLICY_AUTHORING.md) +- Controlled write mode design: [`docs/WRITE_MODE_DESIGN.md`](docs/WRITE_MODE_DESIGN.md) + +ACL source precedence at runtime: +1. If `OPA_ACL_DATA_FILE` is set, ACL input is loaded from that JSON file. +2. Otherwise, `ALLOWED_POLICY_FILE` is transformed into equivalent ACL input. + ## Agent Discoverability The MCP server exposes: @@ -96,7 +134,8 @@ The MCP server exposes: - allowed columns for that table from policy - schema metadata from DB when available - `query(sql)`: - - executes only if query is read-only and within table/column policy + - executes read queries by default under table/column policy + - executes write queries only when write mode/action toggles allow them and policy permits ## Quick Start (uv) @@ -115,12 +154,31 @@ QUERY_TIMEOUT=30 LOG_LEVEL=INFO EOF +# Optional when testing against an external/local OPA process outside the container: +# OPA_URL=http://127.0.0.1:8181 +# OPA_DECISION_PATH=/v1/data/secure_sql/authz/decision +# OPA_TIMEOUT_MS=50 +# OPA_FAIL_CLOSED=true + mkdir -p policy cat > policy/allowed_policy.txt <<'EOF' customers:id,email orders:* EOF +cat > policy/acl.json <<'EOF' +{ + "secure_sql": { + "acl": { + "tables": { + "customers": {"columns": ["id", "email"]}, + "orders": {"columns": ["*"]} + } + } + } +} +EOF + # Create tables for local testing (optional) sqlite3 example.db <<'SQL' CREATE TABLE IF NOT EXISTS customers (id INTEGER PRIMARY KEY, email TEXT NOT NULL, ssn TEXT); @@ -150,6 +208,7 @@ EOF cat > .env <<'EOF' DATABASE_URL=sqlite+aiosqlite:///./example.db ALLOWED_POLICY_FILE=/run/policy/allowed_policy.txt +OPA_ACL_DATA_FILE=/run/policy/acl.json MAX_ROWS=100 QUERY_TIMEOUT=30 LOG_LEVEL=INFO @@ -191,6 +250,7 @@ docker compose up --build - Avoid hardcoding credentials in shell history. - Mount policy files read-only (`:ro`) in Docker. - Keep `.env` and policy files out of version control. +- Keep OPA policy/data assets immutable in runtime containers. ## Dev Tooling @@ -216,11 +276,44 @@ python -m pytest -q \ ``` What these suites validate: -- read-only enforcement for mutation/privileged SQL operations +- default read-only enforcement for mutation/privileged SQL operations - single-statement validation and parser hardening - strict deny-by-default table/column ACL checks, including join/union/subquery paths +- write-mode guardrails (`WRITE_MODE_ENABLED` and per-action toggles), including WHERE safety checks - protocol-level behavior over MCP stdio transport - timeout, row cap truncation, and non-leaky actionable DB error responses +- OPA fail-closed behavior and ACL source precedence + +## Real Docker + OPA Matrix Tests + +Run comprehensive real-server scenarios against Dockerized MCP+OPA across +SQLite, PostgreSQL, and MySQL: + +```bash +python -m pytest -q -m docker_integration tests/integration/docker/test_mcp_docker_opa_matrix.py +``` + +Run a faster smoke subset: + +```bash +bash scripts/run-docker-opa-smoke.sh +``` + +Prerequisites: +- Docker Engine with Compose plugin (`docker compose`) +- ability to pull base images (`postgres:16-alpine`, `mysql:8.4`) + +Troubleshooting: +- if MySQL/PostgreSQL startup is slow, rerun with `-m docker_integration -vv` to inspect per-scenario logs +- if Docker is unavailable, these tests auto-skip and unit/security suites still run normally +- if port/resource contention occurs, remove stale test stacks: `docker compose -f docker-compose.test.yml down -v --remove-orphans` + +What the Docker matrix validates: +- read/write allow/deny behavior with real container runtime and OPA process +- policy-profile variants mounted as read-only files +- write gate toggles (`WRITE_MODE_ENABLED`, `ALLOW_*`) and WHERE/RETURNING controls +- bypass-focused checks (`INSERT ... SELECT`, source `SELECT *`, tautological WHERE) +- OPA fail-closed behavior when decision service is unavailable ## CI Security Gate Expectations @@ -237,7 +330,7 @@ python -m pytest -q \ Recommended policy: - block merges on any failure in the security suites above -- require test updates when changing query validation, policy parsing, or MCP tool responses +- require test updates when changing query validation, OPA policy inputs, policy parsing, or MCP tool responses - keep security test fixtures deterministic (no shared state, no external DB dependency by default) ## Contributing @@ -258,8 +351,9 @@ Before merging security-sensitive changes, verify: - query validation still enforces exactly one statement per request - mutation/DDL/privilege SQL operations are blocked with actionable messaging -- table and column access remains deny-by-default against `ALLOWED_POLICY_FILE` -- `SELECT *` is rejected unless policy explicitly allows `table:*` +- table and column access remains deny-by-default against effective ACL source + (`OPA_ACL_DATA_FILE` when present, else transformed `ALLOWED_POLICY_FILE`) +- `SELECT *` is rejected unless ACL explicitly allows wildcard - multi-table queries still reject unqualified columns and enforce alias-aware ACLs - timeout and row-cap protections remain active and tested - DB error responses stay sanitized and do not expose credentials/internal connection details diff --git a/docker-compose.test.yml b/docker-compose.test.yml new file mode 100644 index 0000000..c4687f1 --- /dev/null +++ b/docker-compose.test.yml @@ -0,0 +1,59 @@ +services: + secure-sql-mcp: + build: + context: . + dockerfile: Dockerfile + image: secure-sql-mcp:test + environment: + ALLOWED_POLICY_FILE: /run/policy/allowed_policy.txt + OPA_URL: http://127.0.0.1:8181 + OPA_DECISION_PATH: /v1/data/secure_sql/authz/decision + OPA_TIMEOUT_MS: "50" + OPA_FAIL_CLOSED: "true" + WRITE_MODE_ENABLED: "false" + ALLOW_INSERT: "false" + ALLOW_UPDATE: "false" + ALLOW_DELETE: "false" + REQUIRE_WHERE_FOR_UPDATE: "true" + REQUIRE_WHERE_FOR_DELETE: "true" + ALLOW_RETURNING: "false" + MAX_ROWS: "100" + QUERY_TIMEOUT: "30" + LOG_LEVEL: "INFO" + stdin_open: true + tty: false + depends_on: + postgres: + condition: service_healthy + mysql: + condition: service_healthy + + postgres: + image: postgres:16-alpine + environment: + POSTGRES_USER: secure + POSTGRES_PASSWORD: secure + POSTGRES_DB: secure_sql_test + volumes: + - ./tests/integration/docker/db-init/postgres.sql:/docker-entrypoint-initdb.d/01-init.sql:ro + healthcheck: + test: ["CMD-SHELL", "pg_isready -U secure -d secure_sql_test"] + interval: 2s + timeout: 2s + retries: 30 + + mysql: + image: mysql:8.4 + command: ["--default-authentication-plugin=mysql_native_password"] + environment: + MYSQL_ROOT_PASSWORD: root + MYSQL_USER: secure + MYSQL_PASSWORD: secure + MYSQL_DATABASE: secure_sql_test + volumes: + - ./tests/integration/docker/db-init/mysql.sql:/docker-entrypoint-initdb.d/01-init.sql:ro + healthcheck: + test: ["CMD-SHELL", "mysqladmin ping -h 127.0.0.1 -usecure -psecure --silent"] + interval: 2s + timeout: 2s + retries: 60 diff --git a/docker-compose.yml b/docker-compose.yml index 4823db2..1b1edd4 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -7,6 +7,13 @@ services: environment: DATABASE_URL: ${DATABASE_URL:-sqlite+aiosqlite:///./example.db} ALLOWED_POLICY_FILE: ${ALLOWED_POLICY_FILE:-/run/policy/allowed_policy.txt} + OPA_DECISION_PATH: ${OPA_DECISION_PATH:-/v1/data/secure_sql/authz/decision} + OPA_TIMEOUT_MS: ${OPA_TIMEOUT_MS:-50} + OPA_FAIL_CLOSED: ${OPA_FAIL_CLOSED:-true} + WRITE_MODE_ENABLED: ${WRITE_MODE_ENABLED:-false} + ALLOW_INSERT: ${ALLOW_INSERT:-false} + ALLOW_UPDATE: ${ALLOW_UPDATE:-false} + ALLOW_DELETE: ${ALLOW_DELETE:-false} MAX_ROWS: ${MAX_ROWS:-100} QUERY_TIMEOUT: ${QUERY_TIMEOUT:-30} LOG_LEVEL: ${LOG_LEVEL:-INFO} diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh new file mode 100644 index 0000000..dfd48a9 --- /dev/null +++ b/docker/entrypoint.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env sh +set -eu + +OPA_BUNDLE_DIR="${OPA_BUNDLE_DIR:-/app/policy}" +OPA_ADDR="${OPA_ADDR:-127.0.0.1:8181}" + +opa run --server --addr "$OPA_ADDR" --bundle "$OPA_BUNDLE_DIR" & +OPA_PID=$! + +cleanup() { + kill "$OPA_PID" 2>/dev/null || true +} + +trap cleanup INT TERM EXIT + +python /app/wait_for_opa.py + +exec python -m secure_sql_mcp.server diff --git a/docker/wait_for_opa.py b/docker/wait_for_opa.py new file mode 100644 index 0000000..fa03036 --- /dev/null +++ b/docker/wait_for_opa.py @@ -0,0 +1,34 @@ +"""Wait for OPA health without extra image dependencies. + +This project uses a Python-based readiness check instead of curl/wget so the +runtime image can stay minimal and self-contained. `python:3.12-slim` already +ships Python (required by the MCP server), while curl is not guaranteed. +""" + +from __future__ import annotations + +import os +import time +from urllib import request + + +def main() -> None: + opa_url = os.environ.get("OPA_URL", "http://127.0.0.1:8181").rstrip("/") + health_url = f"{opa_url}/health" + + deadline = time.time() + 15 + last_error: Exception | None = None + while time.time() < deadline: + try: + with request.urlopen(health_url, timeout=0.5) as resp: # noqa: S310 + if 200 <= resp.status < 300: + return + except Exception as exc: # noqa: BLE001 + last_error = exc + time.sleep(0.2) + + raise SystemExit(f"OPA failed health check: {last_error}") + + +if __name__ == "__main__": + main() diff --git a/docs/POLICY_AUTHORING.md b/docs/POLICY_AUTHORING.md new file mode 100644 index 0000000..a02b2aa --- /dev/null +++ b/docs/POLICY_AUTHORING.md @@ -0,0 +1,233 @@ +# Policy Authoring Guide (OPA/Rego) + +This guide explains how to write and customize policies for `secure-sql-mcp`. + +It is designed to be: +- practical for humans +- structured enough for agents to generate policy variants + +## Policy model at a glance + +The default bundle composes two policy modules: + +- `default_constraints`: baseline guardrails (statement count, operation class, query shape) +- `acl`: table/column access rules + +Final decision is an AND: + +- allow only when `default_constraints.allow` and `acl.allow` are both true + +See: +- `policy/rego/default_constraints.rego` +- `policy/rego/acl.rego` +- `policy/rego/authz.rego` + +## Runtime architecture + +The server now has split execution paths: + +- read statements execute through `execute_read_query(...)` +- write statements execute through `execute_write_query(...)` + +Write execution is still deny-by-default and requires *both* runtime gates and policy: + +- runtime gates (`WRITE_MODE_ENABLED`, `ALLOW_INSERT/UPDATE/DELETE`) +- policy allow in OPA (`default_constraints`, `acl`, `write_constraints`) + +## Input facts available to policy + +OPA receives `{"input": ...}` payloads from the server. + +### For `query` tool + +```json +{ + "tool": { "name": "query" }, + "query": { + "raw_sql": "SELECT id FROM customers", + "statement_count": 1, + "statement_type": "select", + "is_write_statement": false, + "has_disallowed_operation": false, + "is_read_statement": true, + "referenced_tables": ["customers"], + "referenced_columns": { "customers": ["id"] }, + "star_tables": [], + "has_unqualified_multi_table_columns": false, + "target_table": "", + "insert_columns": [], + "updated_columns": [], + "where_present": false, + "where_tautological": false, + "returning_present": false, + "returning_columns": [], + "has_select_source": false, + "source_tables": [] + }, + "config": { + "write_mode_enabled": false, + "allow_insert": false, + "allow_update": false, + "allow_delete": false, + "require_where_for_update": true, + "require_where_for_delete": true, + "allow_returning": false + }, + "acl": { + "tables": { + "customers": { "columns": ["id", "email"] }, + "orders": { "columns": ["*"] } + } + } +} +``` + +### For `list_tables` tool + +```json +{ + "tool": { "name": "list_tables" }, + "acl": { "tables": { "...": { "columns": ["..."] } } } +} +``` + +### For `describe_table` tool + +```json +{ + "tool": { "name": "describe_table" }, + "table": "customers", + "acl": { "tables": { "...": { "columns": ["..."] } } } +} +``` + +## ACL data sources + +ACL data can come from either: + +1. `OPA_ACL_DATA_FILE` (preferred when set), JSON structure at `secure_sql.acl.tables` +2. transformed legacy `ALLOWED_POLICY_FILE` + +Use `OPA_ACL_DATA_FILE` when you want native OPA-oriented config. + +## Rego patterns you can reuse + +Use these as templates when asking an agent to generate a policy. + +### 1) Keep strict read-only baseline (default behavior) + +```rego +package secure_sql.default_constraints + +default allow := false + +deny_reasons["multiple_statements"] if input.query.statement_count != 1 +deny_reasons["disallowed_operation"] if input.query.has_disallowed_operation +deny_reasons["not_read_query"] if not input.query.is_read_statement + +allow if count(deny_reasons) == 0 +``` + +### 2) Relax baseline to allow inserts only (policy example) + +```rego +package secure_sql.default_constraints + +default allow := false + +is_insert if input.query.statement_type == "insert" + +deny_reasons["multiple_statements"] if input.query.statement_count != 1 +deny_reasons["not_allowed_statement_type"] if { + not input.query.is_read_statement + not is_insert +} + +allow if count(deny_reasons) == 0 +``` + +### 3) Allow updates only to specific tables/columns (policy example) + +```rego +package secure_sql.default_constraints + +default allow := false + +allowed_update_columns := { + "customers": {"email"}, + "profiles": {"display_name", "timezone"}, +} + +is_update if input.query.statement_type == "update" + +deny_reasons["multiple_statements"] if input.query.statement_count != 1 + +deny_reasons["update_column_not_allowed"] if { + is_update + table := object.keys(input.query.referenced_columns)[_] + col := input.query.referenced_columns[table][_] + not allowed_update_columns[table][col] +} + +deny_reasons["not_allowed_statement_type"] if { + not input.query.is_read_statement + not is_update +} + +allow if count(deny_reasons) == 0 +``` + +### 4) Time-window or environment gating + +If you add `context` facts (for example `input.context.environment`), you can gate +by deployment environment or maintenance window. + +```rego +deny_reasons["writes_only_allowed_in_maintenance"] if { + input.query.statement_type == "update" + input.context.environment != "maintenance" +} +``` + +## Agent-friendly prompt template + +Use this prompt with an agent to generate policy: + +```text +Generate Rego policies for secure-sql-mcp. + +Constraints: +- Keep authz composition as default_constraints AND acl. +- Tool names are query, list_tables, describe_table. +- Use deny_reasons for explainability. +- Maintain deny-by-default. + +Desired behavior: +- [Describe exactly which statements are allowed] +- [Describe table/column restrictions] +- [Describe any time/env/principal restrictions] + +Output: +1) default_constraints.rego +2) acl.rego (if ACL behavior changes) +3) authz.rego (if composition changes) +4) short test matrix with allowed/blocked examples +``` + +## Testing checklist for policy changes + +When you relax policy behavior, test all of: + +- single-statement enforcement +- disallowed table/column access +- wildcard behavior (`SELECT *`) +- joins/subqueries/unions +- error hygiene (no sensitive leaks) +- OPA fail-closed behavior + +Run: + +```bash +python -m pytest -q +``` + diff --git a/docs/WRITE_MODE_DESIGN.md b/docs/WRITE_MODE_DESIGN.md new file mode 100644 index 0000000..76202bf --- /dev/null +++ b/docs/WRITE_MODE_DESIGN.md @@ -0,0 +1,163 @@ +# Write Mode Design (Controlled Mutations) + +This document describes how to extend `secure-sql-mcp` from read-only execution to +policy-governed writes while preserving security guarantees. + +It is intentionally conservative: mutation capability is powerful and should be +introduced behind explicit controls, with deny-by-default behavior at each layer. + +## Current state + +Today, policy can theoretically return `allow` for non-read operations, but runtime +execution is still read-only: + +- query execution uses `execute_read_query(...)` +- DB session is configured read-only for PostgreSQL/MySQL/SQLite +- query wrapper enforces select-style row capping logic + +As a result, policy-only changes are not sufficient for write support. + +## Goals + +- Allow tightly scoped mutation scenarios (for example `INSERT` only). +- Keep deny-by-default and fail-closed behavior. +- Preserve clean, actionable error responses for agents. +- Avoid broad privilege escalation in database credentials. + +## Non-goals + +- Full unrestricted SQL write access. +- Multi-statement transaction scripting from agents. +- Bypassing policy checks in application code. + +## Security invariants to preserve + +- Single statement per request unless explicitly designed otherwise. +- Explicit allowlist semantics (tables/columns/actions). +- No sensitive internal error leakage. +- OPA unavailable/timeout behavior remains fail-closed. +- Tool responses remain deterministic and auditable. + +## Proposed architecture changes + +### 1) Split execution paths by statement class + +Introduce separate DB execution methods: + +- `execute_read_query(sql)` (existing) +- `execute_write_query(sql)` (new) + +`execute_write_query` should: + +- run with strict timeout +- return affected row count and optional returning payload +- avoid row-cap wrapper intended for SELECT +- avoid enabling arbitrary transaction control from user SQL + +### 2) Expand policy facts for write authorization + +Current input facts are SELECT-centric. Add mutation-focused facts in +`QueryValidator._build_query_policy_input(...)`, for example: + +- `statement_type` normalized (`insert`, `update`, `delete`, etc.) +- `target_tables` +- `updated_columns` +- `insert_columns` +- `where_present` (for updates/deletes) +- `returning_present` + +These facts should be parser-derived, not regex-derived. + +### 3) Add explicit write mode config gates + +Use coarse-grained runtime toggles in config: + +- `WRITE_MODE_ENABLED=false` by default +- optional action toggles: + - `ALLOW_INSERT=false` + - `ALLOW_UPDATE=false` + - `ALLOW_DELETE=false` +- allow these to be configured with flags from the cli also. + +OPA remains the final decision engine; these toggles are safety brakes. + +### 4) Keep OPA as policy source of truth for permissions + +Model policy in Rego with explicit action constraints: + +- allow read paths as before +- allow writes only when: + - statement type is explicitly permitted + - table is allowed + - affected columns are allowed + - optional contextual constraints pass (tenant/env/user role) + +### 5) Server-level routing + +In `query(...)`, route execution by validated statement class: + +- if read -> `execute_read_query` +- if write and allowed -> `execute_write_query` +- else block with actionable message + +## Example policy patterns for controlled writes + +### Insert-only mode + +- allow `insert` on specific tables +- deny `update`, `delete`, DDL + +### Update-only specific columns + +- allow `update` on table `customers` +- allow only `email` and `phone` +- require `WHERE` clause (no full-table updates) + +### Delete with strict guard + +- allow `delete` only on maintenance tables +- require `WHERE` and additional context flag (e.g. maintenance window) + +## DB credential model + +Policy alone is not enough. Use least-privilege DB credentials: + +- read-only role for read-only deployments +- separate write-capable role for write mode +- grants limited to intended schemas/tables/actions + +Do not rely solely on app-layer checks for write containment. + +## Response contract proposal for writes + +For write operations, return structured JSON: + +```json +{ + "status": "ok", + "operation": "update", + "affected_rows": 3, + "returning": [] +} +``` + +For blocked writes, keep consistent actionable messages and avoid leaking internals. + +## Rollout plan + +1. Add parser-derived write facts and tests (no write execution yet). +2. Add OPA write policy rules in shadow mode (log only). +3. Add `execute_write_query` path behind `WRITE_MODE_ENABLED`. +4. Enable insert-only in non-production. +5. Expand to update/delete only with dedicated tests and DB grants. + +## Test matrix (minimum) + +- parser extraction for write facts by dialect +- blocked/allowed decisions for insert/update/delete +- column-restricted updates +- missing-WHERE safeguards +- fail-closed OPA behavior for writes +- sanitized DB error responses +- stdio MCP contract for write and blocked-write outcomes + diff --git a/policy/allowed_policy.txt b/policy/allowed_policy.txt new file mode 100644 index 0000000..eb05c8c --- /dev/null +++ b/policy/allowed_policy.txt @@ -0,0 +1,2 @@ +customers:id,email +orders:* diff --git a/policy/allowed_policy_castledm_test.txt b/policy/allowed_policy_castledm_test.txt new file mode 100644 index 0000000..7b32740 --- /dev/null +++ b/policy/allowed_policy_castledm_test.txt @@ -0,0 +1,7 @@ +tbl_alayamarket_demand_offers:id,visit_id,alayamarket_offer_id +tbl_alayamarket_demand_offers_audit:* +tbl_offers_responses:* +tbl_schedule_offer:* +tbl_schedule_offer_response:* +view_grouped_offer_responses:* +view_guid_computed_offer_responses:* diff --git a/policy/data/acl.example.json b/policy/data/acl.example.json new file mode 100644 index 0000000..9b644a2 --- /dev/null +++ b/policy/data/acl.example.json @@ -0,0 +1,14 @@ +{ + "secure_sql": { + "acl": { + "tables": { + "customers": { + "columns": ["id", "email"] + }, + "orders": { + "columns": ["*"] + } + } + } + } +} diff --git a/policy/rego/acl.rego b/policy/rego/acl.rego new file mode 100644 index 0000000..3069f8e --- /dev/null +++ b/policy/rego/acl.rego @@ -0,0 +1,144 @@ +package secure_sql.acl + +default allow = false + +acl_tables := object.get(object.get(input, "acl", {}), "tables", {}) + +is_query_tool if input.tool.name == "query" +is_list_tables_tool if input.tool.name == "list_tables" +is_describe_table_tool if input.tool.name == "describe_table" +is_write_query if { + is_query_tool + object.get(input.query, "is_write_statement", false) +} + +normalized_table(table) := lower(table) +short_table_name(table) := name if { + parts := split(normalized_table(table), ".") + idx := count(parts) - 1 + name := parts[idx] +} + +table_allowed(table) if object.get(acl_tables, normalized_table(table), null) != null +table_allowed(table) if object.get(acl_tables, short_table_name(table), null) != null + +allowed_columns(table) := cols if { + full := object.get(acl_tables, normalized_table(table), null) + full != null + cols := object.get(full, "columns", []) +} + +allowed_columns(table) := cols if { + full := object.get(acl_tables, normalized_table(table), null) + full == null + short := object.get(acl_tables, short_table_name(table), {}) + cols := object.get(short, "columns", []) +} + +column_allowed(table, col) if { + allowed_columns(table)[_] == "*" +} + +column_allowed(table, col) if { + allowed_columns(table)[_] == col +} + +deny_reasons["table_restricted"] if { + is_query_tool + not is_write_query + table := input.query.referenced_tables[_] + not table_allowed(table) +} + +deny_reasons["column_restricted"] if { + is_query_tool + not is_write_query + table := object.keys(input.query.referenced_columns)[_] + col := input.query.referenced_columns[table][_] + not column_allowed(table, col) +} + +deny_reasons["star_not_allowed"] if { + is_query_tool + not is_write_query + table := input.query.star_tables[_] + not column_allowed(table, "*") +} + +deny_reasons["star_not_allowed"] if { + is_write_query + table := input.query.star_tables[_] + not column_allowed(table, "*") +} + +deny_reasons["table_restricted"] if { + is_write_query + target := object.get(input.query, "target_table", "") + target != "" + not table_allowed(target) +} + +write_columns[col] if { + is_write_query + col := object.get(input.query, "insert_columns", [])[_] +} + +write_columns[col] if { + is_write_query + col := object.get(input.query, "updated_columns", [])[_] +} + +deny_reasons["write_column_restricted"] if { + is_write_query + target := object.get(input.query, "target_table", "") + target != "" + col := write_columns[_] + not column_allowed(target, col) +} + +deny_reasons["write_source_table_restricted"] if { + is_write_query + src := object.get(input.query, "source_tables", [])[_] + not table_allowed(src) +} + +deny_reasons["write_column_restricted"] if { + is_write_query + table := object.keys(input.query.referenced_columns)[_] + col := input.query.referenced_columns[table][_] + not column_allowed(table, col) +} + +deny_reasons["write_column_restricted"] if { + is_write_query + target := object.get(input.query, "target_table", "") + target != "" + col := object.get(input.query, "returning_columns", [])[_] + col != "*" + not column_allowed(target, col) +} + +deny_reasons["star_not_allowed"] if { + is_write_query + target := object.get(input.query, "target_table", "") + target != "" + object.get(input.query, "returning_columns", [])[_] == "*" + not column_allowed(target, "*") +} + +deny_reasons["table_restricted"] if { + is_describe_table_tool + not table_allowed(input.table) +} + +allow if is_list_tables_tool + +allow if { + is_describe_table_tool + table_allowed(input.table) +} + +allow if { + is_query_tool + count(deny_reasons) == 0 +} diff --git a/policy/rego/authz.rego b/policy/rego/authz.rego new file mode 100644 index 0000000..9182453 --- /dev/null +++ b/policy/rego/authz.rego @@ -0,0 +1,22 @@ +package secure_sql.authz + +import data.secure_sql.acl +import data.secure_sql.default_constraints +import data.secure_sql.write_constraints + +default allow = false + +deny_reasons[reason] if default_constraints.deny_reasons[reason] +deny_reasons[reason] if acl.deny_reasons[reason] +deny_reasons[reason] if write_constraints.deny_reasons[reason] + +allow if { + default_constraints.allow + acl.allow + count(deny_reasons) == 0 +} + +decision := { + "allow": allow, + "deny_reasons": [reason | deny_reasons[reason]], +} diff --git a/policy/rego/default_constraints.rego b/policy/rego/default_constraints.rego new file mode 100644 index 0000000..3b7462e --- /dev/null +++ b/policy/rego/default_constraints.rego @@ -0,0 +1,73 @@ +package secure_sql.default_constraints + +default allow = false + +is_query_tool if input.tool.name == "query" +is_write_query if { + is_query_tool + object.get(input.query, "is_write_statement", false) +} + +statement_type := lower(object.get(input.query, "statement_type", "")) + +write_mode_enabled := object.get(object.get(input, "config", {}), "write_mode_enabled", false) +allow_insert := object.get(object.get(input, "config", {}), "allow_insert", false) +allow_update := object.get(object.get(input, "config", {}), "allow_update", false) +allow_delete := object.get(object.get(input, "config", {}), "allow_delete", false) + +deny_reasons["multiple_statements"] if { + is_query_tool + input.query.statement_count != 1 +} + +deny_reasons["ddl_or_privilege_operation"] if { + is_query_tool + input.query.has_disallowed_operation +} + +deny_reasons["not_read_query"] if { + is_query_tool + not input.query.is_read_statement + not is_write_query +} + +deny_reasons["write_not_enabled"] if { + is_write_query + not write_mode_enabled +} + +deny_reasons["insert_not_allowed"] if { + is_write_query + statement_type == "insert" + not allow_insert +} + +deny_reasons["update_not_allowed"] if { + is_write_query + statement_type == "update" + not allow_update +} + +deny_reasons["delete_not_allowed"] if { + is_write_query + statement_type == "delete" + not allow_delete +} + +deny_reasons["insert_columns_missing"] if { + is_write_query + statement_type == "insert" + count(object.get(input.query, "insert_columns", [])) == 0 +} + +deny_reasons["unqualified_multi_table_column"] if { + is_query_tool + input.query.has_unqualified_multi_table_columns +} + +allow if not is_query_tool + +allow if { + is_query_tool + count(deny_reasons) == 0 +} diff --git a/policy/rego/write_constraints.rego b/policy/rego/write_constraints.rego new file mode 100644 index 0000000..b2d42c2 --- /dev/null +++ b/policy/rego/write_constraints.rego @@ -0,0 +1,58 @@ +package secure_sql.write_constraints + +is_query_tool if input.tool.name == "query" +is_write_query if { + is_query_tool + object.get(input.query, "is_write_statement", false) +} + +statement_type := lower(object.get(input.query, "statement_type", "")) +where_present := object.get(input.query, "where_present", false) +where_tautological := object.get(input.query, "where_tautological", false) +returning_present := object.get(input.query, "returning_present", false) + +require_where_for_update := object.get( + object.get(input, "config", {}), + "require_where_for_update", + true, +) +require_where_for_delete := object.get( + object.get(input, "config", {}), + "require_where_for_delete", + true, +) +allow_returning := object.get(object.get(input, "config", {}), "allow_returning", false) + +deny_reasons["missing_where_on_update"] if { + is_write_query + statement_type == "update" + require_where_for_update + not where_present +} + +deny_reasons["missing_where_on_delete"] if { + is_write_query + statement_type == "delete" + require_where_for_delete + not where_present +} + +deny_reasons["tautological_where_clause"] if { + is_write_query + statement_type == "update" + where_present + where_tautological +} + +deny_reasons["tautological_where_clause"] if { + is_write_query + statement_type == "delete" + where_present + where_tautological +} + +deny_reasons["returning_not_allowed"] if { + is_write_query + returning_present + not allow_returning +} diff --git a/pyproject.toml b/pyproject.toml index 466da88..b63a008 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "sqlglot", "sqlalchemy[asyncio]", "asyncpg", + "aiomysql", "aiosqlite", "pydantic-settings", ] @@ -46,6 +47,10 @@ dev = [ [tool.pytest.ini_options] testpaths = ["tests"] pythonpath = ["src"] +markers = [ + "docker_integration: real MCP server integration tests in Docker with OPA", + "smoke: fast representative scenario subset", +] [tool.ruff] line-length = 100 diff --git a/scripts/run-docker-opa-smoke.sh b/scripts/run-docker-opa-smoke.sh new file mode 100755 index 0000000..3a4f5ac --- /dev/null +++ b/scripts/run-docker-opa-smoke.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +echo "Running Docker + OPA smoke scenarios (all backends)..." +python -m pytest -q -m "docker_integration and smoke" tests/integration/docker/test_mcp_docker_opa_matrix.py diff --git a/src/secure_sql_mcp/config.py b/src/secure_sql_mcp/config.py index 34137d1..36c2970 100644 --- a/src/secure_sql_mcp/config.py +++ b/src/secure_sql_mcp/config.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json from pathlib import Path from typing import Any @@ -17,16 +18,65 @@ class Settings(BaseSettings): database_url: str = Field(alias="DATABASE_URL") allowed_policy_file: str = Field(alias="ALLOWED_POLICY_FILE") allowed_policy: dict[str, set[str]] = Field(default_factory=dict) + effective_acl_policy: dict[str, set[str]] = Field(default_factory=dict) + opa_url: str | None = Field(default=None, alias="OPA_URL") + opa_decision_path: str = Field( + default="/v1/data/secure_sql/authz/decision", alias="OPA_DECISION_PATH" + ) + opa_timeout_ms: int = Field(default=50, alias="OPA_TIMEOUT_MS", ge=1, le=5000) + opa_fail_closed: bool = Field(default=True, alias="OPA_FAIL_CLOSED") + opa_acl_data_file: str | None = Field(default=None, alias="OPA_ACL_DATA_FILE") + write_mode_enabled: bool = Field(default=False, alias="WRITE_MODE_ENABLED") + allow_insert: bool = Field(default=False, alias="ALLOW_INSERT") + allow_update: bool = Field(default=False, alias="ALLOW_UPDATE") + allow_delete: bool = Field(default=False, alias="ALLOW_DELETE") + require_where_for_update: bool = Field(default=True, alias="REQUIRE_WHERE_FOR_UPDATE") + require_where_for_delete: bool = Field(default=True, alias="REQUIRE_WHERE_FOR_DELETE") + allow_returning: bool = Field(default=False, alias="ALLOW_RETURNING") max_rows: int = Field(default=100, alias="MAX_ROWS", ge=1, le=10000) query_timeout: int = Field(default=30, alias="QUERY_TIMEOUT", ge=1, le=300) log_level: str = Field(default="INFO", alias="LOG_LEVEL") + @field_validator("database_url", mode="before") + @classmethod + def inject_async_driver(cls, value: Any) -> str: + """Ensure SQLAlchemy async URLs include an async driver suffix.""" + database_url = str(value).strip() + if "://" not in database_url: + return database_url + + scheme = database_url.split("://", 1)[0] + if "+" in scheme: + return database_url + + async_driver_map = { + "postgresql": "asyncpg", + "mysql": "aiomysql", + "sqlite": "aiosqlite", + } + driver = async_driver_map.get(scheme) + if driver is None: + return database_url + + return database_url.replace(f"{scheme}://", f"{scheme}+{driver}://", 1) + @model_validator(mode="after") def load_allowed_policy(self) -> Settings: """Load strict table:columns policy from file.""" self.allowed_policy = self._parse_allowed_policy_file(self.allowed_policy_file) + self.effective_acl_policy = self._load_effective_acl_policy( + self.allowed_policy, self.opa_acl_data_file + ) return self + @field_validator("opa_decision_path", mode="before") + @classmethod + def normalize_opa_decision_path(cls, value: Any) -> str: + path = str(value).strip() + if not path: + return "/v1/data/secure_sql/authz/decision" + return path if path.startswith("/") else f"/{path}" + @field_validator("log_level", mode="before") @classmethod def normalize_log_level(cls, value: Any) -> str: @@ -82,6 +132,72 @@ def _parse_allowed_policy_file(path: str) -> dict[str, set[str]]: raise ValueError("Allowed policy file is empty. Add at least one table rule.") return policy + @classmethod + def _load_effective_acl_policy( + cls, allowed_policy: dict[str, set[str]], opa_acl_data_file: str | None + ) -> dict[str, set[str]]: + """Load ACL from OPA-native data file when available, else fallback to legacy policy.""" + if not opa_acl_data_file: + return {table: set(columns) for table, columns in allowed_policy.items()} + + acl_path = Path(opa_acl_data_file).expanduser() + if not acl_path.exists(): + raise ValueError(f"OPA ACL data file does not exist: {acl_path}") + if not acl_path.is_file(): + raise ValueError(f"OPA ACL data path is not a file: {acl_path}") + + try: + payload = json.loads(acl_path.read_text(encoding="utf-8")) + except json.JSONDecodeError as exc: + raise ValueError(f"OPA ACL data file must be valid JSON: {exc.msg}") from exc + + tables_payload = cls._extract_opa_tables_payload(payload) + parsed: dict[str, set[str]] = {} + for raw_table, raw_rule in tables_payload.items(): + table = str(raw_table).strip().lower() + if not table: + continue + if not isinstance(raw_rule, dict): + raise ValueError(f"OPA ACL rule for table '{table}' must be an object.") + raw_columns = raw_rule.get("columns") + if not isinstance(raw_columns, list) or not raw_columns: + raise ValueError( + f"OPA ACL rule for table '{table}' must include non-empty 'columns' list." + ) + columns = {str(column).strip().lower() for column in raw_columns if str(column).strip()} + if not columns: + raise ValueError( + f"OPA ACL rule for table '{table}' includes no valid column entries." + ) + if "*" in columns and len(columns) > 1: + raise ValueError( + f"OPA ACL wildcard for table '{table}' must be used alone in 'columns'." + ) + parsed[table] = columns + + if not parsed: + raise ValueError("OPA ACL data file resolved to an empty ACL policy.") + return parsed + + @staticmethod + def _extract_opa_tables_payload(payload: Any) -> dict[str, Any]: + if not isinstance(payload, dict): + raise ValueError("OPA ACL data file root must be a JSON object.") + + if "tables" in payload and isinstance(payload["tables"], dict): + return payload["tables"] + + secure_sql = payload.get("secure_sql") + if not isinstance(secure_sql, dict): + raise ValueError("OPA ACL data must define either 'tables' or 'secure_sql.acl.tables'.") + acl = secure_sql.get("acl") + if not isinstance(acl, dict): + raise ValueError("OPA ACL data missing object at 'secure_sql.acl'.") + tables = acl.get("tables") + if not isinstance(tables, dict): + raise ValueError("OPA ACL data missing object at 'secure_sql.acl.tables'.") + return tables + def load_settings() -> Settings: """Load typed settings from environment variables.""" diff --git a/src/secure_sql_mcp/database.py b/src/secure_sql_mcp/database.py index f799833..7eeea71 100644 --- a/src/secure_sql_mcp/database.py +++ b/src/secure_sql_mcp/database.py @@ -21,6 +21,15 @@ class QueryExecutionResult: truncated: bool +@dataclass(slots=True) +class WriteExecutionResult: + """Structured result for write statements.""" + + affected_rows: int + returning_columns: list[str] + returning_rows: list[dict[str, Any]] + + class AsyncDatabase: """Async SQLAlchemy wrapper with read-only execution safeguards.""" @@ -61,6 +70,36 @@ async def _run() -> QueryExecutionResult: return await asyncio.wait_for(_run(), timeout=self._settings.query_timeout) + async def execute_write_query(self, sql: str) -> WriteExecutionResult: + """Execute a single write statement with timeout and optional RETURNING payload.""" + if self._engine is None: + raise RuntimeError("Database engine is not initialized.") + + statement = text(sql.strip().rstrip(";")) + + async def _run() -> WriteExecutionResult: + if self._engine is None: + raise RuntimeError("Database engine is not initialized.") + async with self._engine.begin() as conn: + await self._prepare_write_session(conn) + result = await conn.execute(statement) + affected_rows = int(result.rowcount) if result.rowcount is not None else 0 + returning_rows: list[dict[str, Any]] = [] + returning_columns: list[str] = [] + if result.returns_rows: + fetched = result.fetchmany(self._settings.max_rows + 1) + returning_rows = [ + dict(row._mapping) for row in fetched[: self._settings.max_rows] + ] + returning_columns = list(result.keys()) + return WriteExecutionResult( + affected_rows=affected_rows, + returning_columns=returning_columns, + returning_rows=returning_rows, + ) + + return await asyncio.wait_for(_run(), timeout=self._settings.query_timeout) + async def list_tables(self) -> list[str]: """List all visible base tables from the connected database.""" if self._engine is None: @@ -104,9 +143,22 @@ async def _prepare_read_only_session(self, conn: AsyncConnection) -> None: timeout_ms = int(self._settings.query_timeout) * 1000 await conn.execute(text("BEGIN READ ONLY")) await conn.execute(text(f"SET LOCAL statement_timeout = {timeout_ms}")) + elif self._settings.database_url.startswith("mysql"): + timeout_ms = int(self._settings.query_timeout) * 1000 + await conn.execute(text(f"SET SESSION MAX_EXECUTION_TIME = {timeout_ms}")) + await conn.execute(text("START TRANSACTION READ ONLY")) elif self._settings.database_url.startswith("sqlite"): await conn.execute(text("PRAGMA query_only = ON")) + async def _prepare_write_session(self, conn: AsyncConnection) -> None: + """Apply DB-specific timeout settings for write operations.""" + if self._settings.database_url.startswith("postgresql"): + timeout_ms = int(self._settings.query_timeout) * 1000 + await conn.execute(text(f"SET LOCAL statement_timeout = {timeout_ms}")) + elif self._settings.database_url.startswith("mysql"): + timeout_ms = int(self._settings.query_timeout) * 1000 + await conn.execute(text(f"SET SESSION MAX_EXECUTION_TIME = {timeout_ms}")) + @staticmethod def _wrap_with_limit(sql: str, limit: int) -> str: query = sql.strip().rstrip(";") diff --git a/src/secure_sql_mcp/opa_policy.py b/src/secure_sql_mcp/opa_policy.py new file mode 100644 index 0000000..06392d4 --- /dev/null +++ b/src/secure_sql_mcp/opa_policy.py @@ -0,0 +1,204 @@ +"""OPA policy evaluation helpers.""" + +from __future__ import annotations + +import asyncio +import json +from dataclasses import dataclass, field +from typing import Any +from urllib import error, request + +from secure_sql_mcp.config import Settings + + +@dataclass(slots=True) +class PolicyDecision: + """Normalized policy decision returned to callers.""" + + allow: bool + deny_reasons: list[str] = field(default_factory=list) + message: str | None = None + raw_result: dict[str, Any] | None = None + + +class OpaPolicyEngine: + """Evaluates policy decisions against a local OPA server.""" + + def __init__(self, settings: Settings) -> None: + self.settings = settings + + async def evaluate(self, payload: dict[str, Any]) -> PolicyDecision: + return await asyncio.to_thread(self._evaluate_sync, payload) + + def evaluate_sync(self, payload: dict[str, Any]) -> PolicyDecision: + return self._evaluate_sync(payload) + + def _evaluate_sync(self, payload: dict[str, Any]) -> PolicyDecision: + if not self.settings.opa_url: + return PolicyDecision( + allow=False, + deny_reasons=["opa_unconfigured"], + message=( + "Authorization service is not configured. Please escalate to a human operator." + ), + ) + + endpoint = f"{self.settings.opa_url.rstrip('/')}{self.settings.opa_decision_path}" + body = json.dumps({"input": payload}).encode("utf-8") + req = request.Request( # noqa: S310 + endpoint, + data=body, + method="POST", + headers={"Content-Type": "application/json"}, + ) + + try: + with request.urlopen(req, timeout=self.settings.opa_timeout_ms / 1000) as response: # noqa: S310 + data = json.loads(response.read().decode("utf-8")) + except (error.URLError, TimeoutError, json.JSONDecodeError) as exc: + if self.settings.opa_fail_closed: + return PolicyDecision( + allow=False, + deny_reasons=["opa_unavailable"], + message=( + "Authorization service is unavailable. " + "Please retry or escalate to a human operator." + ), + ) + return PolicyDecision( + allow=True, + deny_reasons=[], + message=None, + raw_result={"warning": str(exc)}, + ) + + result = self._extract_result(data) + if result is None: + return PolicyDecision( + allow=not self.settings.opa_fail_closed, + deny_reasons=["opa_undefined"], + message=( + "Authorization decision is unavailable. " + "Please retry or escalate to a human operator." + ), + ) + + if isinstance(result, bool): + return PolicyDecision(allow=result, raw_result={"allow": result}) + + if not isinstance(result, dict): + return PolicyDecision( + allow=False, + deny_reasons=["opa_invalid_result"], + message="Authorization decision format is invalid.", + ) + + allow = bool(result.get("allow", False)) + deny_reasons = [str(reason) for reason in result.get("deny_reasons", [])] + message = result.get("message") + if message is not None: + message = str(message) + + if not allow and not message: + message = self._message_for_reasons(deny_reasons) + + return PolicyDecision( + allow=allow, + deny_reasons=deny_reasons, + message=message, + raw_result=result, + ) + + @staticmethod + def _extract_result(response_payload: dict[str, Any]) -> Any | None: + # OPA REST response shape: {"result": ...} + if "result" in response_payload: + return response_payload.get("result") + return None + + @staticmethod + def _message_for_reasons(deny_reasons: list[str]) -> str: + if "multiple_statements" in deny_reasons: + return ( + "Only a single SQL statement is allowed. " + "Please remove additional statements and try again." + ) + if "ddl_or_privilege_operation" in deny_reasons: + return ( + "DDL and privilege operations are not permitted. " + "Please escalate to a human operator." + ) + if "disallowed_operation" in deny_reasons: + return ( + "This server is configured for read-only access. " + "If you need to modify data, please escalate to a human operator." + ) + if "write_not_enabled" in deny_reasons: + return ( + "Write operations are disabled by server configuration. " + "Please escalate to a human operator." + ) + if "insert_not_allowed" in deny_reasons: + return ( + "INSERT operations are not permitted by server configuration. " + "Please escalate to a human operator." + ) + if "update_not_allowed" in deny_reasons: + return ( + "UPDATE operations are not permitted by server configuration. " + "Please escalate to a human operator." + ) + if "delete_not_allowed" in deny_reasons: + return ( + "DELETE operations are not permitted by server configuration. " + "Please escalate to a human operator." + ) + if "insert_columns_missing" in deny_reasons: + return ( + "INSERT statements must include an explicit column list under strict mode. " + "Please specify target columns explicitly." + ) + if "not_read_query" in deny_reasons: + return "Only read-only SELECT queries are allowed." + if "missing_where_on_update" in deny_reasons: + return "UPDATE without a WHERE clause is not allowed." + if "missing_where_on_delete" in deny_reasons: + return "DELETE without a WHERE clause is not allowed." + if "tautological_where_clause" in deny_reasons: + return ( + "The WHERE clause appears tautological and may update/delete too broadly. " + "Please provide a restrictive predicate." + ) + if "returning_not_allowed" in deny_reasons: + return "RETURNING is not allowed for this write policy." + if "table_restricted" in deny_reasons: + return ( + "Access to one or more tables is restricted by the server access policy. " + "Please use list_tables/describe_table to view allowed targets." + ) + if "write_source_table_restricted" in deny_reasons: + return ( + "INSERT ... SELECT references one or more source tables restricted by policy. " + "Please use list_tables/describe_table to view allowed targets." + ) + if "write_column_restricted" in deny_reasons: + return ( + "Write access to one or more target columns is restricted by policy. " + "Use describe_table to inspect allowed columns." + ) + if "column_restricted" in deny_reasons: + return ( + "Access to one or more selected columns is restricted by policy. " + "Use describe_table to inspect allowed columns." + ) + if "star_not_allowed" in deny_reasons: + return ( + "SELECT * is not allowed under strict policy for one or more tables. " + "Please select explicit allowed columns." + ) + if "unqualified_multi_table_column" in deny_reasons: + return ( + "Unqualified column references are not allowed in multi-table queries " + "under strict mode." + ) + return "Query blocked by policy." diff --git a/src/secure_sql_mcp/query_validator.py b/src/secure_sql_mcp/query_validator.py index dab4d73..d3f9840 100644 --- a/src/secure_sql_mcp/query_validator.py +++ b/src/secure_sql_mcp/query_validator.py @@ -2,13 +2,18 @@ from __future__ import annotations +import logging from collections import defaultdict from dataclasses import dataclass +from typing import Any import sqlglot from sqlglot import exp from secure_sql_mcp.config import Settings +from secure_sql_mcp.opa_policy import OpaPolicyEngine + +LOGGER = logging.getLogger(__name__) @dataclass(slots=True) @@ -19,16 +24,30 @@ class ValidationResult: normalized_sql: str | None = None referenced_tables: list[str] | None = None referenced_columns: dict[str, list[str]] | None = None + statement_type: str | None = None error: str | None = None +@dataclass(slots=True) +class WriteFacts: + """Parser-derived facts for write statements.""" + + statement_type: str + target_table: str + insert_columns: list[str] + updated_columns: list[str] + where_present: bool + where_tautological: bool + returning_present: bool + returning_columns: list[str] + has_select_source: bool + source_tables: list[str] + + class QueryValidator: """Validates SQL query safety constraints.""" - _DISALLOWED_EXPRESSIONS = ( - exp.Insert, - exp.Update, - exp.Delete, + _ALWAYS_DISALLOWED = ( exp.Drop, exp.Alter, exp.Create, @@ -38,12 +57,16 @@ class QueryValidator: exp.Merge, exp.Command, ) + _WRITE_EXPRESSIONS = (exp.Insert, exp.Update, exp.Delete) - def __init__(self, settings: Settings) -> None: + def __init__(self, settings: Settings, policy_engine: OpaPolicyEngine | None = None) -> None: self.settings = settings + self.policy_engine = policy_engine or ( + OpaPolicyEngine(settings) if settings.opa_url else None + ) def validate_query(self, sql: str) -> ValidationResult: - """Validate SQL for single statement, read-only, and table ACL rules.""" + """Validate SQL and authorize according to configured policy backend.""" query = sql.strip() if not query: return ValidationResult(ok=False, error="Query is empty.") @@ -56,57 +79,255 @@ def validate_query(self, sql: str) -> ValidationResult: error="Could not parse the SQL query. Please check the syntax and try again.", ) - if len(statements) != 1: + if not statements or statements[0] is None: return ValidationResult( ok=False, - error=( - "Only a single SQL statement is allowed. " - "Please remove additional statements and try again." - ), + error="Could not parse the SQL query. Please check the syntax and try again.", ) statement = statements[0] - if statement is None: - return ValidationResult( - ok=False, - error="Could not parse the SQL query. Please check the syntax and try again.", - ) - statement_type = statement.key.upper() if statement.key else "UNKNOWN" + statement_type = self._statement_type(statement) + statement_type_upper = statement_type.upper() + statement_count = len(statements) + has_disallowed_operation = any( + stmt is not None and self._contains_always_disallowed_operation(stmt) + for stmt in statements + ) + is_read_statement = statement_count == 1 and self._is_read_statement(statement) + is_write_statement = statement_count == 1 and self._is_write_statement(statement) - if self._contains_disallowed_operation(statement): - return ValidationResult( - ok=False, - error=( - "This server is configured for read-only access. " - f"The operation '{statement_type}' is not permitted. " - "If you need to modify data, please escalate to a human operator." - ), - ) + referenced_tables: list[str] = [] + referenced_columns: dict[str, set[str]] = {} + star_tables: set[str] = set() + has_unqualified_multi_table_columns = False + write_facts: WriteFacts | None = None + + if statement_count == 1: + referenced_tables = self.extract_referenced_tables(statement) + if is_write_statement: + write_facts = self._extract_write_facts(statement) + if write_facts is None: + return ValidationResult( + ok=False, + error=( + "Could not determine write operation details from SQL. " + "Please use an explicit INSERT/UPDATE/DELETE statement." + ), + ) + + if not self.settings.write_mode_enabled: + self._warn_if_policy_would_allow_blocked_write( + sql=query, + statement_count=statement_count, + statement_type=statement_type, + has_disallowed_operation=has_disallowed_operation, + is_read_statement=is_read_statement, + referenced_tables=referenced_tables, + referenced_columns=referenced_columns, + star_tables=star_tables, + has_unqualified_multi_table_columns=has_unqualified_multi_table_columns, + write_facts=write_facts, + blocked_reason="WRITE_MODE_ENABLED", + ) + return ValidationResult( + ok=False, + statement_type=statement_type, + error=( + "This server is configured for read-only access. " + f"The operation '{statement_type_upper}' is not permitted. " + "Please escalate to a human operator." + ), + ) + + if not self._is_write_action_enabled(statement_type): + gate_name = f"ALLOW_{statement_type_upper}" + self._warn_if_policy_would_allow_blocked_write( + sql=query, + statement_count=statement_count, + statement_type=statement_type, + has_disallowed_operation=has_disallowed_operation, + is_read_statement=is_read_statement, + referenced_tables=referenced_tables, + referenced_columns=referenced_columns, + star_tables=star_tables, + has_unqualified_multi_table_columns=has_unqualified_multi_table_columns, + write_facts=write_facts, + blocked_reason=gate_name, + ) + return ValidationResult( + ok=False, + statement_type=statement_type, + error=( + f"{statement_type_upper} operations are disabled " + "by server configuration. " + "Please escalate to a human operator." + ), + ) + + if write_facts.statement_type == "insert" and not write_facts.insert_columns: + return ValidationResult( + ok=False, + statement_type=statement_type, + error=( + "INSERT statements must include an explicit " + "column list under strict mode. " + "Please specify allowed target columns explicitly." + ), + ) + if ( + write_facts.statement_type in {"update", "delete"} + and self.settings.require_where_for_update + and write_facts.statement_type == "update" + and not write_facts.where_present + ): + return ValidationResult( + ok=False, + statement_type=statement_type, + error=f"{statement_type_upper} without a WHERE clause is not allowed.", + ) + if ( + write_facts.statement_type in {"update", "delete"} + and self.settings.require_where_for_delete + and write_facts.statement_type == "delete" + and not write_facts.where_present + ): + return ValidationResult( + ok=False, + statement_type=statement_type, + error=f"{statement_type_upper} without a WHERE clause is not allowed.", + ) + if ( + write_facts.statement_type in {"update", "delete"} + and write_facts.where_tautological + ): + return ValidationResult( + ok=False, + statement_type=statement_type, + error=( + "The WHERE clause appears tautological and may " + "update/delete too broadly. " + "Please provide a restrictive predicate." + ), + ) + + if write_facts.returning_present and not self.settings.allow_returning: + return ValidationResult( + ok=False, + statement_type=statement_type, + error="RETURNING is not allowed for this write policy.", + ) + + if self.policy_engine is None: + write_acl_error = self._validate_write_acl(write_facts) + if write_acl_error: + return ValidationResult( + ok=False, statement_type=statement_type, error=write_acl_error + ) + + columns_result = self.extract_referenced_columns(statement, referenced_tables) + if isinstance(columns_result, str): + return ValidationResult( + ok=False, statement_type=statement_type, error=columns_result + ) + referenced_columns, star_tables = columns_result + + table_policy = self._resolve_table_policy(referenced_tables) + if isinstance(table_policy, str): + return ValidationResult( + ok=False, statement_type=statement_type, error=table_policy + ) + columns_error = self._validate_column_access( + table_policy, referenced_columns, star_tables + ) + if columns_error: + return ValidationResult( + ok=False, statement_type=statement_type, error=columns_error + ) + else: + referenced_columns, star_tables, has_unqualified_multi_table_columns = ( + self._extract_referenced_columns_relaxed(statement, referenced_tables) + ) + elif is_read_statement: + if self.policy_engine is None: + table_policy = self._resolve_table_policy(referenced_tables) + if isinstance(table_policy, str): + return ValidationResult(ok=False, error=table_policy) + + columns_result = self.extract_referenced_columns(statement, referenced_tables) + if isinstance(columns_result, str): + return ValidationResult(ok=False, error=columns_result) + + referenced_columns, star_tables = columns_result + columns_error = self._validate_column_access( + table_policy, referenced_columns, star_tables + ) + if columns_error: + return ValidationResult(ok=False, error=columns_error) + + for table in referenced_tables: + access_error = self.table_access_error(table, table_policy=table_policy) + if access_error: + return ValidationResult(ok=False, error=access_error) + else: + referenced_columns, star_tables, has_unqualified_multi_table_columns = ( + self._extract_referenced_columns_relaxed(statement, referenced_tables) + ) + else: + referenced_columns, star_tables, has_unqualified_multi_table_columns = ( + self._extract_referenced_columns_relaxed(statement, referenced_tables) + ) - if not self._is_read_statement(statement): - return ValidationResult( - ok=False, - error=(f"Only read-only SELECT queries are allowed. Received '{statement_type}'."), + if self.policy_engine is None: + if statement_count != 1: + return ValidationResult( + ok=False, + statement_type=statement_type, + error=( + "Only a single SQL statement is allowed. " + "Please remove additional statements and try again." + ), + ) + if has_disallowed_operation: + return ValidationResult( + ok=False, + statement_type=statement_type, + error=( + "This server is configured for read-only access. " + f"The operation '{statement_type_upper}' is not permitted. " + "If you need to modify data, please escalate to a human operator." + ), + ) + if not (is_read_statement or is_write_statement): + return ValidationResult( + ok=False, + statement_type=statement_type, + error=( + "Only read-only SELECT queries or explicitly enabled " + "write operations are allowed. " + f"Received '{statement_type_upper}'." + ), + ) + else: + decision = self.policy_engine.evaluate_sync( + self._build_query_policy_input( + sql=query, + statement_count=statement_count, + statement_type=statement_type, + has_disallowed_operation=has_disallowed_operation, + is_read_statement=is_read_statement, + referenced_tables=referenced_tables, + referenced_columns=referenced_columns, + star_tables=star_tables, + has_unqualified_multi_table_columns=has_unqualified_multi_table_columns, + write_facts=write_facts, + ) ) - - referenced_tables = self.extract_referenced_tables(statement) - table_policy = self._resolve_table_policy(referenced_tables) - if isinstance(table_policy, str): - return ValidationResult(ok=False, error=table_policy) - - columns_result = self.extract_referenced_columns(statement, referenced_tables) - if isinstance(columns_result, str): - return ValidationResult(ok=False, error=columns_result) - - referenced_columns, star_tables = columns_result - columns_error = self._validate_column_access(table_policy, referenced_columns, star_tables) - if columns_error: - return ValidationResult(ok=False, error=columns_error) - - for table in referenced_tables: - access_error = self.table_access_error(table, table_policy=table_policy) - if access_error: - return ValidationResult(ok=False, error=access_error) + if not decision.allow: + return ValidationResult( + ok=False, + statement_type=statement_type, + error=decision.message or "Query blocked by policy.", + ) return ValidationResult( ok=True, @@ -117,6 +338,7 @@ def validate_query(self, sql: str) -> ValidationResult: referenced_columns={ table: sorted(columns) for table, columns in referenced_columns.items() }, + statement_type=statement_type, ) def table_access_error( @@ -191,7 +413,7 @@ def extract_referenced_columns( def _resolve_table_policy(self, tables: list[str]) -> dict[str, set[str]] | str: resolved: dict[str, set[str]] = {} - available = ", ".join(sorted(self.settings.allowed_policy)) + available = ", ".join(sorted(self.settings.effective_acl_policy)) for table in tables: policy_columns = self.lookup_table_policy(table) @@ -239,12 +461,31 @@ def _validate_column_access( def _dialect(self) -> str | None: if self.settings.database_url.startswith("postgresql"): return "postgres" + if self.settings.database_url.startswith("mysql"): + return "mysql" if self.settings.database_url.startswith("sqlite"): return "sqlite" return None - def _contains_disallowed_operation(self, statement: exp.Expression) -> bool: - return any(statement.find(kind) is not None for kind in self._DISALLOWED_EXPRESSIONS) + def _contains_always_disallowed_operation(self, statement: exp.Expression) -> bool: + return any(statement.find(kind) is not None for kind in self._ALWAYS_DISALLOWED) + + def _is_write_statement(self, statement: exp.Expression) -> bool: + return isinstance(statement, self._WRITE_EXPRESSIONS) + + @staticmethod + def _statement_type(statement: exp.Expression) -> str: + key = (statement.key or "unknown").lower() + return key + + def _is_write_action_enabled(self, statement_type: str) -> bool: + if statement_type == "insert": + return self.settings.allow_insert + if statement_type == "update": + return self.settings.allow_update + if statement_type == "delete": + return self.settings.allow_delete + return False @staticmethod def _is_read_statement(statement: exp.Expression) -> bool: @@ -252,7 +493,7 @@ def _is_read_statement(statement: exp.Expression) -> bool: return True if isinstance(statement, (exp.Union, exp.Intersect, exp.Except)): return True - return statement.find(exp.Select) is not None + return False @staticmethod def _table_to_name(table: exp.Table) -> str: @@ -270,8 +511,8 @@ def lookup_table_policy(self, table_name: str) -> set[str] | None: normalized = table_name.lower() candidates = (normalized, normalized.split(".")[-1]) for candidate in candidates: - if candidate in self.settings.allowed_policy: - return set(self.settings.allowed_policy[candidate]) + if candidate in self.settings.effective_acl_policy: + return set(self.settings.effective_acl_policy[candidate]) return None def _build_alias_map(self, statement: exp.Expression) -> dict[str, str]: @@ -287,3 +528,382 @@ def _build_alias_map(self, statement: exp.Expression) -> dict[str, str]: alias_name = alias_expr.name.lower() alias_map[alias_name] = table_name return alias_map + + def _extract_referenced_columns_relaxed( + self, statement: exp.Expression, referenced_tables: list[str] + ) -> tuple[dict[str, set[str]], set[str], bool]: + alias_map = self._build_alias_map(statement) + columns_by_table: defaultdict[str, set[str]] = defaultdict(set) + unqualified_columns: set[str] = set() + star_tables: set[str] = set() + + for column in statement.find_all(exp.Column): + if isinstance(column.this, exp.Star): + continue + if not column.name: + continue + + col_name = column.name.lower() + qualifier = (column.table or "").lower() + if qualifier: + table_name = alias_map.get(qualifier, qualifier) + columns_by_table[table_name].add(col_name) + else: + unqualified_columns.add(col_name) + + has_unqualified_multi_table_columns = bool( + unqualified_columns and len(referenced_tables) > 1 + ) + if unqualified_columns and len(referenced_tables) == 1: + columns_by_table[referenced_tables[0]].update(unqualified_columns) + + for select in statement.find_all(exp.Select): + for expression in select.expressions: + if isinstance(expression, exp.Star): + if len(referenced_tables) == 1: + star_tables.add(referenced_tables[0]) + elif len(referenced_tables) > 1: + star_tables.update(referenced_tables) + elif isinstance(expression, exp.Column) and isinstance(expression.this, exp.Star): + qualifier = (expression.table or "").lower() + if qualifier: + star_tables.add(alias_map.get(qualifier, qualifier)) + + return dict(columns_by_table), star_tables, has_unqualified_multi_table_columns + + def _extract_write_facts(self, statement: exp.Expression) -> WriteFacts | None: + statement_type = self._statement_type(statement) + if statement_type not in {"insert", "update", "delete"}: + return None + + target_table = self._extract_target_table(statement) + if not target_table: + return None + + insert_columns: list[str] = [] + updated_columns: list[str] = [] + where_present = False + where_tautological = False + returning_present = ( + bool(statement.args.get("returning")) or statement.find(exp.Returning) is not None + ) + returning_columns = self._extract_returning_columns(statement) + source_tables: list[str] = [] + + if isinstance(statement, exp.Insert): + insert_columns = self._extract_insert_columns(statement) + source_expr = statement.args.get("expression") + if isinstance(source_expr, exp.Expression): + source_tables = self.extract_referenced_tables(source_expr) + elif isinstance(statement, exp.Update): + updated_columns = self._extract_update_columns(statement) + where_expr = statement.args.get("where") + where_present = where_expr is not None + if isinstance(where_expr, exp.Expression): + where_tautological = self._is_tautological_where(where_expr) + elif isinstance(statement, exp.Delete): + where_expr = statement.args.get("where") + where_present = where_expr is not None + if isinstance(where_expr, exp.Expression): + where_tautological = self._is_tautological_where(where_expr) + + if not source_tables: + all_tables = self.extract_referenced_tables(statement) + source_tables = sorted(table for table in all_tables if table != target_table) + + return WriteFacts( + statement_type=statement_type, + target_table=target_table, + insert_columns=sorted(set(insert_columns)), + updated_columns=sorted(set(updated_columns)), + where_present=where_present, + where_tautological=where_tautological, + returning_present=returning_present, + returning_columns=returning_columns, + has_select_source=bool(source_tables), + source_tables=sorted(set(source_tables)), + ) + + def _extract_target_table(self, statement: exp.Expression) -> str | None: + target_expr = statement.args.get("this") + if isinstance(target_expr, exp.Schema): + target_expr = target_expr.this + if isinstance(target_expr, exp.Table): + return self._table_to_name(target_expr) + if isinstance(target_expr, exp.Expression): + table = target_expr.find(exp.Table) + if isinstance(table, exp.Table): + return self._table_to_name(table) + return None + + def _extract_insert_columns(self, statement: exp.Insert) -> list[str]: + target_expr = statement.args.get("this") + if isinstance(target_expr, exp.Schema): + columns: list[str] = [] + for column in target_expr.expressions: + if isinstance(column, exp.Column) and column.name: + columns.append(column.name.lower()) + elif isinstance(column, exp.Identifier) and column.this: + columns.append(str(column.this).lower()) + return columns + return [] + + @staticmethod + def _extract_update_columns(statement: exp.Update) -> list[str]: + columns: set[str] = set() + for assignment in statement.expressions: + lhs = assignment.args.get("this") if isinstance(assignment, exp.Expression) else None + if isinstance(lhs, exp.Column) and lhs.name: + columns.add(lhs.name.lower()) + return sorted(columns) + + @staticmethod + def _extract_returning_columns(statement: exp.Expression) -> list[str]: + returning_expr = statement.args.get("returning") + if not isinstance(returning_expr, exp.Returning): + return [] + + columns: set[str] = set() + for expression in returning_expr.expressions: + if isinstance(expression, exp.Star): + columns.add("*") + continue + if isinstance(expression, exp.Column): + if isinstance(expression.this, exp.Star): + columns.add("*") + continue + if expression.name: + columns.add(expression.name.lower()) + continue + for nested_column in expression.find_all(exp.Column): + if isinstance(nested_column.this, exp.Star): + columns.add("*") + continue + if nested_column.name: + columns.add(nested_column.name.lower()) + return sorted(columns) + + def _is_tautological_where(self, where_expr: exp.Expression) -> bool: + expr = where_expr.this if isinstance(where_expr, exp.Where) else where_expr + if isinstance(expr, exp.Paren): + return self._is_tautological_where(expr.this) + if isinstance(expr, exp.Boolean): + return bool(expr.this) + if isinstance(expr, exp.Not): + child = expr.this + return isinstance(child, exp.Boolean) and not bool(child.this) + if isinstance(expr, exp.Literal): + if expr.is_string: + return expr.this.strip().lower() in {"true", "t", "yes", "on", "1"} + return str(expr.this).strip() in {"1"} + if isinstance(expr, exp.Or): + return self._is_tautological_where(expr.left) or self._is_tautological_where(expr.right) + if isinstance(expr, exp.And): + return self._is_tautological_where(expr.left) and self._is_tautological_where( + expr.right + ) + if isinstance(expr, exp.EQ): + left = expr.left + right = expr.right + if isinstance(left, exp.Literal) and isinstance(right, exp.Literal): + return str(left.this) == str(right.this) and left.is_string == right.is_string + if isinstance(left, exp.Column) and isinstance(right, exp.Column): + return ( + left.name.lower() == right.name.lower() + and (left.table or "").lower() == (right.table or "").lower() + ) + return False + + def _validate_write_acl(self, write_facts: WriteFacts) -> str | None: + target_policy = self.lookup_table_policy(write_facts.target_table) + if target_policy is None: + available = ", ".join(sorted(self.settings.effective_acl_policy)) + return ( + f"Access to table '{write_facts.target_table}' " + "is restricted by the server access policy. " + f"Allowed tables are: {available}. " + "Please use list_tables/describe_table or escalate to a human operator." + ) + + changed_columns = set(write_facts.insert_columns or write_facts.updated_columns) + if changed_columns and "*" not in target_policy: + disallowed = sorted(column for column in changed_columns if column not in target_policy) + if disallowed: + allowed_text = ", ".join(sorted(target_policy)) + return ( + f"Write access to column(s) {', '.join(disallowed)} " + f"on table '{write_facts.target_table}' " + "is restricted. " + f"Allowed columns: {allowed_text}. " + "Use describe_table to inspect policy or escalate to a human operator." + ) + + if write_facts.returning_present: + if "*" in write_facts.returning_columns and "*" not in target_policy: + allowed_text = ", ".join(sorted(target_policy)) + return ( + f"RETURNING * is not allowed for table '{write_facts.target_table}' " + "under strict policy. " + f"Allowed columns: {allowed_text}. " + "Please list explicit allowed RETURNING columns." + ) + if "*" not in target_policy: + disallowed_returning = sorted( + column + for column in write_facts.returning_columns + if column != "*" and column not in target_policy + ) + if disallowed_returning: + allowed_text = ", ".join(sorted(target_policy)) + return ( + f"RETURNING column(s) {', '.join(disallowed_returning)} on table " + f"'{write_facts.target_table}' are restricted. " + f"Allowed columns: {allowed_text}. " + "Use describe_table to inspect policy or escalate to a human operator." + ) + + return None + + def _warn_if_policy_would_allow_blocked_write( + self, + *, + sql: str, + statement_count: int, + statement_type: str, + has_disallowed_operation: bool, + is_read_statement: bool, + referenced_tables: list[str], + referenced_columns: dict[str, set[str]], + star_tables: set[str], + has_unqualified_multi_table_columns: bool, + write_facts: WriteFacts | None, + blocked_reason: str, + ) -> None: + if self.policy_engine is None or write_facts is None: + return + + shadow_payload = self._build_query_policy_input( + sql=sql, + statement_count=statement_count, + statement_type=statement_type, + has_disallowed_operation=has_disallowed_operation, + is_read_statement=is_read_statement, + referenced_tables=referenced_tables, + referenced_columns=referenced_columns, + star_tables=star_tables, + has_unqualified_multi_table_columns=has_unqualified_multi_table_columns, + write_facts=write_facts, + config_overrides={ + "write_mode_enabled": True, + "allow_insert": True, + "allow_update": True, + "allow_delete": True, + "allow_returning": True, + }, + ) + decision = self.policy_engine.evaluate_sync(shadow_payload) + if decision.allow: + operation = statement_type.upper() + gate_name = blocked_reason.upper() + enabled_flag = "true" if self.settings.write_mode_enabled else "false" + action_flag = "true" if self._is_write_action_enabled(statement_type) else "false" + LOGGER.warning( + "Write operation '%s' blocked by config gate " + "(%s, WRITE_MODE_ENABLED=%s, ALLOW_%s=%s).", + operation, + gate_name, + enabled_flag, + operation, + action_flag, + ) + + def _build_query_policy_input( + self, + *, + sql: str, + statement_count: int, + statement_type: str, + has_disallowed_operation: bool, + is_read_statement: bool, + referenced_tables: list[str], + referenced_columns: dict[str, set[str]], + star_tables: set[str], + has_unqualified_multi_table_columns: bool, + write_facts: WriteFacts | None, + config_overrides: dict[str, bool] | None = None, + ) -> dict[str, Any]: + acl_tables = { + table: {"columns": sorted(columns)} + for table, columns in sorted(self.settings.effective_acl_policy.items()) + } + write_mode_enabled = ( + config_overrides["write_mode_enabled"] + if config_overrides and "write_mode_enabled" in config_overrides + else self.settings.write_mode_enabled + ) + allow_insert = ( + config_overrides["allow_insert"] + if config_overrides and "allow_insert" in config_overrides + else self.settings.allow_insert + ) + allow_update = ( + config_overrides["allow_update"] + if config_overrides and "allow_update" in config_overrides + else self.settings.allow_update + ) + allow_delete = ( + config_overrides["allow_delete"] + if config_overrides and "allow_delete" in config_overrides + else self.settings.allow_delete + ) + require_where_for_update = ( + config_overrides["require_where_for_update"] + if config_overrides and "require_where_for_update" in config_overrides + else self.settings.require_where_for_update + ) + require_where_for_delete = ( + config_overrides["require_where_for_delete"] + if config_overrides and "require_where_for_delete" in config_overrides + else self.settings.require_where_for_delete + ) + allow_returning = ( + config_overrides["allow_returning"] + if config_overrides and "allow_returning" in config_overrides + else self.settings.allow_returning + ) + return { + "tool": {"name": "query"}, + "query": { + "raw_sql": sql, + "statement_count": statement_count, + "statement_type": statement_type, + "is_write_statement": write_facts is not None, + "has_disallowed_operation": has_disallowed_operation, + "is_read_statement": is_read_statement, + "referenced_tables": referenced_tables, + "referenced_columns": { + table: sorted(columns) for table, columns in sorted(referenced_columns.items()) + }, + "star_tables": sorted(star_tables), + "has_unqualified_multi_table_columns": has_unqualified_multi_table_columns, + "target_table": write_facts.target_table if write_facts else "", + "insert_columns": write_facts.insert_columns if write_facts else [], + "updated_columns": write_facts.updated_columns if write_facts else [], + "where_present": write_facts.where_present if write_facts else False, + "where_tautological": write_facts.where_tautological if write_facts else False, + "returning_present": write_facts.returning_present if write_facts else False, + "returning_columns": write_facts.returning_columns if write_facts else [], + "has_select_source": write_facts.has_select_source if write_facts else False, + "source_tables": write_facts.source_tables if write_facts else [], + }, + "config": { + "write_mode_enabled": write_mode_enabled, + "allow_insert": allow_insert, + "allow_update": allow_update, + "allow_delete": allow_delete, + "require_where_for_update": require_where_for_update, + "require_where_for_delete": require_where_for_delete, + "allow_returning": allow_returning, + }, + "acl": {"tables": acl_tables}, + } diff --git a/src/secure_sql_mcp/server.py b/src/secure_sql_mcp/server.py index 14727d3..35630f3 100644 --- a/src/secure_sql_mcp/server.py +++ b/src/secure_sql_mcp/server.py @@ -2,16 +2,20 @@ from __future__ import annotations +import argparse import json import logging +import os from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass +from typing import Any from mcp.server.fastmcp import FastMCP from secure_sql_mcp.config import Settings, load_settings from secure_sql_mcp.database import AsyncDatabase +from secure_sql_mcp.opa_policy import OpaPolicyEngine from secure_sql_mcp.query_validator import QueryValidator LOGGER = logging.getLogger(__name__) @@ -22,6 +26,7 @@ class AppState: settings: Settings db: AsyncDatabase validator: QueryValidator + policy_engine: OpaPolicyEngine | None STATE: AppState | None = None @@ -34,9 +39,10 @@ async def lifespan(_: FastMCP) -> AsyncIterator[None]: settings = load_settings() logging.basicConfig(level=settings.log_level) db = AsyncDatabase(settings) - validator = QueryValidator(settings) + policy_engine = OpaPolicyEngine(settings) if settings.opa_url else None + validator = QueryValidator(settings, policy_engine=policy_engine) await db.connect() - STATE = AppState(settings=settings, db=db, validator=validator) + STATE = AppState(settings=settings, db=db, validator=validator, policy_engine=policy_engine) LOGGER.info("secure-sql-mcp started") try: yield @@ -57,13 +63,27 @@ def _state() -> AppState: @mcp.tool() async def query(sql: str) -> str: - """Run a read-only SQL query and return structured results.""" + """Run a SQL query (read-only by default; writes only when explicitly enabled).""" app = _state() validation = app.validator.validate_query(sql) if not validation.ok: return validation.error or "Query blocked by policy." + statement_type = (validation.statement_type or "").lower() try: + if statement_type in {"insert", "update", "delete"}: + write_result = await app.db.execute_write_query(validation.normalized_sql or sql) + payload = { + "status": "ok", + "operation": statement_type, + "affected_rows": write_result.affected_rows, + "returning_columns": write_result.returning_columns, + "returning": write_result.returning_rows, + "referenced_tables": validation.referenced_tables or [], + "referenced_columns": validation.referenced_columns or {}, + } + return json.dumps(payload, default=str, indent=2) + result = await app.db.execute_read_query(validation.normalized_sql or sql) except TimeoutError: return ( @@ -94,7 +114,14 @@ async def query(sql: str) -> str: async def list_tables() -> str: """List tables the agent is allowed to query, validating existence when possible.""" app = _state() - policy = app.settings.allowed_policy + if app.policy_engine is not None: + decision = await app.policy_engine.evaluate( + _build_tool_policy_input("list_tables", app.settings) + ) + if not decision.allow: + return decision.message or "Operation blocked by policy." + + policy = app.settings.effective_acl_policy policy_tables = sorted(policy) policy_set = {t.lower() for t in policy} discovered: list[str] = [] @@ -139,9 +166,16 @@ async def list_tables() -> str: async def describe_table(table: str) -> str: """Describe columns for an allowed table.""" app = _state() + if app.policy_engine is not None: + payload = _build_tool_policy_input("describe_table", app.settings) + payload["table"] = table.lower() + decision = await app.policy_engine.evaluate(payload) + if not decision.allow: + return decision.message or "Operation blocked by policy." + policy_columns = app.validator.lookup_table_policy(table) if policy_columns is None: - available_tables = ", ".join(sorted(app.settings.allowed_policy)) + available_tables = ", ".join(sorted(app.settings.effective_acl_policy)) return ( f"Access to table '{table}' is restricted by the server access policy. " f"Allowed tables are: {available_tables}. " @@ -175,8 +209,48 @@ async def describe_table(table: str) -> str: def main() -> None: """Run the MCP server with stdio transport.""" + parser = argparse.ArgumentParser(add_help=True) + parser.add_argument( + "--write-mode", + action="store_true", + help="Enable write-mode execution path (disabled by default).", + ) + parser.add_argument( + "--allow-insert", + action="store_true", + help="Allow INSERT statements when write mode is enabled.", + ) + parser.add_argument( + "--allow-update", + action="store_true", + help="Allow UPDATE statements when write mode is enabled.", + ) + parser.add_argument( + "--allow-delete", + action="store_true", + help="Allow DELETE statements when write mode is enabled.", + ) + args, _ = parser.parse_known_args() + + if args.write_mode: + os.environ["WRITE_MODE_ENABLED"] = "true" + if args.allow_insert: + os.environ["ALLOW_INSERT"] = "true" + if args.allow_update: + os.environ["ALLOW_UPDATE"] = "true" + if args.allow_delete: + os.environ["ALLOW_DELETE"] = "true" + mcp.run(transport="stdio") +def _build_tool_policy_input(tool_name: str, settings: Settings) -> dict[str, Any]: + acl_tables = { + table: {"columns": sorted(columns)} + for table, columns in sorted(settings.effective_acl_policy.items()) + } + return {"tool": {"name": tool_name}, "acl": {"tables": acl_tables}} + + if __name__ == "__main__": main() diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..70da13d --- /dev/null +++ b/tests/integration/__init__.py @@ -0,0 +1 @@ +"""Integration test package.""" diff --git a/tests/integration/docker/__init__.py b/tests/integration/docker/__init__.py new file mode 100644 index 0000000..ccfa55f --- /dev/null +++ b/tests/integration/docker/__init__.py @@ -0,0 +1 @@ +"""Docker-backed integration tests.""" diff --git a/tests/integration/docker/acl/restricted_acl.json b/tests/integration/docker/acl/restricted_acl.json new file mode 100644 index 0000000..8d53acf --- /dev/null +++ b/tests/integration/docker/acl/restricted_acl.json @@ -0,0 +1,14 @@ +{ + "secure_sql": { + "acl": { + "tables": { + "customers": { + "columns": ["id", "email"] + }, + "orders": { + "columns": ["id", "total"] + } + } + } + } +} diff --git a/tests/integration/docker/conftest.py b/tests/integration/docker/conftest.py new file mode 100644 index 0000000..5a1050b --- /dev/null +++ b/tests/integration/docker/conftest.py @@ -0,0 +1,212 @@ +from __future__ import annotations + +import os +import sqlite3 +import subprocess +import uuid +from collections.abc import Callable, Iterator +from dataclasses import dataclass +from pathlib import Path + +import pytest +from mcp.client.stdio import StdioServerParameters + + +@dataclass(frozen=True, slots=True) +class BackendConfig: + name: str + database_url: str + needs_deps: bool + + +ROOT = Path(__file__).resolve().parents[3] +COMPOSE_FILE = ROOT / "docker-compose.test.yml" +POLICY_DIR = ROOT / "tests" / "integration" / "docker" / "policies" +ACL_DIR = ROOT / "tests" / "integration" / "docker" / "acl" + + +def _run(command: list[str], *, check: bool = True) -> subprocess.CompletedProcess[str]: + return subprocess.run( # noqa: S603 + command, + cwd=ROOT, + check=check, + text=True, + capture_output=True, + ) + + +@pytest.fixture(scope="session") +def docker_available() -> None: + try: + _run(["docker", "version"]) + _run(["docker", "compose", "version"]) + except (OSError, subprocess.CalledProcessError): + pytest.skip("Docker or docker compose is unavailable on this host.") + + +@pytest.fixture(scope="session") +def compose_project_name() -> str: + return f"secure_sql_it_{uuid.uuid4().hex[:10]}" + + +@pytest.fixture(scope="session") +def docker_stack(docker_available: None, compose_project_name: str) -> Iterator[None]: + compose = ["docker", "compose", "-p", compose_project_name, "-f", str(COMPOSE_FILE)] + _run([*compose, "build", "secure-sql-mcp"]) + _run([*compose, "up", "-d", "postgres", "mysql"]) + try: + yield + finally: + _run([*compose, "down", "-v", "--remove-orphans"], check=False) + + +@pytest.fixture(params=["sqlite", "postgresql", "mysql"]) +def backend(request: pytest.FixtureRequest) -> BackendConfig: + backend_name = str(request.param) + if backend_name == "sqlite": + return BackendConfig( + name="sqlite", + database_url="sqlite+aiosqlite:///run/sqlite/test.db", + needs_deps=False, + ) + if backend_name == "postgresql": + return BackendConfig( + name="postgresql", + database_url="postgresql+asyncpg://secure:secure@postgres:5432/secure_sql_test", + needs_deps=True, + ) + return BackendConfig( + name="mysql", + database_url="mysql+aiomysql://secure:secure@mysql:3306/secure_sql_test", + needs_deps=True, + ) + + +@pytest.fixture +def policy_path() -> Callable[[str], Path]: + def _resolve(policy_name: str) -> Path: + path = POLICY_DIR / f"{policy_name}.txt" + if not path.exists(): + raise FileNotFoundError(f"Policy file not found: {path}") + return path + + return _resolve + + +@pytest.fixture +def acl_path() -> Callable[[str], Path]: + def _resolve(acl_name: str) -> Path: + path = ACL_DIR / f"{acl_name}.json" + if not path.exists(): + raise FileNotFoundError(f"ACL file not found: {path}") + return path + + return _resolve + + +@pytest.fixture +def sqlite_db_dir(tmp_path: Path) -> Path: + db_dir = tmp_path / "sqlite" + db_dir.mkdir(parents=True, exist_ok=True) + db_path = db_dir / "test.db" + conn = sqlite3.connect(db_path) + try: + conn.executescript( + """ + CREATE TABLE customers ( + id INTEGER PRIMARY KEY, + email TEXT NOT NULL, + ssn TEXT + ); + CREATE TABLE orders ( + id INTEGER PRIMARY KEY, + total NUMERIC + ); + CREATE TABLE secrets ( + id INTEGER PRIMARY KEY, + token TEXT + ); + INSERT INTO customers (id, email, ssn) VALUES (1, 'a@example.com', '111-22-3333'); + INSERT INTO orders (id, total) VALUES (10, 19.99); + INSERT INTO secrets (id, token) VALUES (99, 'top-secret-token'); + """ + ) + conn.commit() + finally: + conn.close() + return db_dir + + +@pytest.fixture +def docker_server_params_factory( + compose_project_name: str, + sqlite_db_dir: Path, +) -> Callable[..., StdioServerParameters]: + def _factory( + *, + backend: BackendConfig, + policy_file: Path, + write_mode_enabled: bool = False, + allow_insert: bool = False, + allow_update: bool = False, + allow_delete: bool = False, + require_where_for_update: bool = True, + require_where_for_delete: bool = True, + allow_returning: bool = False, + opa_fail_closed: bool = True, + opa_url: str = "http://127.0.0.1:8181", + opa_decision_path: str = "/v1/data/secure_sql/authz/decision", + opa_acl_data_file: Path | None = None, + ) -> StdioServerParameters: + args = ["run", "--rm", "-i"] + if backend.needs_deps: + args.extend(["--network", f"{compose_project_name}_default"]) + + args.extend( + [ + "-e", + f"DATABASE_URL={backend.database_url}", + "-e", + "ALLOWED_POLICY_FILE=/run/policy/allowed_policy.txt", + "-e", + f"OPA_URL={opa_url}", + "-e", + f"OPA_DECISION_PATH={opa_decision_path}", + "-e", + f"OPA_FAIL_CLOSED={'true' if opa_fail_closed else 'false'}", + "-e", + f"WRITE_MODE_ENABLED={'true' if write_mode_enabled else 'false'}", + "-e", + f"ALLOW_INSERT={'true' if allow_insert else 'false'}", + "-e", + f"ALLOW_UPDATE={'true' if allow_update else 'false'}", + "-e", + f"ALLOW_DELETE={'true' if allow_delete else 'false'}", + "-e", + (f"REQUIRE_WHERE_FOR_UPDATE={'true' if require_where_for_update else 'false'}"), + "-e", + (f"REQUIRE_WHERE_FOR_DELETE={'true' if require_where_for_delete else 'false'}"), + "-e", + f"ALLOW_RETURNING={'true' if allow_returning else 'false'}", + "-v", + f"{policy_file}:/run/policy/allowed_policy.txt:ro", + ] + ) + + if backend.name == "sqlite": + args.extend(["-v", f"{sqlite_db_dir}:/run/sqlite:rw"]) + + if opa_acl_data_file is not None: + args.extend( + [ + "-e", + "OPA_ACL_DATA_FILE=/run/policy/acl.json", + "-v", + f"{opa_acl_data_file}:/run/policy/acl.json:ro", + ] + ) + + args.append("secure-sql-mcp:test") + return StdioServerParameters(command="docker", args=args, env=os.environ.copy()) + + return _factory diff --git a/tests/integration/docker/db-init/mysql.sql b/tests/integration/docker/db-init/mysql.sql new file mode 100644 index 0000000..358499c --- /dev/null +++ b/tests/integration/docker/db-init/mysql.sql @@ -0,0 +1,24 @@ +CREATE TABLE IF NOT EXISTS customers ( + id INT PRIMARY KEY, + email VARCHAR(255) NOT NULL, + ssn VARCHAR(64) +); + +CREATE TABLE IF NOT EXISTS orders ( + id INT PRIMARY KEY, + total DECIMAL(10, 2) +); + +CREATE TABLE IF NOT EXISTS secrets ( + id INT PRIMARY KEY, + token VARCHAR(255) +); + +INSERT IGNORE INTO customers (id, email, ssn) +VALUES (1, 'a@example.com', '111-22-3333'); + +INSERT IGNORE INTO orders (id, total) +VALUES (10, 19.99); + +INSERT IGNORE INTO secrets (id, token) +VALUES (99, 'top-secret-token'); diff --git a/tests/integration/docker/db-init/postgres.sql b/tests/integration/docker/db-init/postgres.sql new file mode 100644 index 0000000..f096b1f --- /dev/null +++ b/tests/integration/docker/db-init/postgres.sql @@ -0,0 +1,27 @@ +CREATE TABLE IF NOT EXISTS customers ( + id INTEGER PRIMARY KEY, + email TEXT NOT NULL, + ssn TEXT +); + +CREATE TABLE IF NOT EXISTS orders ( + id INTEGER PRIMARY KEY, + total NUMERIC +); + +CREATE TABLE IF NOT EXISTS secrets ( + id INTEGER PRIMARY KEY, + token TEXT +); + +INSERT INTO customers (id, email, ssn) +VALUES (1, 'a@example.com', '111-22-3333') +ON CONFLICT (id) DO NOTHING; + +INSERT INTO orders (id, total) +VALUES (10, 19.99) +ON CONFLICT (id) DO NOTHING; + +INSERT INTO secrets (id, token) +VALUES (99, 'top-secret-token') +ON CONFLICT (id) DO NOTHING; diff --git a/tests/integration/docker/policies/read_only_strict.txt b/tests/integration/docker/policies/read_only_strict.txt new file mode 100644 index 0000000..4cbc445 --- /dev/null +++ b/tests/integration/docker/policies/read_only_strict.txt @@ -0,0 +1,2 @@ +customers:id,email +orders:id,total diff --git a/tests/integration/docker/policies/wildcard_tables.txt b/tests/integration/docker/policies/wildcard_tables.txt new file mode 100644 index 0000000..97b8414 --- /dev/null +++ b/tests/integration/docker/policies/wildcard_tables.txt @@ -0,0 +1,2 @@ +customers:* +orders:* diff --git a/tests/integration/docker/policies/write_delete_restricted.txt b/tests/integration/docker/policies/write_delete_restricted.txt new file mode 100644 index 0000000..4cbc445 --- /dev/null +++ b/tests/integration/docker/policies/write_delete_restricted.txt @@ -0,0 +1,2 @@ +customers:id,email +orders:id,total diff --git a/tests/integration/docker/policies/write_insert_only.txt b/tests/integration/docker/policies/write_insert_only.txt new file mode 100644 index 0000000..4cbc445 --- /dev/null +++ b/tests/integration/docker/policies/write_insert_only.txt @@ -0,0 +1,2 @@ +customers:id,email +orders:id,total diff --git a/tests/integration/docker/policies/write_update_restricted.txt b/tests/integration/docker/policies/write_update_restricted.txt new file mode 100644 index 0000000..4cbc445 --- /dev/null +++ b/tests/integration/docker/policies/write_update_restricted.txt @@ -0,0 +1,2 @@ +customers:id,email +orders:id,total diff --git a/tests/integration/docker/test_mcp_docker_opa_matrix.py b/tests/integration/docker/test_mcp_docker_opa_matrix.py new file mode 100644 index 0000000..349d52e --- /dev/null +++ b/tests/integration/docker/test_mcp_docker_opa_matrix.py @@ -0,0 +1,281 @@ +from __future__ import annotations + +import asyncio +import json +import time + +import pytest +from mcp.client.session import ClientSession +from mcp.client.stdio import StdioServerParameters, stdio_client + +from .conftest import BackendConfig + + +def _first_text(call_result: object) -> str: + for item in getattr(call_result, "content", []): + text = getattr(item, "text", None) + if text is not None: + return text + return "" + + +async def _call_tool( + server_params: StdioServerParameters, tool: str, payload: dict[str, object] +) -> str: + async with stdio_client(server_params) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + result = await session.call_tool(tool, payload) + return _first_text(result) + + +def _call_tool_with_retries( + server_params: StdioServerParameters, + tool: str, + payload: dict[str, object], + *, + retries: int = 4, + wait_seconds: float = 2.0, +) -> str: + last_response = "" + for attempt in range(retries): + last_response = asyncio.run(_call_tool(server_params, tool, payload)) + if "database error" not in last_response: + return last_response + if attempt < retries - 1: + time.sleep(wait_seconds) + return last_response + + +pytestmark = pytest.mark.docker_integration + + +def test_read_baseline_policy_enforced( + docker_stack: None, + backend: BackendConfig, + policy_path, + docker_server_params_factory, +) -> None: + params = docker_server_params_factory( + backend=backend, + policy_file=policy_path("read_only_strict"), + ) + + if backend.name != "mysql": + allowed = _call_tool_with_retries( + params, "query", {"sql": "SELECT id, email FROM customers"} + ) + payload = json.loads(allowed) + assert payload["status"] == "ok" + else: + list_tables = asyncio.run(_call_tool(params, "list_tables", {})) + list_payload = json.loads(list_tables) + assert list_payload["status"] == "ok" + + blocked_col = asyncio.run(_call_tool(params, "query", {"sql": "SELECT ssn FROM customers"})) + assert "restricted" in blocked_col + + blocked_table = asyncio.run(_call_tool(params, "query", {"sql": "SELECT id FROM secrets"})) + assert "restricted" in blocked_table + + multi = asyncio.run( + _call_tool(params, "query", {"sql": "SELECT id FROM customers; DROP TABLE customers"}) + ) + assert "Only a single SQL statement is allowed" in multi + + +@pytest.mark.smoke +def test_write_disabled_blocks_insert_even_with_policy_allow( + docker_stack: None, + backend: BackendConfig, + policy_path, + docker_server_params_factory, +) -> None: + params = docker_server_params_factory( + backend=backend, + policy_file=policy_path("write_insert_only"), + write_mode_enabled=False, + allow_insert=True, + ) + blocked = asyncio.run( + _call_tool( + params, + "query", + {"sql": "INSERT INTO customers (id, email) VALUES (2, 'b@example.com')"}, + ) + ) + assert "read-only access" in blocked + assert "INSERT" in blocked + + +@pytest.mark.smoke +def test_insert_allowed_with_write_mode_and_gate( + docker_stack: None, + backend: BackendConfig, + policy_path, + docker_server_params_factory, +) -> None: + if backend.name != "postgresql": + pytest.skip("Write success-path assertions are validated on PostgreSQL in this matrix.") + + params = docker_server_params_factory( + backend=backend, + policy_file=policy_path("wildcard_tables"), + write_mode_enabled=True, + allow_insert=True, + ) + allowed = _call_tool_with_retries( + params, + "query", + {"sql": "INSERT INTO customers (id, email) VALUES (2, 'b@example.com')"}, + ) + payload = json.loads(allowed) + assert payload["status"] == "ok" + assert payload["operation"] == "insert" + assert payload["affected_rows"] == 1 + + +def test_insert_select_source_table_and_star_protections( + docker_stack: None, + backend: BackendConfig, + policy_path, + docker_server_params_factory, +) -> None: + params = docker_server_params_factory( + backend=backend, + policy_file=policy_path("write_insert_only"), + write_mode_enabled=True, + allow_insert=True, + ) + + disallowed_source = asyncio.run( + _call_tool( + params, + "query", + {"sql": "INSERT INTO orders (id, total) SELECT s.id, s.id FROM secrets AS s"}, + ) + ) + assert "restricted" in disallowed_source + + star_source = asyncio.run( + _call_tool( + params, "query", {"sql": "INSERT INTO orders (id, total) SELECT * FROM customers"} + ) + ) + assert "restricted" in star_source + + +def test_update_delete_where_guards_and_tautology( + docker_stack: None, + backend: BackendConfig, + policy_path, + docker_server_params_factory, +) -> None: + if backend.name != "postgresql": + pytest.skip("Write success-path assertions are validated on PostgreSQL in this matrix.") + + params = docker_server_params_factory( + backend=backend, + policy_file=policy_path("wildcard_tables"), + write_mode_enabled=True, + allow_update=True, + allow_delete=True, + ) + + missing_where_update = asyncio.run( + _call_tool(params, "query", {"sql": "UPDATE customers SET email = 'x@example.com'"}) + ) + assert "without a WHERE clause is not allowed" in missing_where_update + + tautological_delete = asyncio.run( + _call_tool(params, "query", {"sql": "DELETE FROM orders WHERE 1 = 1"}) + ) + assert "WHERE clause appears tautological" in tautological_delete + + valid_update = _call_tool_with_retries( + params, "query", {"sql": "UPDATE customers SET email = 'x@example.com' WHERE id = 1"} + ) + payload = json.loads(valid_update) + assert payload["status"] == "ok" + assert payload["operation"] == "update" + + +def test_returning_controls_and_column_acl( + docker_stack: None, + backend: BackendConfig, + policy_path, + docker_server_params_factory, +) -> None: + blocked_params = docker_server_params_factory( + backend=backend, + policy_file=policy_path("wildcard_tables"), + write_mode_enabled=True, + allow_update=True, + allow_returning=False, + ) + returning_blocked = asyncio.run( + _call_tool( + blocked_params, + "query", + {"sql": "UPDATE customers SET email = 'x@example.com' WHERE id = 1 RETURNING email"}, + ) + ) + assert "RETURNING is not allowed" in returning_blocked + + if backend.name != "postgresql": + return + + allowed_params = docker_server_params_factory( + backend=backend, + policy_file=policy_path("write_update_restricted"), + write_mode_enabled=True, + allow_update=True, + allow_returning=True, + ) + restricted_column = asyncio.run( + _call_tool( + allowed_params, + "query", + {"sql": "UPDATE customers SET email = 'x@example.com' WHERE id = 1 RETURNING ssn"}, + ) + ) + assert "restricted" in restricted_column + + +def test_opa_fail_closed_when_unavailable( + docker_stack: None, + backend: BackendConfig, + policy_path, + docker_server_params_factory, +) -> None: + params = docker_server_params_factory( + backend=backend, + policy_file=policy_path("read_only_strict"), + opa_decision_path="/v1/data/secure_sql/authz/missing", + opa_fail_closed=True, + ) + query_msg = asyncio.run(_call_tool(params, "query", {"sql": "SELECT id FROM customers"})) + assert "Authorization decision is unavailable" in query_msg + + list_msg = asyncio.run(_call_tool(params, "list_tables", {})) + assert "Authorization decision is unavailable" in list_msg + + describe_msg = asyncio.run(_call_tool(params, "describe_table", {"table": "customers"})) + assert "Authorization decision is unavailable" in describe_msg + + +def test_opa_acl_data_file_profile_works( + docker_stack: None, + backend: BackendConfig, + policy_path, + acl_path, + docker_server_params_factory, +) -> None: + params = docker_server_params_factory( + backend=backend, + policy_file=policy_path("wildcard_tables"), + opa_acl_data_file=acl_path("restricted_acl"), + ) + + blocked = asyncio.run(_call_tool(params, "query", {"sql": "SELECT ssn FROM customers"})) + assert "restricted" in blocked diff --git a/tests/test_config.py b/tests/test_config.py index 3c1b98f..f94d50d 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -11,6 +11,45 @@ from tests.conftest import write_policy +@pytest.mark.parametrize( + ("database_url", "expected_url"), + [ + ( + "postgresql://user:pass@localhost:5432/appdb", + "postgresql+asyncpg://user:pass@localhost:5432/appdb", + ), + ( + "mysql://user:pass@localhost:3306/appdb", + "mysql+aiomysql://user:pass@localhost:3306/appdb", + ), + ( + "sqlite:///./tmp.db", + "sqlite+aiosqlite:///./tmp.db", + ), + ( + "postgresql+asyncpg://user:pass@localhost:5432/appdb", + "postgresql+asyncpg://user:pass@localhost:5432/appdb", + ), + ( + "mysql+aiomysql://user:pass@localhost:3306/appdb", + "mysql+aiomysql://user:pass@localhost:3306/appdb", + ), + ], +) +def test_database_url_injects_or_preserves_async_driver( + tmp_path: Path, database_url: str, expected_url: str +) -> None: + policy_path = tmp_path / "policy.txt" + write_policy(policy_path, "customers:id\n") + settings = Settings.model_validate( + { + "DATABASE_URL": database_url, + "ALLOWED_POLICY_FILE": str(policy_path), + } + ) + assert settings.database_url == expected_url + + def test_policy_invalid_format_no_colon_raises(tmp_path: Path) -> None: policy_path = tmp_path / "policy.txt" write_policy(policy_path, "customers id email\n") @@ -133,3 +172,93 @@ def test_normalize_log_level(tmp_path: Path) -> None: } ) assert settings.log_level == "DEBUG" + + +def test_opa_acl_data_file_preferred_over_allowed_policy(tmp_path: Path) -> None: + policy_path = tmp_path / "policy.txt" + write_policy(policy_path, "customers:id\n") + opa_acl_path = tmp_path / "acl.json" + opa_acl_path.write_text( + """ + { + "secure_sql": { + "acl": { + "tables": { + "orders": {"columns": ["*"]} + } + } + } + } + """, + encoding="utf-8", + ) + + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./tmp.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "OPA_ACL_DATA_FILE": str(opa_acl_path), + } + ) + + assert settings.allowed_policy == {"customers": {"id"}} + assert settings.effective_acl_policy == {"orders": {"*"}} + + +def test_invalid_opa_acl_json_raises(tmp_path: Path) -> None: + policy_path = tmp_path / "policy.txt" + write_policy(policy_path, "customers:id\n") + opa_acl_path = tmp_path / "acl.json" + opa_acl_path.write_text("{not-json", encoding="utf-8") + + with pytest.raises(ValidationError): + Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./tmp.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "OPA_ACL_DATA_FILE": str(opa_acl_path), + } + ) + + +def test_write_mode_flags_default_to_false(tmp_path: Path) -> None: + policy_path = tmp_path / "policy.txt" + write_policy(policy_path, "customers:id\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./tmp.db", + "ALLOWED_POLICY_FILE": str(policy_path), + } + ) + assert settings.write_mode_enabled is False + assert settings.allow_insert is False + assert settings.allow_update is False + assert settings.allow_delete is False + assert settings.require_where_for_update is True + assert settings.require_where_for_delete is True + assert settings.allow_returning is False + + +def test_write_mode_flags_can_be_enabled(tmp_path: Path) -> None: + policy_path = tmp_path / "policy.txt" + write_policy(policy_path, "customers:id\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./tmp.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "WRITE_MODE_ENABLED": True, + "ALLOW_INSERT": True, + "ALLOW_UPDATE": True, + "ALLOW_DELETE": True, + "REQUIRE_WHERE_FOR_UPDATE": False, + "REQUIRE_WHERE_FOR_DELETE": False, + "ALLOW_RETURNING": True, + } + ) + assert settings.write_mode_enabled is True + assert settings.allow_insert is True + assert settings.allow_update is True + assert settings.allow_delete is True + assert settings.require_where_for_update is False + assert settings.require_where_for_delete is False + assert settings.allow_returning is True diff --git a/tests/test_mcp_interface.py b/tests/test_mcp_interface.py index 96325f8..b7a8a3c 100644 --- a/tests/test_mcp_interface.py +++ b/tests/test_mcp_interface.py @@ -2,8 +2,11 @@ import asyncio import json +import os import sqlite3 +import sys from pathlib import Path +from unittest.mock import AsyncMock import pytest from pydantic import ValidationError @@ -41,7 +44,9 @@ def app_state(tmp_path: Path): ) db = AsyncDatabase(settings) asyncio.run(db.connect()) - state = AppState(settings=settings, db=db, validator=QueryValidator(settings)) + state = AppState( + settings=settings, db=db, validator=QueryValidator(settings), policy_engine=None + ) mcp_server.STATE = state try: @@ -90,7 +95,92 @@ def limited_app_state(tmp_path: Path): ) db = AsyncDatabase(settings) asyncio.run(db.connect()) - state = AppState(settings=settings, db=db, validator=QueryValidator(settings)) + state = AppState( + settings=settings, db=db, validator=QueryValidator(settings), policy_engine=None + ) + mcp_server.STATE = state + + try: + yield state + finally: + asyncio.run(db.dispose()) + mcp_server.STATE = None + + +@pytest.fixture() +def write_enabled_app_state(tmp_path: Path): + db_path = tmp_path / "write_enabled.db" + init_sqlite_db(db_path) + + policy_path = tmp_path / "allowed_policy.txt" + write_policy( + policy_path, + """ + customers:id,email + orders:* + """, + ) + + settings = Settings.model_validate( + { + "DATABASE_URL": f"sqlite+aiosqlite:///{db_path}", + "ALLOWED_POLICY_FILE": str(policy_path), + "WRITE_MODE_ENABLED": True, + "ALLOW_INSERT": True, + "ALLOW_UPDATE": True, + "ALLOW_DELETE": True, + "MAX_ROWS": 100, + "QUERY_TIMEOUT": 30, + "LOG_LEVEL": "INFO", + } + ) + db = AsyncDatabase(settings) + asyncio.run(db.connect()) + state = AppState( + settings=settings, db=db, validator=QueryValidator(settings), policy_engine=None + ) + mcp_server.STATE = state + + try: + yield state + finally: + asyncio.run(db.dispose()) + mcp_server.STATE = None + + +@pytest.fixture() +def write_enabled_returning_app_state(tmp_path: Path): + db_path = tmp_path / "write_enabled_returning.db" + init_sqlite_db(db_path) + + policy_path = tmp_path / "allowed_policy.txt" + write_policy( + policy_path, + """ + customers:id,email + orders:* + """, + ) + + settings = Settings.model_validate( + { + "DATABASE_URL": f"sqlite+aiosqlite:///{db_path}", + "ALLOWED_POLICY_FILE": str(policy_path), + "WRITE_MODE_ENABLED": True, + "ALLOW_INSERT": True, + "ALLOW_UPDATE": True, + "ALLOW_DELETE": True, + "ALLOW_RETURNING": True, + "MAX_ROWS": 100, + "QUERY_TIMEOUT": 30, + "LOG_LEVEL": "INFO", + } + ) + db = AsyncDatabase(settings) + asyncio.run(db.connect()) + state = AppState( + settings=settings, db=db, validator=QueryValidator(settings), policy_engine=None + ) mcp_server.STATE = state try: @@ -166,6 +256,14 @@ def test_query_blocks_mutation_operations(app_state: AppState, sql: str, operati assert operation in response +def test_query_blocks_insert_select_when_write_mode_disabled(app_state: AppState) -> None: + response = asyncio.run( + mcp_server.query("INSERT INTO orders (id, total) SELECT id, id FROM orders") + ) + assert "read-only access" in response + assert "INSERT" in response + + def test_query_blocks_multi_statement_payload(app_state: AppState) -> None: response = asyncio.run(mcp_server.query("SELECT id FROM customers; DROP TABLE customers")) assert "Only a single SQL statement is allowed" in response @@ -318,3 +416,159 @@ async def _raise_db_error(_: str) -> object: assert "describe_table" in response assert "supersecret" not in response assert "internal-db" not in response + + +def test_prepare_read_only_session_mysql_sets_timeout_and_read_only(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "mysql://user:pass@localhost:3306/appdb", + "ALLOWED_POLICY_FILE": str(policy_path), + "QUERY_TIMEOUT": 12, + } + ) + db = AsyncDatabase(settings) + fake_conn = AsyncMock() + + asyncio.run(db._prepare_read_only_session(fake_conn)) + + executed_sql = [str(call.args[0]) for call in fake_conn.execute.await_args_list] + assert executed_sql == [ + "SET SESSION MAX_EXECUTION_TIME = 12000", + "START TRANSACTION READ ONLY", + ] + + +def test_query_allows_insert_when_write_mode_enabled(write_enabled_app_state: AppState) -> None: + response = asyncio.run( + mcp_server.query("INSERT INTO customers (id, email) VALUES (2, 'b@example.com')") + ) + payload = json.loads(response) + assert payload["status"] == "ok" + assert payload["operation"] == "insert" + assert payload["affected_rows"] == 1 + + +def test_query_allows_update_with_where_when_write_mode_enabled( + write_enabled_app_state: AppState, +) -> None: + response = asyncio.run( + mcp_server.query("UPDATE customers SET email = 'c@example.com' WHERE id = 1") + ) + payload = json.loads(response) + assert payload["status"] == "ok" + assert payload["operation"] == "update" + assert payload["affected_rows"] == 1 + + verify_response = asyncio.run(mcp_server.query("SELECT email FROM customers WHERE id = 1")) + verify_payload = json.loads(verify_response) + assert verify_payload["rows"][0]["email"] == "c@example.com" + + +def test_query_blocks_update_without_where_even_when_write_enabled( + write_enabled_app_state: AppState, +) -> None: + response = asyncio.run(mcp_server.query("UPDATE customers SET email = 'x@example.com'")) + assert "UPDATE without a WHERE clause is not allowed" in response + + +def test_query_blocks_tautological_where_even_when_write_enabled( + write_enabled_app_state: AppState, +) -> None: + response = asyncio.run(mcp_server.query("DELETE FROM customers WHERE 1 = 1")) + assert "WHERE clause appears tautological" in response + + +def test_query_blocks_insert_from_disallowed_source_table_even_when_write_enabled( + write_enabled_app_state: AppState, +) -> None: + response = asyncio.run( + mcp_server.query("INSERT INTO orders (id, total) SELECT s.id, s.id FROM secrets AS s") + ) + assert "Access to table 'secrets' is restricted" in response + + +def test_query_blocks_returning_by_default_even_when_write_enabled( + write_enabled_app_state: AppState, +) -> None: + response = asyncio.run( + mcp_server.query( + "UPDATE customers SET email = 'x@example.com' WHERE id = 1 RETURNING email" + ) + ) + assert "RETURNING is not allowed" in response + + +def test_query_blocks_restricted_returning_column_when_allowed( + write_enabled_returning_app_state: AppState, +) -> None: + response = asyncio.run( + mcp_server.query("UPDATE customers SET email = 'x@example.com' WHERE id = 1 RETURNING ssn") + ) + assert "RETURNING column(s) ssn" in response + + +def test_write_query_timeout_returns_actionable_message( + write_enabled_app_state: AppState, monkeypatch: pytest.MonkeyPatch +) -> None: + async def _raise_timeout(_: str) -> object: + raise TimeoutError() + + monkeypatch.setattr(write_enabled_app_state.db, "execute_write_query", _raise_timeout) + response = asyncio.run( + mcp_server.query("UPDATE customers SET email = 'x@example.com' WHERE id = 1") + ) + assert ( + f"Query exceeded the {write_enabled_app_state.settings.query_timeout}-second timeout" + in response + ) + + +def test_write_query_db_error_message_does_not_leak_sensitive_details( + write_enabled_app_state: AppState, monkeypatch: pytest.MonkeyPatch +) -> None: + async def _raise_db_error(_: str) -> object: + raise RuntimeError("password=supersecret host=internal-db") + + monkeypatch.setattr(write_enabled_app_state.db, "execute_write_query", _raise_db_error) + response = asyncio.run( + mcp_server.query("UPDATE customers SET email = 'x@example.com' WHERE id = 1") + ) + assert "Query execution failed with a database error" in response + assert "supersecret" not in response + assert "internal-db" not in response + + +def test_main_cli_flags_set_write_mode_env(monkeypatch: pytest.MonkeyPatch) -> None: + captured: dict[str, str] = {} + keys = ("WRITE_MODE_ENABLED", "ALLOW_INSERT", "ALLOW_UPDATE", "ALLOW_DELETE") + previous = {key: os.environ.get(key) for key in keys} + + def _fake_run(*_: object, **__: object) -> None: + captured["WRITE_MODE_ENABLED"] = os.environ.get("WRITE_MODE_ENABLED", "") + captured["ALLOW_INSERT"] = os.environ.get("ALLOW_INSERT", "") + captured["ALLOW_UPDATE"] = os.environ.get("ALLOW_UPDATE", "") + captured["ALLOW_DELETE"] = os.environ.get("ALLOW_DELETE", "") + + monkeypatch.setattr(mcp_server.mcp, "run", _fake_run) + monkeypatch.setattr( + sys, + "argv", + ["secure-sql-mcp", "--write-mode", "--allow-insert", "--allow-update", "--allow-delete"], + ) + + try: + mcp_server.main() + assert captured == { + "WRITE_MODE_ENABLED": "true", + "ALLOW_INSERT": "true", + "ALLOW_UPDATE": "true", + "ALLOW_DELETE": "true", + } + finally: + for key, value in previous.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value diff --git a/tests/test_mcp_stdio_security.py b/tests/test_mcp_stdio_security.py index 22a3865..f38d54b 100644 --- a/tests/test_mcp_stdio_security.py +++ b/tests/test_mcp_stdio_security.py @@ -20,7 +20,7 @@ def _first_text(call_result: object) -> str: return "" -def _server_params(tmp_path: Path) -> StdioServerParameters: +def _server_params(tmp_path: Path, *, write_mode: bool = False) -> StdioServerParameters: db_path = tmp_path / "test.db" policy_path = tmp_path / "allowed_policy.txt" init_sqlite_db(db_path) @@ -36,6 +36,15 @@ def _server_params(tmp_path: Path) -> StdioServerParameters: "LOG_LEVEL": "INFO", } ) + if write_mode: + env.update( + { + "WRITE_MODE_ENABLED": "true", + "ALLOW_INSERT": "true", + "ALLOW_UPDATE": "true", + "ALLOW_DELETE": "true", + } + ) return StdioServerParameters( command=sys.executable, @@ -98,3 +107,50 @@ async def _run() -> None: assert "Only a single SQL statement is allowed" in message asyncio.run(_run()) + + +def test_mcp_stdio_blocks_insert_select_when_write_disabled(tmp_path: Path) -> None: + async def _run() -> None: + async with stdio_client(_server_params(tmp_path)) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + result = await session.call_tool( + "query", + {"sql": "INSERT INTO orders (id, total) SELECT id, id FROM orders"}, + ) + message = _first_text(result) + assert "read-only access" in message + assert "INSERT" in message + + asyncio.run(_run()) + + +def test_mcp_stdio_write_mode_allows_insert(tmp_path: Path) -> None: + async def _run() -> None: + async with stdio_client(_server_params(tmp_path, write_mode=True)) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + result = await session.call_tool( + "query", + {"sql": "INSERT INTO customers (id, email) VALUES (2, 'b@example.com')"}, + ) + payload = json.loads(_first_text(result)) + assert payload["status"] == "ok" + assert payload["operation"] == "insert" + assert payload["affected_rows"] == 1 + + asyncio.run(_run()) + + +def test_mcp_stdio_write_mode_blocks_tautological_delete(tmp_path: Path) -> None: + async def _run() -> None: + async with stdio_client(_server_params(tmp_path, write_mode=True)) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + result = await session.call_tool( + "query", {"sql": "DELETE FROM customers WHERE 1 = 1"} + ) + message = _first_text(result) + assert "WHERE clause appears tautological" in message + + asyncio.run(_run()) diff --git a/tests/test_opa_policy.py b/tests/test_opa_policy.py new file mode 100644 index 0000000..22537c7 --- /dev/null +++ b/tests/test_opa_policy.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, cast +from urllib import error + +from secure_sql_mcp.config import Settings +from secure_sql_mcp.opa_policy import OpaPolicyEngine, PolicyDecision +from secure_sql_mcp.query_validator import QueryValidator +from tests.conftest import write_policy + + +class _FakeResponse: + def __init__(self, payload: dict[str, object]) -> None: + self._payload = payload + self.status = 200 + + def read(self) -> bytes: + return json.dumps(self._payload).encode("utf-8") + + def __enter__(self) -> _FakeResponse: + return self + + def __exit__(self, *_: object) -> None: + return None + + +class _CaptureEngine: + def __init__(self, decision: PolicyDecision) -> None: + self.decision = decision + self.last_payload: dict[str, object] | None = None + + def evaluate_sync(self, payload: dict[str, object]) -> PolicyDecision: + self.last_payload = payload + return self.decision + + +def _settings(tmp_path: Path) -> Settings: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id,email\norders:*\n") + return Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./tmp.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "OPA_URL": "http://127.0.0.1:8181", + "OPA_FAIL_CLOSED": True, + } + ) + + +def test_opa_engine_fail_closed_on_transport_error(tmp_path: Path, monkeypatch) -> None: + settings = _settings(tmp_path) + engine = OpaPolicyEngine(settings) + + def _raise_url_error(*_: object, **__: object): + raise error.URLError("connection refused") + + monkeypatch.setattr("secure_sql_mcp.opa_policy.request.urlopen", _raise_url_error) + decision = engine.evaluate_sync({"tool": {"name": "query"}, "query": {"statement_count": 1}}) + + assert decision.allow is False + assert "opa_unavailable" in decision.deny_reasons + + +def test_opa_engine_parses_decision_payload(tmp_path: Path, monkeypatch) -> None: + settings = _settings(tmp_path) + engine = OpaPolicyEngine(settings) + + def _ok(*_: object, **__: object): + return _FakeResponse({"result": {"allow": False, "deny_reasons": ["table_restricted"]}}) + + monkeypatch.setattr("secure_sql_mcp.opa_policy.request.urlopen", _ok) + decision = engine.evaluate_sync({"tool": {"name": "query"}, "query": {"statement_count": 1}}) + + assert decision.allow is False + assert decision.deny_reasons == ["table_restricted"] + assert "restricted" in (decision.message or "") + + +def test_validator_builds_policy_input_for_opa(tmp_path: Path) -> None: + settings = _settings(tmp_path) + capture_engine = _CaptureEngine(PolicyDecision(allow=True)) + validator = QueryValidator(settings, policy_engine=cast(Any, capture_engine)) + + result = validator.validate_query( + "SELECT c.id, o.total FROM customers c JOIN orders o ON c.id = o.id" + ) + + assert result.ok + assert capture_engine.last_payload is not None + payload = cast(dict[str, Any], capture_engine.last_payload) + assert payload["tool"] == {"name": "query"} + assert payload["query"]["statement_count"] == 1 + assert sorted(payload["query"]["referenced_tables"]) == ["customers", "orders"] + assert payload["config"]["write_mode_enabled"] is False + assert payload["config"]["allow_returning"] is False + assert payload["config"]["require_where_for_update"] is True + assert payload["config"]["require_where_for_delete"] is True + assert payload["query"]["is_write_statement"] is False + + +def test_validator_builds_write_policy_input_for_opa(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id,email\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./tmp.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "OPA_URL": "http://127.0.0.1:8181", + "OPA_FAIL_CLOSED": True, + "WRITE_MODE_ENABLED": True, + "ALLOW_UPDATE": True, + } + ) + capture_engine = _CaptureEngine(PolicyDecision(allow=True)) + validator = QueryValidator(settings, policy_engine=cast(Any, capture_engine)) + + result = validator.validate_query("UPDATE customers SET email = 'x@example.com' WHERE id = 1") + assert result.ok + assert capture_engine.last_payload is not None + payload = cast(dict[str, Any], capture_engine.last_payload) + assert payload["query"]["is_write_statement"] is True + assert payload["query"]["statement_type"] == "update" + assert payload["query"]["target_table"] == "customers" + assert payload["query"]["updated_columns"] == ["email"] + assert payload["query"]["where_present"] is True + assert payload["config"]["write_mode_enabled"] is True + assert payload["config"]["allow_update"] is True + + +def test_validator_marks_insert_select_as_write_for_opa(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id,email\norders:*\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./tmp.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "OPA_URL": "http://127.0.0.1:8181", + "WRITE_MODE_ENABLED": True, + "ALLOW_INSERT": True, + } + ) + capture_engine = _CaptureEngine(PolicyDecision(allow=True)) + validator = QueryValidator(settings, policy_engine=cast(Any, capture_engine)) + result = validator.validate_query("INSERT INTO orders (id, total) SELECT id, id FROM customers") + assert result.ok + assert capture_engine.last_payload is not None + payload = cast(dict[str, Any], capture_engine.last_payload) + assert payload["query"]["is_write_statement"] is True + assert payload["query"]["statement_type"] == "insert" + + +def test_validator_includes_star_tables_for_insert_select_star_opa(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id,email\norders:*\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./tmp.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "OPA_URL": "http://127.0.0.1:8181", + "WRITE_MODE_ENABLED": True, + "ALLOW_INSERT": True, + } + ) + capture_engine = _CaptureEngine(PolicyDecision(allow=True)) + validator = QueryValidator(settings, policy_engine=cast(Any, capture_engine)) + result = validator.validate_query("INSERT INTO orders (id, total) SELECT * FROM customers") + assert result.ok + assert capture_engine.last_payload is not None + payload = cast(dict[str, Any], capture_engine.last_payload) + assert payload["query"]["is_write_statement"] is True + assert "customers" in payload["query"]["star_tables"] + + +def test_opa_engine_maps_write_deny_reason_to_message(tmp_path: Path, monkeypatch) -> None: + settings = _settings(tmp_path) + engine = OpaPolicyEngine(settings) + + def _ok(*_: object, **__: object): + return _FakeResponse( + {"result": {"allow": False, "deny_reasons": ["missing_where_on_update"]}} + ) + + monkeypatch.setattr("secure_sql_mcp.opa_policy.request.urlopen", _ok) + decision = engine.evaluate_sync({"tool": {"name": "query"}, "query": {"statement_count": 1}}) + + assert decision.allow is False + assert decision.deny_reasons == ["missing_where_on_update"] + assert decision.message == "UPDATE without a WHERE clause is not allowed." diff --git a/tests/test_query_validator_security.py b/tests/test_query_validator_security.py index 4465275..144790c 100644 --- a/tests/test_query_validator_security.py +++ b/tests/test_query_validator_security.py @@ -151,3 +151,184 @@ def test_validator_blocks_except_with_disallowed_table(validator: QueryValidator result = validator.validate_query("SELECT id FROM customers EXCEPT SELECT id FROM secrets") assert not result.ok assert "Access to table 'secrets' is restricted" in (result.error or "") + + +def test_validator_uses_mysql_dialect_for_mysql_url(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy( + policy_path, + """ + customers:id,email + """, + ) + settings = Settings.model_validate( + { + "DATABASE_URL": "mysql://user:pass@localhost:3306/appdb", + "ALLOWED_POLICY_FILE": str(policy_path), + } + ) + validator = QueryValidator(settings) + + assert validator._dialect == "mysql" + + +def test_validator_allows_insert_when_write_mode_enabled(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id,email\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./validator.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "WRITE_MODE_ENABLED": True, + "ALLOW_INSERT": True, + } + ) + validator = QueryValidator(settings) + result = validator.validate_query( + "INSERT INTO customers (id, email) VALUES (2, 'b@example.com')" + ) + assert result.ok + assert result.statement_type == "insert" + + +def test_validator_blocks_update_without_where_in_write_mode(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id,email\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./validator.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "WRITE_MODE_ENABLED": True, + "ALLOW_UPDATE": True, + } + ) + validator = QueryValidator(settings) + result = validator.validate_query("UPDATE customers SET email = 'x@example.com'") + assert not result.ok + assert "UPDATE without a WHERE clause is not allowed" in (result.error or "") + + +def test_validator_blocks_tautological_where_in_write_mode(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id,email\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./validator.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "WRITE_MODE_ENABLED": True, + "ALLOW_DELETE": True, + } + ) + validator = QueryValidator(settings) + result = validator.validate_query("DELETE FROM customers WHERE 1 = 1") + assert not result.ok + assert "WHERE clause appears tautological" in (result.error or "") + + +def test_validator_blocks_insert_select_when_write_mode_disabled(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id,email\norders:*\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./validator.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "WRITE_MODE_ENABLED": False, + "ALLOW_INSERT": True, + } + ) + validator = QueryValidator(settings) + result = validator.validate_query("INSERT INTO orders (id, total) SELECT id, id FROM customers") + assert not result.ok + assert "configured for read-only access" in (result.error or "") + + +def test_validator_blocks_update_with_subquery_when_update_disabled(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id,email\norders:*\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./validator.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "WRITE_MODE_ENABLED": True, + "ALLOW_UPDATE": False, + } + ) + validator = QueryValidator(settings) + result = validator.validate_query( + "UPDATE customers SET email = (SELECT 'x@example.com') WHERE id = 1" + ) + assert not result.ok + assert "UPDATE operations are disabled" in (result.error or "") + + +def test_validator_blocks_delete_with_subquery_when_delete_disabled(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id,email\norders:*\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./validator.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "WRITE_MODE_ENABLED": True, + "ALLOW_DELETE": False, + } + ) + validator = QueryValidator(settings) + result = validator.validate_query( + "DELETE FROM customers WHERE id IN (SELECT id FROM customers)" + ) + assert not result.ok + assert "DELETE operations are disabled" in (result.error or "") + + +def test_validator_blocks_returning_by_default(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id,email\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./validator.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "WRITE_MODE_ENABLED": True, + "ALLOW_UPDATE": True, + } + ) + validator = QueryValidator(settings) + result = validator.validate_query( + "UPDATE customers SET email = 'x@example.com' WHERE id = 1 RETURNING email" + ) + assert not result.ok + assert "RETURNING is not allowed" in (result.error or "") + + +def test_validator_blocks_insert_select_star_from_non_wildcard_source(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id,email\norders:*\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./validator.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "WRITE_MODE_ENABLED": True, + "ALLOW_INSERT": True, + } + ) + validator = QueryValidator(settings) + result = validator.validate_query("INSERT INTO orders (id, total) SELECT * FROM customers") + assert not result.ok + assert "SELECT * is not allowed for table 'customers'" in (result.error or "") + + +def test_validator_accepts_qualified_target_table_with_short_policy_name(tmp_path: Path) -> None: + policy_path = tmp_path / "allowed_policy.txt" + write_policy(policy_path, "customers:id,email\n") + settings = Settings.model_validate( + { + "DATABASE_URL": "sqlite+aiosqlite:///./validator.db", + "ALLOWED_POLICY_FILE": str(policy_path), + "WRITE_MODE_ENABLED": True, + "ALLOW_UPDATE": True, + } + ) + validator = QueryValidator(settings) + result = validator.validate_query( + "UPDATE main.customers SET email = 'x@example.com' WHERE id = 1" + ) + assert result.ok diff --git a/tests/test_write_facts.py b/tests/test_write_facts.py new file mode 100644 index 0000000..91a81dd --- /dev/null +++ b/tests/test_write_facts.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +from pathlib import Path + +import sqlglot + +from secure_sql_mcp.config import Settings +from secure_sql_mcp.query_validator import QueryValidator +from tests.conftest import write_policy + + +def _validator(tmp_path: Path, **overrides: object) -> QueryValidator: + policy_path = tmp_path / "allowed_policy.txt" + write_policy( + policy_path, + """ + customers:id,email + orders:* + """, + ) + payload: dict[str, object] = { + "DATABASE_URL": "sqlite+aiosqlite:///./write-facts.db", + "ALLOWED_POLICY_FILE": str(policy_path), + } + payload.update(overrides) + settings = Settings.model_validate(payload) + return QueryValidator(settings) + + +def test_extract_insert_write_facts(tmp_path: Path) -> None: + validator = _validator(tmp_path, WRITE_MODE_ENABLED=True, ALLOW_INSERT=True) + statement = sqlglot.parse_one( + "INSERT INTO customers (id, email) VALUES (2, 'b@example.com')", + read=validator._dialect, + ) + facts = validator._extract_write_facts(statement) + assert facts is not None + assert facts.statement_type == "insert" + assert facts.target_table == "customers" + assert facts.insert_columns == ["email", "id"] + assert facts.source_tables == [] + + +def test_extract_update_write_facts_with_tautological_where(tmp_path: Path) -> None: + validator = _validator(tmp_path, WRITE_MODE_ENABLED=True, ALLOW_UPDATE=True) + statement = sqlglot.parse_one( + "UPDATE customers SET email = 'x@example.com' WHERE 1 = 1", + read=validator._dialect, + ) + facts = validator._extract_write_facts(statement) + assert facts is not None + assert facts.statement_type == "update" + assert facts.where_present is True + assert facts.where_tautological is True + assert facts.updated_columns == ["email"] + + +def test_extract_delete_write_facts(tmp_path: Path) -> None: + validator = _validator(tmp_path, WRITE_MODE_ENABLED=True, ALLOW_DELETE=True) + statement = sqlglot.parse_one( + "DELETE FROM customers WHERE id = 1", + read=validator._dialect, + ) + facts = validator._extract_write_facts(statement) + assert facts is not None + assert facts.statement_type == "delete" + assert facts.where_present is True + assert facts.where_tautological is False + + +def test_extract_returning_columns(tmp_path: Path) -> None: + validator = _validator( + tmp_path, + WRITE_MODE_ENABLED=True, + ALLOW_UPDATE=True, + ALLOW_RETURNING=True, + ) + statement = sqlglot.parse_one( + "UPDATE customers SET email = 'x@example.com' WHERE id = 1 RETURNING email", + read=validator._dialect, + ) + facts = validator._extract_write_facts(statement) + assert facts is not None + assert facts.returning_present is True + assert facts.returning_columns == ["email"] + + +def test_extract_insert_select_source_tables(tmp_path: Path) -> None: + validator = _validator(tmp_path, WRITE_MODE_ENABLED=True, ALLOW_INSERT=True) + statement = sqlglot.parse_one( + "INSERT INTO orders (id, total) SELECT id, total FROM orders", + read=validator._dialect, + ) + facts = validator._extract_write_facts(statement) + assert facts is not None + assert facts.has_select_source is True + assert facts.source_tables == ["orders"] + + +def test_write_mode_disabled_blocks_writes(tmp_path: Path) -> None: + validator = _validator(tmp_path) + result = validator.validate_query( + "INSERT INTO customers (id, email) VALUES (2, 'b@example.com')" + ) + assert not result.ok + assert "configured for read-only access" in (result.error or "") + + +def test_allow_insert_flag_controls_insert(tmp_path: Path) -> None: + validator = _validator(tmp_path, WRITE_MODE_ENABLED=True, ALLOW_INSERT=False) + result = validator.validate_query( + "INSERT INTO customers (id, email) VALUES (2, 'b@example.com')" + ) + assert not result.ok + assert "INSERT operations are disabled by server configuration" in (result.error or "") + + +def test_enable_insert_does_not_enable_update(tmp_path: Path) -> None: + validator = _validator(tmp_path, WRITE_MODE_ENABLED=True, ALLOW_INSERT=True, ALLOW_UPDATE=False) + result = validator.validate_query("UPDATE customers SET email = 'x@example.com' WHERE id = 1") + assert not result.ok + assert "UPDATE operations are disabled by server configuration" in (result.error or "") + + +def test_ddl_still_blocked_when_write_mode_enabled(tmp_path: Path) -> None: + validator = _validator( + tmp_path, + WRITE_MODE_ENABLED=True, + ALLOW_INSERT=True, + ALLOW_UPDATE=True, + ALLOW_DELETE=True, + ) + result = validator.validate_query("DROP TABLE customers") + assert not result.ok + assert "read-only access" in (result.error or "") + assert "DROP" in (result.error or "")