diff --git a/toplocvalidator/batch_processing/validation_engine.py b/toplocvalidator/batch_processing/validation_engine.py index 29a19b7..e69b53d 100644 --- a/toplocvalidator/batch_processing/validation_engine.py +++ b/toplocvalidator/batch_processing/validation_engine.py @@ -9,6 +9,7 @@ import asyncio import multiprocessing as mp import time +import threading from vllm.model_executor.model_loader.loader import _process_weights_after_loading from toplocvalidator import env from toplocvalidator.validation_server import ValidationStatus @@ -180,7 +181,7 @@ def validate_toploc2( ) chunk_logits /= temperature - + # top_logits = chunk_logits.max(dim=1).values neg_gumbel_noise = generate_neg_gumbel_noise( (chunk_logits.shape[0], chunk_logits.shape[-1]), @@ -564,6 +565,11 @@ def __init__( self.config = ValidationConfig.from_toml(config_path) self.model_name_or_path = model_name_or_path self.tp = tp + self.active_validations: Dict[str, float] = {} # rid -> start_time + self.active_validations_lock = threading.Lock() + self.monitor_thread = None + self.stop_monitoring = False + self.worker = None self.start() def is_model_loaded(self) -> bool: @@ -585,6 +591,97 @@ def start(self): ) self.worker.start() + # Start the monitoring thread + self.stop_monitoring = False + self.monitor_thread = threading.Thread(target=self._monitor_timeouts) + self.monitor_thread.daemon = True + self.monitor_thread.start() + + def _monitor_timeouts(self) -> None: + """Monitor active validations for timeouts and restart worker if needed.""" + while not self.stop_monitoring: + try: + current_time = time.time() + timed_out_rids = [] + + with self.active_validations_lock: + for rid, start_time in self.active_validations.items(): + if current_time - start_time > env.VALIDATION_TIMEOUT: + logger.warning( + f"Validation {rid} has been running for " + f"{current_time - start_time:.1f}s, exceeding timeout " + f"of {env.VALIDATION_TIMEOUT}s" + ) + timed_out_rids.append(rid) + + if timed_out_rids: + logger.error( + f"Detected {len(timed_out_rids)} timed out validations. " + f"Restarting worker process..." + ) + self._restart_worker(timed_out_rids) + + # Check every 5 seconds + time.sleep(5) + + except Exception as e: + logger.error(f"Error in timeout monitor: {e}") + time.sleep(5) + + def _restart_worker(self, timed_out_rids: list[str]) -> None: + """Restart the worker process and mark timed out validations as crashed.""" + try: + # Mark all timed out validations as crashed + for rid in timed_out_rids: + if rid in self.rid_to_output: + validation_output = ValidationOutput(event=self.manager.Event()) + validation_output.set_result( + ValidationStatus.CRASHED, + f"Validation timed out after {env.VALIDATION_TIMEOUT}s", + ) + self.rid_to_output[rid] = validation_output + + with self.active_validations_lock: + self.active_validations.pop(rid, None) + + # Kill the existing worker process + if self.worker and self.worker.is_alive(): + logger.info("Terminating stuck worker process...") + self.worker.terminate() + self.worker.join(timeout=10) + + # If still alive, force kill + if self.worker.is_alive(): + logger.warning("Worker didn't terminate gracefully, killing...") + self.worker.kill() + self.worker.join() + + # Clear the model loaded event since we're restarting + self.model_loaded_event.clear() + + # Start a new worker process + logger.info("Starting new worker process...") + self.worker = mp.Process( + target=_validation_worker, + args=( + self.model_name_or_path, + self.tp, + self.config, + self.input_queue, + self.rid_to_output, + self.model_loaded_event, + ), + ) + self.worker.start() + + logger.info("Worker process restarted successfully") + + except Exception as e: + logger.error(f"Error restarting worker: {e}") + import traceback + + traceback.print_exc() + async def validate_entries( self, entries: "list[InferenceEntry]", step: Optional[int] = None ) -> Tuple[ValidationStatus, Optional[str]]: @@ -603,6 +700,10 @@ async def validate_entries( validation_output = ValidationOutput(event=self.manager.Event()) self.rid_to_output[entries[0].rid] = validation_output + # Track the start time of this validation + with self.active_validations_lock: + self.active_validations[entries[0].rid] = time.time() + start_time = time.time() await loop.run_in_executor( None, self.input_queue.put, ValidationInput(entries, None, step) @@ -610,6 +711,11 @@ async def validate_entries( # Get result from worker process await loop.run_in_executor(None, validation_output.event.wait) logger.info(f"Validation time: {time.time() - start_time:.2f} seconds") + + # Remove from active validations + with self.active_validations_lock: + self.active_validations.pop(entries[0].rid, None) + # TODO: This is kind of a hack because of the way ManagerDict works. # Maybe we can fix later but it works for now. validation_output = self.rid_to_output[entries[0].rid] @@ -638,6 +744,10 @@ async def validate_grouped_entries( validation_output = ValidationOutput(event=self.manager.Event()) self.rid_to_output[main_entries[0].rid] = validation_output + # Track the start time of this validation + with self.active_validations_lock: + self.active_validations[main_entries[0].rid] = time.time() + start_time = time.time() await loop.run_in_executor( None, @@ -648,6 +758,10 @@ async def validate_grouped_entries( await loop.run_in_executor(None, validation_output.event.wait) logger.info(f"Validation time: {time.time() - start_time:.2f} seconds") + # Remove from active validations + with self.active_validations_lock: + self.active_validations.pop(main_entries[0].rid, None) + # TODO: This is kind of a hack because of the way ManagerDict works. # Maybe we can fix later but it works for now. validation_output = self.rid_to_output[main_entries[0].rid] @@ -656,12 +770,27 @@ async def validate_grouped_entries( return result def stop(self): - self.input_queue.put(None) - self.worker.join() - self.worker = None + # Stop the monitoring thread + self.stop_monitoring = True + if self.monitor_thread and self.monitor_thread.is_alive(): + self.monitor_thread.join(timeout=10) + + # Stop the worker process + if self.worker: + self.input_queue.put(None) + self.worker.join() + self.worker = None def __del__(self): """Cleanup worker process on deletion.""" + if hasattr(self, "stop_monitoring"): + self.stop_monitoring = True + if ( + hasattr(self, "monitor_thread") + and self.monitor_thread + and self.monitor_thread.is_alive() + ): + self.monitor_thread.join(timeout=10) if hasattr(self, "worker") and self.worker is not None: self.input_queue.put(None) # Send shutdown signal self.worker.join() diff --git a/toplocvalidator/env.py b/toplocvalidator/env.py index 51ba87f..7212e6e 100644 --- a/toplocvalidator/env.py +++ b/toplocvalidator/env.py @@ -9,6 +9,7 @@ TOPLOCVALIDATOR_AUTH_TOKEN: str = "" GCP_CREDENTIALS: str = "" GCP_PULL_TIMEOUT: int = 5 # Timeout in seconds for GCP file pull operations + VALIDATION_TIMEOUT: int = 300 # Timeout in seconds for validation processing SKIP_DATA_SAMPLING_CHECK: bool = False # Feature flag SKIP_TOPLOC2_CHECK: bool = True # Feature flag # Filesha check @@ -44,6 +45,7 @@ "TOPLOCVALIDATOR_AUTH_TOKEN": lambda: os.getenv("TOPLOCVALIDATOR_AUTH_TOKEN", ""), "GCP_CREDENTIALS": lambda: os.getenv("GCP_CREDENTIALS", ""), "GCP_PULL_TIMEOUT": lambda: int(os.getenv("GCP_PULL_TIMEOUT", "5")), + "VALIDATION_TIMEOUT": lambda: int(os.getenv("VALIDATION_TIMEOUT", "300")), "SHARDCAST_SERVERS": lambda: os.getenv("SHARDCAST_SERVERS", "").split(",") if os.getenv("SHARDCAST_SERVERS", "") else None,