diff --git a/src/execution/scheduler.rs b/src/execution/scheduler.rs index f3f07227..92f20c1d 100644 --- a/src/execution/scheduler.rs +++ b/src/execution/scheduler.rs @@ -123,6 +123,10 @@ pub struct Scheduler { hook_registry: HookRegistry, /// Project root directory for resolving hook paths project_root: PathBuf, + /// Reuse existing test database (--reuse-db) + reuse_db: bool, + /// Force recreation of test database (--create-db) + create_db: bool, } impl Scheduler { @@ -143,6 +147,8 @@ impl Scheduler { None, hook_registry, project_root, + false, + false, ) } @@ -165,6 +171,8 @@ impl Scheduler { None, hook_registry, project_root, + false, + false, ) } @@ -179,6 +187,8 @@ impl Scheduler { timeout_hook: Option, hook_registry: HookRegistry, project_root: PathBuf, + reuse_db: bool, + create_db: bool, ) -> Result { let max_workers = log_capture.slot_count(); @@ -202,6 +212,8 @@ impl Scheduler { timeout_hook, hook_registry, project_root, + reuse_db, + create_db, }) } @@ -518,8 +530,8 @@ impl Scheduler { cached_effects, markers: test.markers.clone(), marker_info: test.marker_info.clone(), - reuse_db: false, - create_db: false, + reuse_db: self.reuse_db, + create_db: self.create_db, }; // Use encode_with_length which includes protocol header diff --git a/src/execution/zygote.rs b/src/execution/zygote.rs index 918bda24..54338620 100644 --- a/src/execution/zygote.rs +++ b/src/execution/zygote.rs @@ -661,7 +661,10 @@ pub fn entrypoint(cmd_socket: UnixStream, result_socket: UnixStream) -> Result<( } // Django Detection & Setup (Batteries-Included) - // Initialize Django in Zygote so workers inherit the pre-warmed state + // Initialize Django in Zygote so workers inherit the pre-warmed state. + // NOTE: We do NOT warm up DB connections here. setup_databases() creates + // a test DB after init_session(), and connections are closed before fork + // so workers get fresh file descriptors. py.run( c_str!(r#" import os @@ -670,20 +673,9 @@ import sys try: import django - # Check if DJANGO_SETTINGS_MODULE is already set if 'DJANGO_SETTINGS_MODULE' in os.environ: django.setup() print(f'[tach:zygote] Django initialized: {os.environ["DJANGO_SETTINGS_MODULE"]}', file=sys.stderr) - - # CRITICAL: Warm up DB connections before forking - # File descriptors must exist in Zygote to be inherited by workers - try: - from django.db import connections - for alias in connections: - connections[alias].ensure_connection() - print(f'[tach:zygote] Django DB connections warmed up', file=sys.stderr) - except Exception as e: - print(f'[tach:zygote] Django DB warmup failed: {e}', file=sys.stderr) except ImportError: pass # Django not installed, skip except Exception as e: @@ -708,6 +700,24 @@ except Exception as e: let target_path = std::env::var("TACH_TARGET_PATH").unwrap_or_else(|_| cwd_str.clone()); harness.getattr("init_session")?.call1((&target_path,))?; + // Django Test Database: Create test DB after pytest is configured but + // before workers fork. Reads TACH_REUSE_DB/TACH_CREATE_DB env vars. + // Connections are closed before fork so workers get fresh FDs. + py.run( + c_str!( + r#" +try: + import tach_harness + tach_harness._setup_django_test_db() +except Exception as e: + import sys + print(f'[tach:zygote] Django test DB setup error: {e}', file=sys.stderr) +"# + ), + None, + None, + )?; + // HOOK EFFECT BRIDGE (v0.2.0): Retrieve session effects from Python // After init_session(), Python has recorded effects in _SESSION_HOOK_EFFECTS. // We retrieve them here and will send them to the Supervisor for HookRegistry population. diff --git a/src/main.rs b/src/main.rs index 6efa837b..a8e85ed8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -201,6 +201,15 @@ fn main() -> Result<()> { // SAFETY: Same as above - called before worker threads spawn. unsafe { std::env::set_var("TACH_TARGET_PATH", &cli.path) }; + // Set Django test DB flags for Zygote to read during setup_databases() + // SAFETY: Same as above - called before worker threads spawn. + if cli.reuse_db { + unsafe { std::env::set_var("TACH_REUSE_DB", "1") }; + } + if cli.create_db { + unsafe { std::env::set_var("TACH_CREATE_DB", "1") }; + } + // --- LIFECYCLE SETUP --- debugger::install_panic_hook(); @@ -1160,6 +1169,8 @@ fn run_tests( timeout_hook, hook_registry, project_root, + std::env::var("TACH_REUSE_DB").unwrap_or_default() == "1", + std::env::var("TACH_CREATE_DB").unwrap_or_default() == "1", )?; let stats = scheduler.run(runnable_tests, reporter)?; diff --git a/src/tach_harness.py b/src/tach_harness.py index 4e4bd259..92ab057e 100644 --- a/src/tach_harness.py +++ b/src/tach_harness.py @@ -116,16 +116,16 @@ def get_scope_key(self, item: Any) -> str: if self._current_scope == "session": return "session" elif self._current_scope == "module": - fspath = getattr(item, 'fspath', None) + fspath = getattr(item, "fspath", None) return f"module:{fspath}" if fspath else f"module:unknown:{id(item)}" elif self._current_scope == "class": cls = getattr(item, "cls", None) if cls: return f"class:{cls.__module__}.{cls.__name__}" - nodeid = getattr(item, 'nodeid', None) + nodeid = getattr(item, "nodeid", None) return f"function:{nodeid}" if nodeid else f"function:unknown:{id(item)}" else: # function scope (default) - nodeid = getattr(item, 'nodeid', None) + nodeid = getattr(item, "nodeid", None) return f"function:{nodeid}" if nodeid else f"function:unknown:{id(item)}" def close_scope(self, scope_key: str) -> None: @@ -140,7 +140,9 @@ def close_scope(self, scope_key: str) -> None: task.cancel() # Run until all tasks are cancelled if pending: - loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) + loop.run_until_complete( + asyncio.gather(*pending, return_exceptions=True) + ) except Exception as e: _logger.debug("Cleanup error in close_scope: %s", e) finally: @@ -152,9 +154,7 @@ def close_all(self) -> None: self.close_scope(scope_key) def on_scope_transition( - self, - current_module: Optional[str], - current_class: Optional[str] + self, current_module: Optional[str], current_class: Optional[str] ) -> None: """Handle scope transitions and cleanup old scopes. @@ -163,13 +163,19 @@ def on_scope_transition( """ # Module transition: close previous module's loop if scope is module if self._current_scope == "module": - if self._previous_module is not None and self._previous_module != current_module: + if ( + self._previous_module is not None + and self._previous_module != current_module + ): old_key = f"module:{self._previous_module}" self.close_scope(old_key) # Class transition: close previous class's loop if scope is class if self._current_scope == "class": - if self._previous_class is not None and self._previous_class != current_class: + if ( + self._previous_class is not None + and self._previous_class != current_class + ): old_key = f"class:{self._previous_class}" self.close_scope(old_key) @@ -208,6 +214,7 @@ def detect_uvloop() -> bool: """Detect if uvloop is available.""" try: import uvloop # noqa: F401 + return True except ImportError: return False @@ -217,6 +224,7 @@ def get_uvloop_policy() -> Optional[asyncio.AbstractEventLoopPolicy]: """Get uvloop policy if available, None otherwise.""" try: import uvloop + return uvloop.EventLoopPolicy() except ImportError: return None @@ -231,13 +239,22 @@ class AsyncFixtureWrapper: """ _generators_by_scope: dict[str, dict[str, Any]] = { - "session": {}, "module": {}, "class": {}, "function": {}, + "session": {}, + "module": {}, + "class": {}, + "function": {}, } _consumed_by_scope: dict[str, set[str]] = { - "session": set(), "module": set(), "class": set(), "function": set(), + "session": set(), + "module": set(), + "class": set(), + "function": set(), } _values_by_scope: dict[str, dict[str, Any]] = { - "session": {}, "module": {}, "class": {}, "function": {}, + "session": {}, + "module": {}, + "class": {}, + "function": {}, } _loop: Optional[asyncio.AbstractEventLoop] = None _teardown_errors: list[Exception] = [] @@ -264,7 +281,9 @@ def get_loop(cls) -> asyncio.AbstractEventLoop: return cls._loop @classmethod - def on_test_start(cls, module_path: Optional[str], class_name: Optional[str]) -> None: + def on_test_start( + cls, module_path: Optional[str], class_name: Optional[str] + ) -> None: if cls._current_module is not None and cls._current_module != module_path: cls.teardown_module_scope() if cls._current_class is not None and cls._current_class != class_name: @@ -284,8 +303,10 @@ def consume_async_fixture( loop = cls.get_loop() try: if inspect.isasyncgen(fixture_value): + async def consume_gen(): return await fixture_value.__anext__() + result = loop.run_until_complete(consume_gen()) cls._generators_by_scope[scope][fixture_name] = fixture_value cls._consumed_by_scope[scope].add(fixture_name) @@ -310,8 +331,10 @@ def _teardown_scope(cls, scope: str) -> None: loop = cls.get_loop() for name, gen in list(generators.items()): try: + async def cleanup(): await gen.aclose() + loop.run_until_complete(cleanup()) except Exception as e: cls._teardown_errors.append( @@ -379,14 +402,16 @@ def pytest_fixture_setup(fixturedef, request): return is_async = inspect.isasyncgen(result) or asyncio.iscoroutine(result) if is_async: - scope = getattr(fixturedef, 'scope', 'function') + scope = getattr(fixturedef, "scope", "function") try: consumed_value = AsyncFixtureWrapper.consume_async_fixture( fixturedef.argname, result, scope ) outcome.force_result(consumed_value) except Exception as e: - _logger.error("Failed to consume async fixture '%s': %s", fixturedef.argname, e) + _logger.error( + "Failed to consume async fixture '%s': %s", fixturedef.argname, e + ) raise @@ -427,8 +452,13 @@ def _configure_asyncio_from_pyproject(root_dir: str) -> None: loop_scope = pytest_opts.get("asyncio_default_fixture_loop_scope", "function") auto_mode = asyncio_mode == "auto" if auto_mode or loop_scope != "function": - EventLoopManager.get_instance().configure(loop_scope=loop_scope, auto_mode=auto_mode) - os.write(2, f"[tach:harness] Asyncio config: mode={asyncio_mode}, loop_scope={loop_scope}\n".encode()) + EventLoopManager.get_instance().configure( + loop_scope=loop_scope, auto_mode=auto_mode + ) + os.write( + 2, + f"[tach:harness] Asyncio config: mode={asyncio_mode}, loop_scope={loop_scope}\n".encode(), + ) except Exception as e: _logger.warning("Failed to parse asyncio config: %s", e) @@ -583,11 +613,13 @@ def teardown_async_fixture( This helper will be integrated when Tach implements native fixture resolution. """ if gen is not None: + async def cleanup(): try: await gen.__anext__() except StopAsyncIteration: pass + try: loop.run_until_complete(cleanup()) except Exception as e: @@ -697,7 +729,10 @@ def _detect_thread_leak(initial_count: int, allow_threads: bool) -> bool: time.sleep(0.050) # 50ms intervals current_count = threading.active_count() if current_count <= initial_count: - print("[tach:harness] INFO: Threads terminated within grace period", file=sys.stderr) + print( + "[tach:harness] INFO: Threads terminated within grace period", + file=sys.stderr, + ) return False # Grace period expired, threads still running @@ -894,7 +929,9 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> bool: self._catch.__exit__(exc_type, exc_val, exc_tb) # Filter warnings to find matching ones - matching = [w for w in self._warnings if issubclass(w.category, self.expected_warning)] + matching = [ + w for w in self._warnings if issubclass(w.category, self.expected_warning) + ] if not matching: raise AssertionError(f"DID NOT WARN {self.expected_warning}") @@ -1000,7 +1037,10 @@ def __eq__(self, actual) -> bool: return False if len(self.expected) != len(actual): return False - return all(approx(e, self.rel, self.abs) == a for e, a in zip(self.expected, actual)) + return all( + approx(e, self.rel, self.abs) == a + for e, a in zip(self.expected, actual) + ) # Single value comparison # Handle special float values first @@ -1017,8 +1057,16 @@ def __eq__(self, actual) -> bool: try: # Use built-in abs function (not self.abs) - expected_abs = __builtins__["abs"](self.expected) if isinstance(__builtins__, dict) else abs(self.expected) - diff = __builtins__["abs"](self.expected - actual) if isinstance(__builtins__, dict) else abs(self.expected - actual) + expected_abs = ( + __builtins__["abs"](self.expected) + if isinstance(__builtins__, dict) + else abs(self.expected) + ) + diff = ( + __builtins__["abs"](self.expected - actual) + if isinstance(__builtins__, dict) + else abs(self.expected - actual) + ) except (TypeError, KeyError): # Fallback for edge cases import builtins @@ -1194,7 +1242,9 @@ def tach_breakpointhook(*args, **kwargs): global _debug_socket_path if not _debug_socket_path: - print("[tach] WARNING: breakpoint() called but no debug socket.", file=sys.stderr) + print( + "[tach] WARNING: breakpoint() called but no debug socket.", file=sys.stderr + ) return sock = None @@ -1238,12 +1288,17 @@ def inject_entropy(): try: import ctypes import ctypes.util - ssl_lib_path = ctypes.util.find_library('ssl') + + ssl_lib_path = ctypes.util.find_library("ssl") if ssl_lib_path: ssl_lib = ctypes.CDLL(ssl_lib_path) # Note: hasattr on CDLL may not work reliably; try/except is the real safeguard - if hasattr(ssl_lib, 'RAND_add'): - ssl_lib.RAND_add.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_double] + if hasattr(ssl_lib, "RAND_add"): + ssl_lib.RAND_add.argtypes = [ + ctypes.c_char_p, + ctypes.c_int, + ctypes.c_double, + ] entropy_bytes = os.urandom(32) ssl_lib.RAND_add(entropy_bytes, 32, 32.0) except Exception as e: @@ -1430,7 +1485,9 @@ def uninstall_tach_import_hook(): """Remove the Tach import hook from sys.meta_path.""" global _TACH_IMPORT_HOOK_INSTALLED - sys.meta_path[:] = [f for f in sys.meta_path if not isinstance(f, TachMetaPathFinder)] + sys.meta_path[:] = [ + f for f in sys.meta_path if not isinstance(f, TachMetaPathFinder) + ] _TACH_IMPORT_HOOK_INSTALLED = False @@ -1513,7 +1570,10 @@ def post_fork_init() -> bool: # 3. Capture baseline sys.modules for hot reloading # This snapshot defines what modules are "framework" vs "test-imported" _INITIAL_MODULES = set(sys.modules.keys()) - print(f"[tach:harness] Captured {len(_INITIAL_MODULES)} baseline modules", file=sys.stderr) + print( + f"[tach:harness] Captured {len(_INITIAL_MODULES)} baseline modules", + file=sys.stderr, + ) # 4. Check if snapshot mode is enabled import os @@ -1596,7 +1656,9 @@ def _format_local_value(name: str, value) -> str: return f" {name} = " -def _get_source_context(filename: str, lineno: int, context_lines: int = _CONTEXT_LINES) -> Optional[str]: +def _get_source_context( + filename: str, lineno: int, context_lines: int = _CONTEXT_LINES +) -> Optional[str]: """Get source code context around a specific line. Args: @@ -1804,9 +1866,7 @@ def detect_installed_plugins() -> dict: # Check if it's a pytest plugin (entry point group pytest11) try: eps = dist.entry_points - is_pytest_plugin = any( - ep.group == "pytest11" for ep in eps - ) + is_pytest_plugin = any(ep.group == "pytest11" for ep in eps) except Exception: is_pytest_plugin = name.startswith("pytest-") or name.startswith("pytest_") @@ -1897,20 +1957,24 @@ def _compute_sys_path_delta(before: list, after: list) -> list: action = "prepend" else: action = "append" - effects.append({ - "type": EFFECT_TYPE_MODIFY_SYS_PATH, - "action": action, - "path": path, - }) + effects.append( + { + "type": EFFECT_TYPE_MODIFY_SYS_PATH, + "action": action, + "path": path, + } + ) # Find paths removed (in before but not in after) for path in before: if path not in after: - effects.append({ - "type": EFFECT_TYPE_MODIFY_SYS_PATH, - "action": "remove", - "path": path, - }) + effects.append( + { + "type": EFFECT_TYPE_MODIFY_SYS_PATH, + "action": "remove", + "path": path, + } + ) return effects @@ -1925,17 +1989,21 @@ def _compute_env_delta(before: dict, after: dict) -> list: # Find added or changed variables for key, value in after.items(): if key not in before: - effects.append({ - "type": EFFECT_TYPE_SET_ENV, - "key": key, - "value": value, - }) + effects.append( + { + "type": EFFECT_TYPE_SET_ENV, + "key": key, + "value": value, + } + ) elif before[key] != value: - effects.append({ - "type": EFFECT_TYPE_SET_ENV, - "key": key, - "value": value, - }) + effects.append( + { + "type": EFFECT_TYPE_SET_ENV, + "key": key, + "value": value, + } + ) # Note: We don't track unset env vars for session-level hooks # as pytest_configure typically only adds env vars @@ -2163,7 +2231,10 @@ def call_collection_modifyitems( result["removed"] = removed # Determine if reordered (same items but different order) - if items_before_count == items_after_count and items_before_ids != items_after_ids: + if ( + items_before_count == items_after_count + and items_before_ids != items_after_ids + ): result["reordered"] = True finally: @@ -2227,10 +2298,16 @@ def apply_cached_effects(effects: list) -> int: # Debug logging: show count of effects provided vs applied if provided > 0: if applied > 0: - print(f"[tach:harness] Applied {applied}/{provided} cached hook effects", file=sys.stderr) + print( + f"[tach:harness] Applied {applied}/{provided} cached hook effects", + file=sys.stderr, + ) else: # Warning: effects were provided but none were applied - print(f"[tach:harness] WARNING: {provided} effects provided but 0 applied (possible mismatch)", file=sys.stderr) + print( + f"[tach:harness] WARNING: {provided} effects provided but 0 applied (possible mismatch)", + file=sys.stderr, + ) return applied @@ -2243,6 +2320,7 @@ def apply_cached_effects(effects: list) -> int: _SESSION = None _ITEMS_MAP = {} # nodeid -> pytest Item _PARAM_FUZZY_INDEX = {} # "file::test_name" -> [pytest Item, ...] +_DJANGO_OLD_CONFIG = None # Stores setup_databases() return for teardown def _fuzzy_parametrize_lookup(rust_node_id: str) -> Any: @@ -2277,7 +2355,7 @@ def _fuzzy_parametrize_lookup(rust_node_id: str) -> Any: pytest_params = item.nodeid.rsplit("[", 1)[-1].rstrip("]") pytest_parts = pytest_params.split("-") if len(rust_parts) <= len(pytest_parts): - if pytest_parts[-len(rust_parts):] == rust_parts: + if pytest_parts[-len(rust_parts) :] == rust_parts: return item # Strategy 2: Shared suffix — find candidates where the last N resolved parts match @@ -2313,8 +2391,9 @@ def _fuzzy_parametrize_lookup(rust_node_id: str) -> Any: # Strategy 3: Index extraction — Rust fallback IDs use {param_name}{index} pattern # Extract the index from fallback parts and map to the Nth candidate import re + for part in rust_parts: - m = re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*?(\d+)$', part) + m = re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*?(\d+)$", part) if m: idx = int(m.group(1)) if 0 <= idx < len(candidates): @@ -2378,14 +2457,17 @@ def init_session(root_dir: str): # Patch FixtureDef.execute to consume async fixtures at resolution time. # This MUST be in init_session() so workers inherit the patch via fork. from _pytest.fixtures import FixtureDef - if not getattr(FixtureDef.execute, '_tach_patched', False): + + if not getattr(FixtureDef.execute, "_tach_patched", False): _original_execute = FixtureDef.execute def _patched_execute(self, request): result = _original_execute(self, request) if inspect.isasyncgen(result) or asyncio.iscoroutine(result): - scope = getattr(self, 'scope', 'function') - result = AsyncFixtureWrapper.consume_async_fixture(self.argname, result, scope) + scope = getattr(self, "scope", "function") + result = AsyncFixtureWrapper.consume_async_fixture( + self.argname, result, scope + ) return result _patched_execute._tach_patched = True @@ -2470,10 +2552,15 @@ def _patched_execute(self, request): ) fuzzy_count = sum(len(v) for v in _PARAM_FUZZY_INDEX.values()) - os.write(2, f"[tach:harness] Pre-collected {len(_ITEMS_MAP)} tests ({fuzzy_count} fuzzy parametrize entries)\n".encode()) + os.write( + 2, + f"[tach:harness] Pre-collected {len(_ITEMS_MAP)} tests ({fuzzy_count} fuzzy parametrize entries)\n".encode(), + ) -def _parse_django_db_marker(marker_info: list[dict[str, Any]] | None) -> dict[str, Any] | None: +def _parse_django_db_marker( + marker_info: list[dict[str, Any]] | None, +) -> dict[str, Any] | None: """Parse @pytest.mark.django_db marker arguments. Looks for a marker named 'django_db' in the marker_info list and @@ -2516,6 +2603,7 @@ def _is_django_available() -> bool: return False try: from django.conf import settings + return settings.configured except ImportError: return False @@ -2541,12 +2629,101 @@ def _close_django_connections() -> None: connections.close_all() except DatabaseError as e: - print(f"[tach:harness] WARN: Database error closing connections: {e}", file=sys.stderr) + print( + f"[tach:harness] WARN: Database error closing connections: {e}", + file=sys.stderr, + ) except Exception as e: _logger.debug("Django connection close error: %s", e) -def _apply_django_db_isolation(marker_args: dict[str, Any] | None) -> list[tuple[str, str]]: +def _setup_django_test_db() -> None: + """Create Django test database in Zygote before forking workers. + + Calls django.test.utils.setup_databases() to create a test DB with + all migrations applied. Workers inherit this via fork. + + Reads TACH_REUSE_DB / TACH_CREATE_DB env vars for keepdb behavior. + Registers atexit handler for teardown. + """ + global _DJANGO_OLD_CONFIG + + if not _is_django_available(): + return + + reuse_db = os.environ.get("TACH_REUSE_DB", "") == "1" + create_db = os.environ.get("TACH_CREATE_DB", "") == "1" + + # --create-db overrides --reuse-db (matches pytest-django semantics) + keepdb = reuse_db and not create_db + + try: + from django.test.utils import setup_databases + + print( + f"[tach:harness] Setting up Django test database (keepdb={keepdb})", + file=sys.stderr, + ) + + _DJANGO_OLD_CONFIG = setup_databases( + verbosity=1, + interactive=False, + keepdb=keepdb, + ) + + print("[tach:harness] Django test database ready", file=sys.stderr) + + # Close connections BEFORE fork — workers must get fresh FDs + _close_django_connections() + + # Register teardown — Zygote has no clean Python shutdown path + import atexit + + atexit.register(_teardown_django_test_db) + + except Exception as e: + print( + f"[tach:harness] ERROR: Django test DB setup failed: {e}", + file=sys.stderr, + ) + raise + + +def _teardown_django_test_db() -> None: + """Tear down Django test database via atexit. + + Skips DROP DATABASE when keepdb=True (--reuse-db preserves DB for next run). + """ + global _DJANGO_OLD_CONFIG + + if _DJANGO_OLD_CONFIG is None: + return + + reuse_db = os.environ.get("TACH_REUSE_DB", "") == "1" + create_db = os.environ.get("TACH_CREATE_DB", "") == "1" + keepdb = reuse_db and not create_db + + try: + from django.test.utils import teardown_databases + + teardown_databases( + _DJANGO_OLD_CONFIG, + verbosity=1, + keepdb=keepdb, + ) + print("[tach:harness] Django test database torn down", file=sys.stderr) + except Exception as e: + print( + f"[tach:harness] WARN: Django test DB teardown failed: {e}", + file=sys.stderr, + ) + finally: + _DJANGO_OLD_CONFIG = None + + +def _apply_django_db_isolation( + marker_args: dict[str, Any] | None, +) -> list[tuple[str, str]]: """Apply database isolation based on marker args. Uses SAVEPOINT for transaction isolation when transaction=False (default). @@ -2568,8 +2745,12 @@ def _apply_django_db_isolation(marker_args: dict[str, Any] | None) -> list[tuple try: from django.conf import settings + if not settings.configured: - print("[tach:harness] WARN: Django settings not configured, skipping DB isolation", file=sys.stderr) + print( + "[tach:harness] WARN: Django settings not configured, skipping DB isolation", + file=sys.stderr, + ) return [] except ImportError: return [] @@ -2585,13 +2766,23 @@ def _apply_django_db_isolation(marker_args: dict[str, Any] | None) -> list[tuple try: connections.close_all() except DatabaseError as e: - print(f"[tach:harness] WARN: Database error closing connections: {e}", file=sys.stderr) + print( + f"[tach:harness] WARN: Database error closing connections: {e}", + file=sys.stderr, + ) except Exception as e: - print(f"[tach:harness] WARN: Failed to close Django connections: {e}", file=sys.stderr) + print( + f"[tach:harness] WARN: Failed to close Django connections: {e}", + file=sys.stderr, + ) # If no marker_args, apply default isolation to all databases if marker_args is None: - marker_args = {"transaction": False, "reset_sequences": False, "databases": None} + marker_args = { + "transaction": False, + "reset_sequences": False, + "databases": None, + } # If transaction=True, skip isolation (test manages its own transactions) if marker_args.get("transaction", False): @@ -2608,7 +2799,10 @@ def _apply_django_db_isolation(marker_args: dict[str, Any] | None) -> list[tuple if alias in connections: valid_databases.append(alias) else: - print(f"[tach:harness] WARN: Unknown database alias '{alias}', skipping", file=sys.stderr) + print( + f"[tach:harness] WARN: Unknown database alias '{alias}', skipping", + file=sys.stderr, + ) # Create savepoints for each database savepoints = [] @@ -2623,25 +2817,46 @@ def _apply_django_db_isolation(marker_args: dict[str, Any] | None) -> list[tuple savepoints.append((alias, sid)) except DatabaseError as e: # Database-specific error during savepoint creation - print(f"[tach:harness] WARN: Database error creating savepoint for '{alias}': {e}", file=sys.stderr) - print(f"[tach:harness] INFO: Rolling back {len(savepoints)} previously created savepoints", file=sys.stderr) + print( + f"[tach:harness] WARN: Database error creating savepoint for '{alias}': {e}", + file=sys.stderr, + ) + print( + f"[tach:harness] INFO: Rolling back {len(savepoints)} previously created savepoints", + file=sys.stderr, + ) for prev_alias, prev_sid in reversed(savepoints): try: transaction.savepoint_rollback(prev_sid, using=prev_alias) except DatabaseError as rollback_error: - print(f"[tach:harness] WARN: Database error rolling back savepoint for '{prev_alias}': {rollback_error}", file=sys.stderr) + print( + f"[tach:harness] WARN: Database error rolling back savepoint for '{prev_alias}': {rollback_error}", + file=sys.stderr, + ) except Exception as rollback_error: - print(f"[tach:harness] WARN: Failed to rollback savepoint for '{prev_alias}': {rollback_error}", file=sys.stderr) + print( + f"[tach:harness] WARN: Failed to rollback savepoint for '{prev_alias}': {rollback_error}", + file=sys.stderr, + ) return [] # Return empty - no isolation applied except Exception as e: # Unexpected error - still roll back and fail gracefully - print(f"[tach:harness] WARN: Failed to create savepoint for '{alias}': {e}", file=sys.stderr) - print(f"[tach:harness] INFO: Rolling back {len(savepoints)} previously created savepoints", file=sys.stderr) + print( + f"[tach:harness] WARN: Failed to create savepoint for '{alias}': {e}", + file=sys.stderr, + ) + print( + f"[tach:harness] INFO: Rolling back {len(savepoints)} previously created savepoints", + file=sys.stderr, + ) for prev_alias, prev_sid in reversed(savepoints): try: transaction.savepoint_rollback(prev_sid, using=prev_alias) except Exception as rollback_error: - print(f"[tach:harness] WARN: Failed to rollback savepoint for '{prev_alias}': {rollback_error}", file=sys.stderr) + print( + f"[tach:harness] WARN: Failed to rollback savepoint for '{prev_alias}': {rollback_error}", + file=sys.stderr, + ) return [] # Return empty - no isolation applied return savepoints @@ -2667,9 +2882,15 @@ def _cleanup_django_db_isolation(savepoints: list[tuple[str, str]]) -> None: try: transaction.savepoint_rollback(sid, using=alias) except DatabaseError as e: - print(f"[tach:harness] WARN: Database error rolling back savepoint for '{alias}': {e}", file=sys.stderr) + print( + f"[tach:harness] WARN: Database error rolling back savepoint for '{alias}': {e}", + file=sys.stderr, + ) except Exception as e: - print(f"[tach:harness] WARN: Failed to rollback savepoint for '{alias}': {e}", file=sys.stderr) + print( + f"[tach:harness] WARN: Failed to rollback savepoint for '{alias}': {e}", + file=sys.stderr, + ) def run_test( @@ -2721,11 +2942,11 @@ def run_test( # Ensure stdout/stderr use UTF-8 encoding after fork/redirect # Workers redirect stdout to memfd which may default to ASCII encoding, # causing UnicodeEncodeError in libraries like colorama - for stream_name in ('stdout', 'stderr'): + for stream_name in ("stdout", "stderr"): stream = getattr(sys, stream_name, None) - if stream is not None and hasattr(stream, 'reconfigure'): + if stream is not None and hasattr(stream, "reconfigure"): try: - stream.reconfigure(encoding='utf-8', errors='replace') + stream.reconfigure(encoding="utf-8", errors="replace") except Exception: pass @@ -2785,12 +3006,16 @@ def run_test( # Parse asyncio marker for loop_scope configuration loop_scope, _has_asyncio_marker = parse_asyncio_marker(target_item) - loop_manager.configure(loop_scope=loop_scope, auto_mode=loop_manager.auto_mode) + loop_manager.configure( + loop_scope=loop_scope, auto_mode=loop_manager.auto_mode + ) # Scope transition handling (Issue #43) - fspath = str(getattr(target_item, 'fspath', '')) - item_cls = getattr(target_item, 'cls', None) - class_name = f"{item_cls.__module__}.{item_cls.__name__}" if item_cls else None + fspath = str(getattr(target_item, "fspath", "")) + item_cls = getattr(target_item, "cls", None) + class_name = ( + f"{item_cls.__module__}.{item_cls.__name__}" if item_cls else None + ) loop_manager.on_scope_transition(fspath, class_name) scope_key = loop_manager.get_scope_key(target_item) @@ -2805,7 +3030,9 @@ def sync_wrapper(*args, **kwargs): # Consume async fixtures in kwargs before calling test for name, val in list(kwargs.items()): if inspect.isasyncgen(val) or asyncio.iscoroutine(val): - kwargs[name] = AsyncFixtureWrapper.consume_async_fixture(name, val, "function") + kwargs[name] = AsyncFixtureWrapper.consume_async_fixture( + name, val, "function" + ) try: return loop.run_until_complete(async_fn(*args, **kwargs)) @@ -2836,29 +3063,44 @@ def sync_wrapper(*args, **kwargs): AsyncFixtureWrapper.set_loop(fixture_loop) # Scope transition handling for async fixtures - fspath_for_fixture = str(getattr(target_item, 'fspath', '')) - item_cls_for_fixture = getattr(target_item, 'cls', None) - class_name_for_fixture = f"{item_cls_for_fixture.__module__}.{item_cls_for_fixture.__name__}" if item_cls_for_fixture else None + fspath_for_fixture = str(getattr(target_item, "fspath", "")) + item_cls_for_fixture = getattr(target_item, "cls", None) + class_name_for_fixture = ( + f"{item_cls_for_fixture.__module__}.{item_cls_for_fixture.__name__}" + if item_cls_for_fixture + else None + ) AsyncFixtureWrapper.on_test_start(fspath_for_fixture, class_name_for_fixture) # Invalidate stale async fixture caches inherited from Zygote parent from _pytest.fixtures import FixtureDef + fixture_info = getattr(target_item, "_fixtureinfo", None) - if fixture_info and hasattr(fixture_info, 'name2fixturedefs'): + if fixture_info and hasattr(fixture_info, "name2fixturedefs"): for fixturedefs in fixture_info.name2fixturedefs.values(): for fixturedef in fixturedefs: cached = getattr(fixturedef, "cached_result", None) - if cached is not None and isinstance(cached, tuple) and len(cached) > 0: + if ( + cached is not None + and isinstance(cached, tuple) + and len(cached) > 0 + ): val = cached[0] if asyncio.iscoroutine(val) or inspect.isasyncgen(val): - scope = getattr(fixturedef, 'scope', 'function') + scope = getattr(fixturedef, "scope", "function") consumed = AsyncFixtureWrapper.consume_async_fixture( fixturedef.argname, val, scope ) - fixturedef.cached_result = (consumed, cached[1], cached[2]) if len(cached) >= 3 else (consumed,) + fixturedef.cached_result = ( + (consumed, cached[1], cached[2]) + if len(cached) >= 3 + else (consumed,) + ) try: - reports = _pytest.runner.runtestprotocol(target_item, nextitem=None, log=False) + reports = _pytest.runner.runtestprotocol( + target_item, nextitem=None, log=False + ) finally: AsyncFixtureWrapper.teardown_function_scope() # Rollback savepoints to restore database state @@ -2895,13 +3137,23 @@ def sync_wrapper(*args, **kwargs): except Exception as enhance_err: # If enhancement fails, use the original message # Debug logging for troubleshooting enhancement failures - print(f"[tach:harness] DEBUG: Enhanced failure formatting failed: {enhance_err}", file=sys.stderr) + print( + f"[tach:harness] DEBUG: Enhanced failure formatting failed: {enhance_err}", + file=sys.stderr, + ) return (STATUS_FAIL, duration, msg, _thread_leak_detected) if skipped_report: - skip_reason = str(skipped_report.longrepr) if skipped_report.longrepr else "" - return (STATUS_SKIP, duration, f"Skipped: {skip_reason}", _thread_leak_detected) + skip_reason = ( + str(skipped_report.longrepr) if skipped_report.longrepr else "" + ) + return ( + STATUS_SKIP, + duration, + f"Skipped: {skip_reason}", + _thread_leak_detected, + ) return (STATUS_PASS, duration, "", _thread_leak_detected) @@ -3005,10 +3257,15 @@ def cleanup_test_modules() -> int: pass # Already removed except Exception as e: # Log but don't crash - dirty worker is better than dead worker - print(f"[tach:harness] WARN: Failed to remove {mod_name}: {e}", file=sys.stderr) + print( + f"[tach:harness] WARN: Failed to remove {mod_name}: {e}", + file=sys.stderr, + ) if removed_count > 0: - print(f"[tach:harness] Cleaned up {removed_count} test modules", file=sys.stderr) + print( + f"[tach:harness] Cleaned up {removed_count} test modules", file=sys.stderr + ) return removed_count @@ -3048,7 +3305,13 @@ def should_worker_exit(is_toxic: bool, thread_leaked: bool = False) -> bool: # ============================================================================= -def worker_loop_iteration(file_path: str, node_id: str, is_toxic: bool, cached_effects: list = None, marker_info: list = None) -> tuple: +def worker_loop_iteration( + file_path: str, + node_id: str, + is_toxic: bool, + cached_effects: list = None, + marker_info: list = None, +) -> tuple: """Execute one iteration of the worker loop. This is the main entry point for persistent workers. @@ -3069,7 +3332,9 @@ def worker_loop_iteration(file_path: str, node_id: str, is_toxic: bool, cached_e - should_exit: Whether worker should exit after this test """ # 1. Execute the test (now returns 4 values including thread_leaked) - status, duration, message, thread_leaked = run_test(file_path, node_id, cached_effects, marker_info) + status, duration, message, thread_leaked = run_test( + file_path, node_id, cached_effects, marker_info + ) # 2. Determine if worker should exit (consider thread leaks) exit_after = should_worker_exit(is_toxic, thread_leaked) @@ -3232,7 +3497,9 @@ def enable_coverage(): ) # Register LINE callback (coverage recording) - sys.monitoring.register_callback(_coverage_tool_id, sys.monitoring.events.LINE, _coverage_line_callback) + sys.monitoring.register_callback( + _coverage_tool_id, sys.monitoring.events.LINE, _coverage_line_callback + ) # Enable both PY_START and LINE events globally sys.monitoring.set_events( @@ -3264,8 +3531,12 @@ def disable_coverage(): sys.monitoring.set_events(_coverage_tool_id, 0) # Unregister callbacks - sys.monitoring.register_callback(_coverage_tool_id, sys.monitoring.events.PY_START, None) - sys.monitoring.register_callback(_coverage_tool_id, sys.monitoring.events.LINE, None) + sys.monitoring.register_callback( + _coverage_tool_id, sys.monitoring.events.PY_START, None + ) + sys.monitoring.register_callback( + _coverage_tool_id, sys.monitoring.events.LINE, None + ) # Free tool ID sys.monitoring.free_tool_id(_coverage_tool_id) @@ -3289,8 +3560,12 @@ def get_coverage_stats() -> dict: return { "enabled": _coverage_enabled, - "coverage_overflow": tach_rust.get_coverage_overflow() if _coverage_enabled else 0, - "mapping_overflow": tach_rust.get_mapping_overflow() if _coverage_enabled else 0, + "coverage_overflow": tach_rust.get_coverage_overflow() + if _coverage_enabled + else 0, + "mapping_overflow": tach_rust.get_mapping_overflow() + if _coverage_enabled + else 0, } except Exception as e: _logger.debug("Coverage stats error: %s", e)