diff --git a/context_scribe/main.py b/context_scribe/main.py index 801c829..d9afa0a 100644 --- a/context_scribe/main.py +++ b/context_scribe/main.py @@ -4,7 +4,7 @@ import shutil from pathlib import Path from datetime import datetime -from typing import Optional +from typing import List, Optional import click from rich.console import Console from rich.live import Live @@ -20,6 +20,7 @@ from context_scribe.evaluator import get_evaluator, EVALUATOR_REGISTRY from context_scribe.bridge.mcp_client import MemoryBankClient + logger = logging.getLogger("context_scribe") console: Console = Console() @@ -163,6 +164,31 @@ def bootstrap_claude_config() -> None: f.write(f"\n{MASTER_RETRIEVAL_RULE}\n") +TOOL_REGISTRY = { + "gemini-cli": (GeminiCliProvider, bootstrap_global_config), + "copilot": (CopilotProvider, bootstrap_copilot_config), + "claude": (ClaudeProvider, bootstrap_claude_config), +} + + +def _create_providers(tools: List[str]): + """Create and bootstrap providers for the given tool names. + + Raises ValueError for unknown tool names. + """ + providers = [] + for tool in tools: + entry = TOOL_REGISTRY.get(tool) + if entry is None: + raise ValueError( + f"Unknown tool '{tool}'. Available: {', '.join(sorted(TOOL_REGISTRY))}" + ) + provider_cls, bootstrap_fn = entry + bootstrap_fn() + providers.append((tool, provider_cls())) + return providers + + def _detect_evaluator(preferred_tool: Optional[str] = None) -> str: """Auto-detect which evaluator CLI is available, prioritizing the preferred tool.""" # Map tool names to their corresponding CLI commands @@ -205,22 +231,20 @@ def _status(msg: str, db, live, debug: bool): live.update(db.generate_layout()) -async def run_daemon(tool: str, bank_path: str, debug: bool = False, evaluator_name: str = "auto") -> bool: - if tool == "gemini-cli": - bootstrap_global_config() - provider = GeminiCliProvider() - elif tool == "copilot": - bootstrap_copilot_config() - provider = CopilotProvider() - elif tool == "claude": - bootstrap_claude_config() - provider = ClaudeProvider() +async def run_daemon(tool: str, bank_path: str, debug: bool = False, evaluator_name: str = "auto", tools: Optional[List[str]] = None) -> bool: + # Build provider list: --tools takes precedence over --tool + if tools is not None: + if not tools: + raise ValueError("--tools was provided but resolved to an empty list.") + tool_names = tools else: - provider = None - if not provider: return False + tool_names = [tool] + providers = _create_providers(tool_names) + if not providers: + return False if evaluator_name == "auto": - evaluator_name = _detect_evaluator(tool) + evaluator_name = _detect_evaluator(tool_names[0]) evaluator = get_evaluator(evaluator_name) mcp_client = MemoryBankClient(bank_path=bank_path) @@ -230,29 +254,54 @@ async def run_daemon(tool: str, bank_path: str, debug: bool = False, evaluator_n console.print("[bold red]Fatal Error: Could not connect to the Memory Bank MCP server.[/bold red]") raise SystemExit(1) - db = Dashboard(tool, bank_path) + display_name = ",".join(tool_names) + db = Dashboard(display_name, bank_path) + queue: asyncio.Queue = asyncio.Queue(maxsize=1000) + + async def _watch_provider(tool_name: str, provider): + """Run a provider's watch() in a thread and feed interactions into the shared queue.""" + loop = asyncio.get_event_loop() + watch_iter = provider.watch() + try: + while True: + interaction = await loop.run_in_executor(None, next, watch_iter) + if interaction is not None: + await queue.put((tool_name, interaction)) + except (StopIteration, asyncio.CancelledError, KeyboardInterrupt): + pass + except Exception as e: + logger.error("Watcher for %s failed: %s", tool_name, e) async def _loop(live=None): + watcher_tasks = [] try: - loop = asyncio.get_event_loop() - watch_iter = provider.watch() + # Start a watcher task for each provider + watcher_tasks = [ + asyncio.create_task(_watch_provider(name, prov)) + for name, prov in providers + ] _status("🔍 Watching log stream...", db, live, debug) while True: - if live: live.update(db.generate_layout()) - interaction = await loop.run_in_executor(None, next, watch_iter) - if interaction is None: + if live: + live.update(db.generate_layout()) + + # Wait for next interaction from any provider + try: + tool_name, interaction = await asyncio.wait_for(queue.get(), timeout=1.0) + except asyncio.TimeoutError: continue - _status(f"🤔 Analyzing user message ({interaction.project_name})", db, live, debug) + _status(f"🤔 [{tool_name}] Analyzing user message ({interaction.project_name})", db, live, debug) if debug: - logging.getLogger("context_scribe").info(" content: %s", interaction.content[:120]) + logger.info(" content: %s", interaction.content[:120]) - _status(f"📖 Accessing Memory Bank ({interaction.project_name})...", db, live, debug) + _status(f"📖 [{tool_name}] Accessing Memory Bank ({interaction.project_name})...", db, live, debug) existing_global = await mcp_client.read_rules("global", "global_rules.md") existing_project = await mcp_client.read_rules(interaction.project_name, "rules.md") - _status(f"🧠 Thinking: Extracting rules for {interaction.project_name}...", db, live, debug) + _status(f"🧠 [{tool_name}] Extracting rules for {interaction.project_name}...", db, live, debug) + loop = asyncio.get_event_loop() rule_output = await loop.run_in_executor(None, evaluator.evaluate_interaction, interaction, existing_global, existing_project) if rule_output: @@ -272,11 +321,11 @@ async def _loop(live=None): seen.add(stripped) deduped_content = "\n".join(unique_lines).strip() - _status(f"📝 Committing: {dest_path}", db, live, debug) + _status(f"📝 [{tool_name}] Committing: {dest_path}", db, live, debug) await mcp_client.save_rule(deduped_content, dest_proj, dest_file) db.add_history(dest_path, rule_output.description) - _status(f"✅ SUCCESS: Updated {dest_path}", db, live, debug) + _status(f"✅ [{tool_name}] Updated {dest_path}", db, live, debug) if not debug: console.print(f"[bold green]▶ UPDATED:[/bold green] [cyan]{dest_path}[/cyan] ({rule_output.description})") else: @@ -287,6 +336,8 @@ async def _loop(live=None): except (KeyboardInterrupt, asyncio.CancelledError): _status("🛑 Stopping...", db, live, debug) finally: + for task in watcher_tasks: + task.cancel() await mcp_client.close() if debug: @@ -297,16 +348,34 @@ async def _loop(live=None): return True @click.command() -@click.option('--tool', default='gemini-cli', type=click.Choice(['gemini-cli', 'copilot', 'claude']), help='The AI tool to monitor') +@click.option('--tool', default='gemini-cli', type=click.Choice(['gemini-cli', 'copilot', 'claude']), help='Single AI tool to monitor (use --tools for multiple)') +@click.option('--tools', 'tools_csv', default=None, help='Comma-separated tools to monitor concurrently (e.g. gemini-cli,claude,copilot)') @click.option('--bank-path', default='~/.memory-bank', help='Path to your Memory Bank root') @click.option('--evaluator', 'evaluator_name', default='auto', type=click.Choice(['auto'] + sorted(EVALUATOR_REGISTRY)), help='Evaluator LLM to use (default: auto-detect)') @click.option('--debug', is_flag=True, default=False, help='Stream plain debug logs instead of dashboard UI') -def cli(tool, bank_path, evaluator_name, debug): +def cli(tool, tools_csv, bank_path, evaluator_name, debug): """Context-Scribe: Persistent Secretary Daemon""" if debug: logging.basicConfig(level=logging.DEBUG, format='%(asctime)s [%(levelname)s] %(name)s: %(message)s') + + # Parse --tools if provided + tools = None + if tools_csv is not None: + tools = list(dict.fromkeys( # deduplicate preserving order + t.strip() for t in tools_csv.split(",") if t.strip() + )) + if not tools: + raise click.ClickException("--tools requires at least one tool name.") + valid_tools = set(TOOL_REGISTRY) + invalid = [t for t in tools if t not in valid_tools] + if invalid: + raise click.ClickException( + f"Unknown tool(s): {', '.join(invalid)}. " + f"Available: {', '.join(sorted(valid_tools))}" + ) + try: - asyncio.run(run_daemon(tool, bank_path, debug=debug, evaluator_name=evaluator_name)) + asyncio.run(run_daemon(tool, bank_path, debug=debug, evaluator_name=evaluator_name, tools=tools)) except KeyboardInterrupt: pass diff --git a/tests/test_daemons.py b/tests/test_daemons.py index c42de2e..0cde75b 100644 --- a/tests/test_daemons.py +++ b/tests/test_daemons.py @@ -4,41 +4,40 @@ from context_scribe.main import run_daemon @pytest.mark.asyncio -@pytest.mark.parametrize("tool, provider_class, evaluator_class, bootstrap_func, evaluator_name", [ - ("gemini-cli", "GeminiCliProvider", "GeminiCliEvaluator", "bootstrap_global_config", "gemini"), - ("copilot", "CopilotProvider", "CopilotEvaluator", "bootstrap_copilot_config", "copilot"), - ("claude", "ClaudeProvider", "ClaudeEvaluator", "bootstrap_claude_config", "claude"), +@pytest.mark.parametrize("tool, bootstrap_func, evaluator_name", [ + ("gemini-cli", "bootstrap_global_config", "gemini"), + ("copilot", "bootstrap_copilot_config", "copilot"), + ("claude", "bootstrap_claude_config", "claude"), ]) -async def test_run_daemon_tools(tool, provider_class, evaluator_class, bootstrap_func, evaluator_name, daemon_mocks): +async def test_run_daemon_tools(tool, bootstrap_func, evaluator_name, daemon_mocks): """Test the daemon run loop for all supported tools.""" - - with patch(f"context_scribe.main.{provider_class}", return_value=daemon_mocks.provider): - with patch(f"context_scribe.main.{evaluator_class}", return_value=daemon_mocks.evaluator): + + with patch("context_scribe.main._create_providers", return_value=[(tool, daemon_mocks.provider)]): + with patch("context_scribe.main.get_evaluator", return_value=daemon_mocks.evaluator): with patch("context_scribe.main.MemoryBankClient", return_value=daemon_mocks.mcp): with patch(f"context_scribe.main.{bootstrap_func}"): # Mock Live to avoid rich rendering logic completely with patch("context_scribe.main.Live") as mock_live: - with patch("os._exit") as mock_exit: - # Make the context manager work - mock_live.return_value.__enter__.return_value = MagicMock() + # Make the context manager work + mock_live.return_value.__enter__.return_value = MagicMock() - # Start daemon and wait for it to process the mocked interaction - daemon_task = asyncio.create_task(run_daemon(tool, "~/.memory-bank", evaluator_name=evaluator_name)) + # Start daemon and wait for it to process the mocked interaction + daemon_task = asyncio.create_task(run_daemon(tool, "~/.memory-bank", evaluator_name=evaluator_name)) - # Wait until save_rule is called (meaning interaction processed) - for _ in range(50): - if daemon_mocks.processed_interaction: - break - await asyncio.sleep(0.1) + # Wait until save_rule is called (meaning interaction processed) + for _ in range(100): + if daemon_mocks.processed_interaction: + break + await asyncio.sleep(0.1) - daemon_task.cancel() - try: - await daemon_task - except asyncio.CancelledError: - pass + daemon_task.cancel() + try: + await daemon_task + except asyncio.CancelledError: + pass - # Verify calls - daemon_mocks.mcp.connect.assert_called_once() - daemon_mocks.mcp.read_rules.assert_called() - daemon_mocks.evaluator.evaluate_interaction.assert_called() - daemon_mocks.mcp.save_rule.assert_called_once_with("Extracted Rule", "global", "global_rules.md") + # Verify calls + daemon_mocks.mcp.connect.assert_called_once() + daemon_mocks.mcp.read_rules.assert_called() + daemon_mocks.evaluator.evaluate_interaction.assert_called() + daemon_mocks.mcp.save_rule.assert_called_once_with("Extracted Rule", "global", "global_rules.md") diff --git a/tests/test_multi_tool.py b/tests/test_multi_tool.py new file mode 100644 index 0000000..a97bd77 --- /dev/null +++ b/tests/test_multi_tool.py @@ -0,0 +1,141 @@ +"""Tests for concurrent multi-tool daemon support.""" +import asyncio +import sys +from unittest.mock import patch, MagicMock +import pytest + + +@pytest.fixture(autouse=True) +def mock_heavy_deps(): + """Mock heavy imports so we can import main without mcp/rich.""" + mocks = {} + for mod in ["mcp", "mcp.client", "mcp.client.stdio", + "rich", "rich.console", "rich.live", "rich.panel", + "rich.text", "rich.layout", "rich.table", "rich.spinner"]: + if mod not in sys.modules or not hasattr(sys.modules.get(mod), '__file__'): + mocks[mod] = MagicMock() + with patch.dict(sys.modules, mocks): + # Clear cached imports so they re-resolve with mocks + for key in list(sys.modules.keys()): + if key.startswith("context_scribe.main") or key.startswith("context_scribe.bridge"): + del sys.modules[key] + yield + + +def test_tool_registry_populated(): + from context_scribe.main import TOOL_REGISTRY + assert "gemini-cli" in TOOL_REGISTRY + assert "copilot" in TOOL_REGISTRY + assert "claude" in TOOL_REGISTRY + + +def test_create_providers_single(): + from context_scribe.main import _create_providers + with patch("context_scribe.main.bootstrap_global_config"): + with patch("context_scribe.main.GeminiCliProvider") as mock_cls: + mock_cls.return_value = MagicMock() + providers = _create_providers(["gemini-cli"]) + assert len(providers) == 1 + assert providers[0][0] == "gemini-cli" + + +def test_create_providers_multiple(): + from context_scribe.main import _create_providers + with patch("context_scribe.main.bootstrap_global_config"): + with patch("context_scribe.main.bootstrap_claude_config"): + with patch("context_scribe.main.GeminiCliProvider", return_value=MagicMock()): + with patch("context_scribe.main.ClaudeProvider", return_value=MagicMock()): + providers = _create_providers(["gemini-cli", "claude"]) + assert len(providers) == 2 + names = [p[0] for p in providers] + assert "gemini-cli" in names + assert "claude" in names + + +def test_create_providers_unknown_raises(): + from context_scribe.main import _create_providers + with pytest.raises(ValueError, match="Unknown tool"): + _create_providers(["nonexistent"]) + + +def test_create_providers_calls_bootstrap(): + from context_scribe.main import _create_providers, TOOL_REGISTRY + mock_boot = MagicMock() + original = TOOL_REGISTRY["gemini-cli"] + TOOL_REGISTRY["gemini-cli"] = (original[0], mock_boot) + try: + with patch("context_scribe.main.GeminiCliProvider", return_value=MagicMock()): + _create_providers(["gemini-cli"]) + mock_boot.assert_called_once() + finally: + TOOL_REGISTRY["gemini-cli"] = original + + +def test_create_providers_all_three(): + from context_scribe.main import _create_providers + with patch("context_scribe.main.bootstrap_global_config"): + with patch("context_scribe.main.bootstrap_copilot_config"): + with patch("context_scribe.main.bootstrap_claude_config"): + with patch("context_scribe.main.GeminiCliProvider", return_value=MagicMock()): + with patch("context_scribe.main.CopilotProvider", return_value=MagicMock()): + with patch("context_scribe.main.ClaudeProvider", return_value=MagicMock()): + providers = _create_providers(["gemini-cli", "copilot", "claude"]) + assert len(providers) == 3 + + +# --- CLI tests for --tools flag --- + +def test_cli_tools_deduplication(): + """--tools with duplicate entries should deduplicate preserving order.""" + from click.testing import CliRunner + from context_scribe.main import cli + + runner = CliRunner() + captured_tools = {} + + original_run_daemon = None + + async def fake_run_daemon(tool, bank_path, debug=False, evaluator_name="auto", tools=None): + captured_tools["tools"] = tools + return True + + with patch("context_scribe.main.run_daemon", side_effect=fake_run_daemon) as mock_rd: + loop = asyncio.new_event_loop() + with patch("asyncio.run", side_effect=lambda coro: loop.run_until_complete(coro)): + result = runner.invoke(cli, ["--tools", "gemini-cli,gemini-cli,claude"]) + + assert result.exit_code == 0 + assert captured_tools["tools"] == ["gemini-cli", "claude"] + + +def test_cli_tools_invalid_tool(): + """--tools with an unknown tool name should fail with a clear error.""" + from click.testing import CliRunner + from context_scribe.main import cli + + runner = CliRunner() + result = runner.invoke(cli, ["--tools", "gemini-cli,nonexistent"]) + assert result.exit_code != 0 + assert "Unknown tool(s): nonexistent" in result.output + + +def test_cli_tools_empty_string(): + """--tools with an empty string should fail.""" + from click.testing import CliRunner + from context_scribe.main import cli + + runner = CliRunner() + result = runner.invoke(cli, ["--tools", ""]) + assert result.exit_code != 0 + assert "--tools requires at least one tool name" in result.output + + +def test_cli_tools_whitespace_only(): + """--tools with only whitespace/commas should fail.""" + from click.testing import CliRunner + from context_scribe.main import cli + + runner = CliRunner() + result = runner.invoke(cli, ["--tools", " , , "]) + assert result.exit_code != 0 + assert "--tools requires at least one tool name" in result.output