diff --git a/.gitignore b/.gitignore index 31b7ff6..b91377b 100644 --- a/.gitignore +++ b/.gitignore @@ -109,7 +109,7 @@ ipython_config.py # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. -#uv.lock +uv.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. diff --git a/README.md b/README.md index 5ca65df..17a3495 100644 --- a/README.md +++ b/README.md @@ -115,9 +115,34 @@ Use this when you need a practical validation loop around a phenotype. ### Install ```bash -pip install -e . +pip install -e ".[dev]" ``` +## Dependency Management + +The project currently uses a simple split: + +- `pyproject.toml` defines the Python package, runtime dependencies, console scripts, and optional dev tools. +- `environment.yml` bootstraps a Conda or Micromamba environment with the Python tooling commonly used in this repo. +- `uv.lock` is not tracked as a repo source of truth. If you use `uv` locally, generate your own lockfile after cloning. + +Official local workflow: + +```bash +conda env create -f environment.yml +conda activate study-agent +pip install -e ".[dev]" +``` + +Optional `uv` workflow for users who prefer it: + +```bash +uv lock +uv run pytest +``` + +The repo does not currently require `uv`, and Docker still builds from `environment.yml` plus an editable install. + ### Start MCP over HTTP ```bash @@ -230,4 +255,3 @@ The repository still contains broader plans that are not the main implemented st - expansion toward a larger study-agent service catalog The planned-service inventory in older docs should not be read as "fully available now". - diff --git a/core/study_agent_core/logging_utils.py b/core/study_agent_core/logging_utils.py index 20a944a..b214d0b 100644 --- a/core/study_agent_core/logging_utils.py +++ b/core/study_agent_core/logging_utils.py @@ -2,10 +2,11 @@ import logging import os +import re import sys from logging.handlers import RotatingFileHandler from pathlib import Path -from typing import Literal +from typing import Any, Literal _LEVELS = { @@ -19,6 +20,108 @@ "OFF": logging.CRITICAL + 10, } +_SENSITIVE_KEY_NAMES = ( + "password", + "passwd", + "pwd", + "secret", + "token", + "api_key", + "apikey", + "access_key", + "access_token", + "refresh_token", + "authorization", + "bearer", + "dsn", + "connection_string", + "database_url", + "person_id", + "personid", + "patient_id", + "subject_id", + "visit_id", + "mrn", + "medical_record_number", +) + +_URI_CREDENTIALS_RE = re.compile(r"([a-z][a-z0-9+.\-]*://)([^/\s:@]+)(?::([^@/\s]*))?@", re.IGNORECASE) +_BEARER_RE = re.compile(r"\b(Bearer)\s+[A-Za-z0-9._~+/=-]+\b", re.IGNORECASE) +_KV_SECRET_RE = re.compile( + r"(?i)\b(password|passwd|pwd|secret|token|api[_-]?key|access[_-]?token|refresh[_-]?token|authorization)\b" + r"(\s*[:=]\s*)([^\s,;]+)" +) +_EMAIL_RE = re.compile(r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}") +_PHONE_RE = re.compile(r"\b\+?\d{1,2}[\s.-]?\(?\d{3}\)?[\s.-]?\d{3}[\s.-]?\d{4}\b") +_DATE_RE = re.compile(r"\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b|\b\d{4}-\d{1,2}-\d{1,2}\b") +_SSN_RE = re.compile(r"\b\d{3}-\d{2}-\d{4}\b") +_MRN_RE = re.compile(r"(?i)\b(mrn|medical_record_number|person_id|personid|subject_id|patient_id)\b(\s*[:=]\s*)([^\s,;]+)") + + +def _sanitize_string(text: str) -> str: + value = str(text) + value = _URI_CREDENTIALS_RE.sub(r"\1[REDACTED_CREDENTIALS]@", value) + value = _BEARER_RE.sub(r"\1 [REDACTED_TOKEN]", value) + value = _KV_SECRET_RE.sub(r"\1\2[REDACTED]", value) + value = _EMAIL_RE.sub("[REDACTED_EMAIL]", value) + value = _PHONE_RE.sub("[REDACTED_PHONE]", value) + value = _DATE_RE.sub("[REDACTED_DATE]", value) + value = _SSN_RE.sub("[REDACTED_SSN]", value) + value = _MRN_RE.sub(r"\1\2[REDACTED_ID]", value) + return value + + +def _is_sensitive_key(key: Any) -> bool: + key_norm = re.sub(r"[^a-z0-9]+", "_", str(key).strip().lower()).strip("_") + return key_norm in _SENSITIVE_KEY_NAMES + + +def _sanitize_field(key: Any, value: Any, depth: int) -> Any: + if _is_sensitive_key(key): + return "[REDACTED]" + return sanitize_log_value(value, depth + 1) + + +def sanitize_log_value(value: Any, depth: int = 0) -> Any: + if depth > 4: + return _sanitize_string(repr(value)) + if value is None or isinstance(value, (bool, int, float)): + return value + if isinstance(value, str): + return _sanitize_string(value) + if isinstance(value, dict): + sanitized: dict[Any, Any] = {} + for key, inner in value.items(): + key_text = _sanitize_string(str(key)) + sanitized[key_text] = _sanitize_field(key, inner, depth) + return sanitized + if isinstance(value, tuple): + return tuple(sanitize_log_value(item, depth + 1) for item in value) + if isinstance(value, list): + return [sanitize_log_value(item, depth + 1) for item in value] + if isinstance(value, set): + return {sanitize_log_value(item, depth + 1) for item in value} + return _sanitize_string(str(value)) + + +def format_log_kv(fields: dict[str, Any]) -> str: + parts = [] + for key in sorted(fields): + parts.append(f"{key}={_sanitize_field(key, fields[key], 0)}") + return " ".join(parts) + + +class SensitiveDataFilter(logging.Filter): + def filter(self, record: logging.LogRecord) -> bool: + record.msg = sanitize_log_value(record.msg) + if isinstance(record.args, dict): + record.args = {key: sanitize_log_value(value) for key, value in record.args.items()} + elif isinstance(record.args, tuple): + record.args = tuple(sanitize_log_value(value) for value in record.args) + elif record.args: + record.args = sanitize_log_value(record.args) + return True + def _parse_level(value: str | None, default: str) -> int: return _LEVELS.get(str(value or default).strip().upper(), _LEVELS[default]) @@ -83,6 +186,7 @@ def configure_service_logger( console_handler = logging.StreamHandler(sys.stdout if stream == "stdout" else sys.stderr) console_handler.setLevel(level) console_handler.setFormatter(formatter) + console_handler.addFilter(SensitiveDataFilter()) logger.addHandler(console_handler) log_path = _resolve_log_path(service_name, default_filename) @@ -98,7 +202,7 @@ def configure_service_logger( ) file_handler.setLevel(level) file_handler.setFormatter(formatter) + file_handler.addFilter(SensitiveDataFilter()) logger.addHandler(file_handler) return logger - diff --git a/docs/TESTING.md b/docs/TESTING.md index b4228c7..dcaae86 100644 --- a/docs/TESTING.md +++ b/docs/TESTING.md @@ -7,11 +7,17 @@ This repo uses lightweight CLI smoke tests for the ACP and MCP layers. Keep thes Install the repo in editable mode so the CLI entrypoints are on your PATH and changes take effect immediately: ```bash -pip install -e . +pip install -e ".[dev]" ``` Editable mode means Python imports the local source tree directly. You do not need to reinstall after edits; just re-run the commands. Manage this per environment (venv/conda) and remove with `pip uninstall study-agent` if needed. +Dependency notes: + +- `pyproject.toml` is the source of truth for the Python package and the optional `dev` extras. +- `environment.yml` bootstraps the Conda or Micromamba environment used by Docker and many local setups. +- `uv.lock` is intentionally not tracked. If you prefer `uv`, generate a local lockfile after cloning with `uv lock`. + ## Test output verbosity Use pytest's built-in verbosity: diff --git a/dodo.py b/dodo.py index cd69822..7136ade 100644 --- a/dodo.py +++ b/dodo.py @@ -138,7 +138,16 @@ def task_test_all(): def task_run_all(): return { "actions": None, - "task_dep": ["test_all","smoke_phenotype_recommend_flow", "smoke_phenotype_intent_split_flow", "smoke_phenotype_recommendation_advice_flow", "smoke_phenotype_improvements_flow", "smoke_concept_sets_review_flow", "smoke_cohort_critique_flow"], + "task_dep": [ + "test_all", + "smoke_phenotype_recommend_flow", + "smoke_phenotype_intent_split_flow", + "smoke_phenotype_recommendation_advice_flow", + "smoke_phenotype_improvements_flow", + "smoke_concept_sets_review_flow", + "smoke_cohort_critique_flow", + "smoke_case_causal_review_flow", + ], } @@ -779,3 +788,14 @@ def _run_smoke() -> None: "actions": [_run_smoke], "verbosity": 2, } + + +def task_smoke_case_causal_review_flow(): + def _run_smoke() -> None: + print("Running case causal review flow smoke test...") + subprocess.run(["python", "tests/case_causal_review_flow_smoke_test.py"], check=True) + + return { + "actions": [_run_smoke], + "verbosity": 2, + } diff --git a/environment.yml b/environment.yml index 2d50e38..20f5f0c 100644 --- a/environment.yml +++ b/environment.yml @@ -13,4 +13,5 @@ dependencies: - pydantic - pyyaml - pytest + - ruff - requests diff --git a/mcp_server/study_agent_mcp/tools/_log.py b/mcp_server/study_agent_mcp/tools/_log.py index f9a3d11..3599a95 100644 --- a/mcp_server/study_agent_mcp/tools/_log.py +++ b/mcp_server/study_agent_mcp/tools/_log.py @@ -4,6 +4,8 @@ import os from typing import Any +from study_agent_core.logging_utils import format_log_kv + logger = logging.getLogger("study_agent.mcp") @@ -21,6 +23,6 @@ def log_debug(message: str, **fields: Any) -> None: if not _level_enabled("DEBUG"): return if fields: - logger.debug("%s %s", message, " ".join([f"{k}={v}" for k, v in fields.items()])) + logger.debug("%s %s", message, format_log_kv(fields)) else: logger.debug("%s", message) diff --git a/mcp_server/study_agent_mcp/tools/keeper_concept_sets.py b/mcp_server/study_agent_mcp/tools/keeper_concept_sets.py index 9dbaf1f..75dc6c3 100644 --- a/mcp_server/study_agent_mcp/tools/keeper_concept_sets.py +++ b/mcp_server/study_agent_mcp/tools/keeper_concept_sets.py @@ -1,5 +1,6 @@ from __future__ import annotations +import hashlib import json import logging import os @@ -22,6 +23,10 @@ logger = logging.getLogger("study_agent.mcp.keeper_concept_sets") +def _text_fingerprint(value: str) -> str: + return hashlib.sha256((value or "").encode("utf-8")).hexdigest()[:12] + + def _prompt_dir() -> str: return os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "prompts", "keeper_concept_sets")) @@ -262,8 +267,9 @@ def _search_standard_via_hecate( endpoint = rewrite_container_host_url(endpoint) timeout = int(os.getenv("VOCAB_SEARCH_TIMEOUT", "30")) logger.debug( - "vocab_search provider=hecate_api query=%s domains=%s concept_classes=%s limit=%s timeout=%s", - query, + "vocab_search provider=hecate_api query_len=%s query_sha=%s domains=%s concept_classes=%s limit=%s timeout=%s", + len(query or ""), + _text_fingerprint(query), domains, concept_classes, limit, @@ -297,7 +303,11 @@ def _search_standard_via_hecate( elif isinstance(item, dict): concept_rows.append(item) normalized = _dedupe_concepts(concept_rows) - logger.debug("vocab_search provider=hecate_api query=%s results=%s", query, len(normalized)) + logger.debug( + "vocab_search provider=hecate_api query_sha=%s results=%s", + _text_fingerprint(query), + len(normalized), + ) return {"concepts": normalized, "count": len(normalized), "provider": "hecate_api", "url": endpoint} @@ -323,8 +333,9 @@ def _search_standard_via_generic_api( endpoint = rewrite_container_host_url(endpoint) timeout = int(os.getenv("VOCAB_SEARCH_TIMEOUT", "30")) logger.debug( - "vocab_search provider=generic_search_api query=%s domains=%s concept_classes=%s limit=%s timeout=%s", - query, + "vocab_search provider=generic_search_api query_len=%s query_sha=%s domains=%s concept_classes=%s limit=%s timeout=%s", + len(query or ""), + _text_fingerprint(query), domains, concept_classes, limit, @@ -417,7 +428,6 @@ def _phoebe_via_db(concept_ids: List[int], relationship_ids: List[str] | None) - engine = create_engine_with_dependencies(engine_name, future=True) logger.debug( "phoebe provider=db engine= concept_ids=%s relationship_ids=%s", - engine_name, len(concept_ids), relationship_ids, ) @@ -472,7 +482,6 @@ def _phoebe_via_db(concept_ids: List[int], relationship_ids: List[str] | None) - filtered, controls = _apply_phoebe_expansion_controls(raw_deduped, relationship_ids) logger.debug( "phoebe provider=db engine= query_seconds=%.2f total_seconds=%.2f rows=%s raw_results=%s final_results=%s relationships=%s applied_relationship_ids=%s max_per_relationship=%s max_total=%s", - engine_name, query_seconds, time.perf_counter() - started, len(rows), @@ -507,7 +516,6 @@ def _fetch_concepts_via_db( engine = create_engine_with_dependencies(engine_name, future=True) logger.debug( "vocab_fetch provider=db engine= concept_ids=%s domains=%s concept_classes=%s require_standard=%s", - engine_name, len(concept_ids), domains, concept_classes, @@ -564,7 +572,6 @@ def _fetch_concepts_via_db( deduped = _dedupe_concepts(concepts) logger.debug( "vocab_fetch provider=db engine= query_seconds=%.2f total_seconds=%.2f rows=%s results=%s missing=%s", - engine_name, query_seconds, time.perf_counter() - started, len(rows), diff --git a/pyproject.toml b/pyproject.toml index 2e692e6..c08db21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,12 @@ dependencies = [ "PyYAML>=6.0", ] +[project.optional-dependencies] +dev = [ + "pytest>=9.0.0", + "ruff>=0.11.0", +] + [project.scripts] study-agent-mcp = "study_agent_mcp.server:main" study-agent-acp = "study_agent_acp.server:main" diff --git a/tests/case_causal_review_flow_smoke_test.py b/tests/case_causal_review_flow_smoke_test.py new file mode 100644 index 0000000..8fcbd8d --- /dev/null +++ b/tests/case_causal_review_flow_smoke_test.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import json +import sys + +from study_agent_acp.agent import StudyAgent +import study_agent_acp.agent as agent_module + + +class StubMCPClient: + def __init__(self) -> None: + self.calls = [] + + def list_tools(self): + return [] + + def call_tool(self, name, arguments): + self.calls.append((name, arguments)) + if name == "case_causal_review_sanitize_row": + case_row = arguments.get("case_row", {}) + candidate_items = list(case_row.get("candidate_items") or []) + context_items = list(case_row.get("context_items") or []) + return { + "sanitized_row": { + "case_id": case_row.get("case_id") or "", + "case_summary": case_row.get("case_summary") or "", + "index_event": case_row.get("index_event") or {}, + "candidate_items": candidate_items, + "candidate_items_by_domain": {"drug_exposures": candidate_items}, + "context_items": context_items, + "context_items_by_domain": {"labs": context_items} if context_items else {}, + "case_metadata": case_row.get("case_metadata") or {}, + "annotations": case_row.get("annotations") or {}, + "tool_hints": case_row.get("tool_hints") or {}, + }, + "diagnostics": {"sanitization_status": "ok"}, + } + if name == "case_causal_review_prompt_bundle": + return { + "overview": "overview", + "spec": "spec", + "output_schema": {"type": "object"}, + "system_prompt": "system", + } + if name == "get_case_review_drug_signal_details": + return { + "status": "ok", + "source_record_id": arguments.get("source_record_id"), + "adverse_event_concept_id": arguments.get("adverse_event_concept_id"), + "has_disproportional_signal": True, + } + if name == "get_case_review_report_literature_stub": + return { + "status": "ok", + "case_id": arguments.get("case_id"), + "literature_reference_present": True, + } + if name == "case_causal_review_build_prompt": + return { + "prompt": "main", + "prompt_payload": { + "task": "case_causal_review", + "adverse_event_name": arguments.get("adverse_event_name"), + "source_type": arguments.get("source_type"), + "allowed_domains": arguments.get("allowed_domains") or [], + "enrichment": arguments.get("enrichment") or {}, + }, + } + if name == "case_causal_review_parse_response": + return { + "candidates_by_domain": { + "drug_exposures": [ + { + "domain": "drug_exposures", + "label": "Warfarin", + "source_record_id": "drug-1", + "why_it_may_contribute": "Bleeding risk", + "confidence": "high", + "rank": 1, + "candidate_role": "primary_suspect", + "evidence_basis": "Signal annotation and clinical plausibility", + } + ] + }, + "narrative": "Warfarin is a plausible contributor.", + "mode": "case_causal_review", + "diagnostics": {"parse_mode": "dict"}, + } + raise ValueError(f"unexpected tool: {name}") + + +def _fake_llm(prompt, required_keys=None): + return { + "candidates_by_domain": { + "drug_exposures": [ + { + "domain": "drug_exposures", + "label": "Warfarin", + "source_record_id": "drug-1", + "why_it_may_contribute": "Bleeding risk", + "confidence": "high", + "rank": 1, + } + ] + }, + "narrative": "Warfarin is a plausible contributor.", + "mode": "case_causal_review", + "diagnostics": {}, + } + + +def main() -> int: + original_call_llm = agent_module.call_llm + agent_module.call_llm = _fake_llm + try: + client = StubMCPClient() + agent = StudyAgent(mcp_client=client) + result = agent.run_case_causal_review_flow( + adverse_event_name="GI bleed", + case_row={ + "case_id": "case-1", + "case_summary": "GI bleed after anticoagulation.", + "index_event": { + "domain": "index_event", + "label": "GI bleed", + "source_record_id": "reaction-1", + "annotations": {"adverse_event_concept_id": 321, "adverse_event_meddra_id": "789"}, + }, + "candidate_items": [ + { + "domain": "drug_exposures", + "label": "Warfarin", + "source_record_id": "drug-1", + "subrole": "primary_suspect", + "annotations": {"ingredient_concept_id": 123, "ingred_rxcui": "456"}, + } + ], + "context_items": [ + { + "domain": "labs", + "label": "INR 4.2", + "source_record_id": "lab-1", + "subrole": "proximate_marker", + "annotations": {}, + } + ], + "case_metadata": { + "literature_reference_present": True, + "lookup_key": {"primaryid": None, "isr": "6526923"}, + }, + "annotations": {"concept_set_available_domains": ["drug_exposures"]}, + "tool_hints": { + "available_expansions": [ + "get_case_review_drug_signal_details", + "get_case_review_report_literature_stub", + ], + "prefetch_expansions": [ + "get_case_review_drug_signal_details", + "get_case_review_report_literature_stub", + ], + }, + }, + source_type="signal_validation", + allowed_domains=["drug_exposures"], + ) + finally: + agent_module.call_llm = original_call_llm + + assert result["status"] == "ok" + assert result["flow_name"] == "case_causal_review" + assert result["mode"] == "case_causal_review" + assert result["candidates_by_domain"]["drug_exposures"][0]["label"] == "Warfarin" + assert result["diagnostics"]["optional_enrichment"]["called"] == [ + "get_case_review_drug_signal_details", + "get_case_review_report_literature_stub", + ] + print(json.dumps(result, indent=2, sort_keys=True)) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/test_keeper_concept_sets_tools.py b/tests/test_keeper_concept_sets_tools.py index 3d02efa..cee0919 100644 --- a/tests/test_keeper_concept_sets_tools.py +++ b/tests/test_keeper_concept_sets_tools.py @@ -105,7 +105,8 @@ def test_vocab_add_nonchildren_merges_and_skips_descendants() -> None: @pytest.mark.mcp -def test_vocab_search_standard_reports_unconfigured_provider() -> None: +def test_vocab_search_standard_reports_unconfigured_provider(monkeypatch) -> None: + monkeypatch.delenv("VOCAB_SEARCH_PROVIDER", raising=False) tools = _registered_tools() result = tools["vocab_search_standard"]( query="GI bleed", @@ -119,7 +120,8 @@ def test_vocab_search_standard_reports_unconfigured_provider() -> None: @pytest.mark.mcp -def test_phoebe_related_concepts_reports_unconfigured_provider() -> None: +def test_phoebe_related_concepts_reports_unconfigured_provider(monkeypatch) -> None: + monkeypatch.delenv("PHOEBE_PROVIDER", raising=False) tools = _registered_tools() result = tools["phoebe_related_concepts"]( concept_ids=[1, 2], diff --git a/tests/test_logging_utils.py b/tests/test_logging_utils.py index 3c2c61b..8420ea3 100644 --- a/tests/test_logging_utils.py +++ b/tests/test_logging_utils.py @@ -1,6 +1,6 @@ import logging -from study_agent_core.logging_utils import configure_service_logger +from study_agent_core.logging_utils import configure_service_logger, format_log_kv, sanitize_log_value def test_configure_service_logger_writes_to_file(tmp_path, monkeypatch): @@ -40,3 +40,74 @@ def test_configure_service_logger_off_disables_logger(monkeypatch): logger.disabled = False logger.handlers.clear() logging.getLogger("study_agent.test.acp.off").handlers.clear() + + +def test_sanitize_log_value_redacts_credentials_and_phi(): + payload = { + "database_url": "postgresql://alice:supersecret@db.internal:5432/omop", + "authorization": "Bearer abc.def.ghi", + "patient_email": "patient@example.com", + "dob": "1984-07-15", + "person_id": "12345", + } + + sanitized = sanitize_log_value(payload) + + assert sanitized["database_url"] == "[REDACTED]" + assert sanitized["authorization"] == "[REDACTED]" + assert sanitized["patient_email"] == "[REDACTED_EMAIL]" + assert sanitized["dob"] == "[REDACTED_DATE]" + assert sanitized["person_id"] == "[REDACTED]" + + +def test_format_log_kv_redacts_helper_fields(): + rendered = format_log_kv( + { + "password": "secret123", + "embed_url": "https://user:pass@example.com/embed", + "owner_email": "owner@example.com", + } + ) + + assert "secret123" not in rendered + assert "user:pass@" not in rendered + assert "owner@example.com" not in rendered + assert "[REDACTED]" in rendered + assert "[REDACTED_CREDENTIALS]" in rendered + assert "[REDACTED_EMAIL]" in rendered + + +def test_configure_service_logger_redacts_formatted_args_in_file(tmp_path, monkeypatch): + log_dir = tmp_path / "logs" + monkeypatch.setenv("STUDY_AGENT_LOG_DIR", str(log_dir)) + monkeypatch.setenv("ACP_LOG_LEVEL", "DEBUG") + monkeypatch.setenv("ACP_LOG_TO_CONSOLE", "0") + + logger = configure_service_logger( + "ACP", + "study_agent.test.acp.redaction", + default_level="INFO", + stream="stderr", + default_filename="study-agent-acp.log", + ) + logger.info( + "dsn=%s auth=%s patient_email=%s dob=%s payload=%s", + "postgresql://alice:supersecret@db.internal:5432/omop", + "Bearer abc.def.ghi", + "patient@example.com", + "1984-07-15", + {"password": "swordfish", "person_id": "12345"}, + ) + + contents = (log_dir / "study-agent-acp.log").read_text(encoding="utf-8") + assert "supersecret" not in contents + assert "abc.def.ghi" not in contents + assert "patient@example.com" not in contents + assert "1984-07-15" not in contents + assert "swordfish" not in contents + assert "12345" not in contents + assert "[REDACTED_CREDENTIALS]" in contents + assert "[REDACTED_TOKEN]" in contents + assert "[REDACTED_EMAIL]" in contents + assert "[REDACTED_DATE]" in contents + assert "[REDACTED]" in contents