diff --git a/pyproject.toml b/pyproject.toml index 48d92d3..b17e0b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "shardcast>=0.1.8", "setuptools", "datasets", + "aiohttp", ] [project.optional-dependencies] diff --git a/toplocvalidator/distributed_validation.py b/toplocvalidator/distributed_validation.py new file mode 100644 index 0000000..c06ca32 --- /dev/null +++ b/toplocvalidator/distributed_validation.py @@ -0,0 +1,346 @@ +"""Distributed validation module for routing requests to worker servers.""" + +import asyncio +import logging +from typing import Optional, List, Dict, Any +import aiohttp +from dataclasses import dataclass +from datetime import datetime + +from toplocvalidator import env +from toplocvalidator.validation_server import ValidationStatus, ValidationState +from toplocvalidator.datamodels import GroupValidationRequest + +logger = logging.getLogger(__name__) + + +@dataclass +class WorkerInfo: + """Information about a worker server.""" + + url: str + is_healthy: bool = False + last_health_check: Optional[datetime] = None + current_load: int = 0 # Number of active validations + + +class DistributedValidationManager: + """Manages distributed validation across worker servers.""" + + def __init__(self, worker_urls: List[str]): + """ + Initialize the distributed validation manager. + + Args: + worker_urls: List of worker server URLs + """ + self.workers = [ + WorkerInfo(url=url.strip()) for url in worker_urls if url.strip() + ] + self.health_check_interval = 30 # seconds + self.session: Optional[aiohttp.ClientSession] = None + self._health_check_task: Optional[asyncio.Task] = None + + async def start(self) -> None: + """Start the distributed validation manager.""" + self.session = aiohttp.ClientSession() + # Start health check task + self._health_check_task = asyncio.create_task(self._health_check_loop()) + # Do initial health check + await self._check_all_workers_health() + + async def stop(self) -> None: + """Stop the distributed validation manager.""" + if self._health_check_task: + self._health_check_task.cancel() + try: + await self._health_check_task + except asyncio.CancelledError: + pass + if self.session: + await self.session.close() + + async def _health_check_loop(self) -> None: + """Periodically check health of all workers.""" + while True: + try: + await asyncio.sleep(self.health_check_interval) + await self._check_all_workers_health() + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Error in health check loop: {e}") + + async def _check_all_workers_health(self) -> None: + """Check health of all worker servers.""" + tasks = [self._check_worker_health(worker) for worker in self.workers] + await asyncio.gather(*tasks, return_exceptions=True) + + healthy_count = sum(1 for w in self.workers if w.is_healthy) + logger.info(f"Worker health check: {healthy_count}/{len(self.workers)} healthy") + + async def _check_worker_health(self, worker: WorkerInfo) -> None: + """Check health of a single worker.""" + if not self.session: + logger.warning("Session not initialized for health check") + return + + try: + async with self.session.get( + f"{worker.url}/health", timeout=aiohttp.ClientTimeout(total=5) + ) as response: + worker.is_healthy = response.status == 200 + worker.last_health_check = datetime.now() + except Exception as e: + logger.warning(f"Worker {worker.url} health check failed: {e}") + worker.is_healthy = False + worker.last_health_check = datetime.now() + + def _select_worker(self) -> Optional[WorkerInfo]: + """ + Select a healthy worker with the lowest load. + + Returns: + Selected worker or None if no healthy workers available + """ + healthy_workers = [w for w in self.workers if w.is_healthy] + if not healthy_workers: + return None + + # Select worker with lowest load + return min(healthy_workers, key=lambda w: w.current_load) + + async def validate_file( + self, filepath: str, file_sha: str, node_address: Optional[str] = None + ) -> ValidationState: + """ + Validate a file using a worker server. + + Args: + filepath: Path to the file to validate + file_sha: SHA of the file + node_address: Optional node address for tracking + + Returns: + ValidationState with the result + """ + worker = self._select_worker() + if not worker: + logger.error("No healthy workers available") + return ValidationState( + status=ValidationStatus.CRASHED, + reason="No healthy workers available for validation", + ) + + if not self.session: + logger.error("Session not initialized") + return ValidationState( + status=ValidationStatus.CRASHED, + reason="Distributed validation manager not properly initialized", + ) + + worker.current_load += 1 + try: + # Build request data matching the server's ValidationRequest model + request_data = { + "file_sha": file_sha, + "address": node_address, + } + + logger.info(f"Routing validation of {filepath} to worker {worker.url}") + + # Add auth header if token is set + headers = {} + if env.TOPLOCVALIDATOR_AUTH_TOKEN: + headers["Authorization"] = f"Bearer {env.TOPLOCVALIDATOR_AUTH_TOKEN}" + + # Call the worker's validate endpoint + async with self.session.post( + f"{worker.url}/validate/{filepath}", + json=request_data, + headers=headers, + timeout=aiohttp.ClientTimeout( + total=30 + ), # Short timeout for initial request + ) as response: + if response.status != 200: + error_text = await response.text() + logger.error( + f"Worker returned error: {response.status} - {error_text}" + ) + return ValidationState( + status=ValidationStatus.CRASHED, + reason=f"Worker error: {response.status} - {error_text}", + ) + + # The worker will start validation and return immediately + # We need to poll for the result + await asyncio.sleep(1) # Give it a moment to start + + # Poll the status endpoint + max_polls = 600 # 10 minutes with 1 second intervals + for _ in range(max_polls): + async with self.session.get( + f"{worker.url}/status/{filepath}", + headers=headers, + timeout=aiohttp.ClientTimeout(total=5), + ) as status_response: + if status_response.status == 200: + status_result = await status_response.json() + status_value = status_result["status"] + + if status_value != "pending": + # Validation complete + return ValidationState( + status=ValidationStatus(status_value), + reason="", # Basic status endpoint doesn't return reason + sample_count=0, # Basic status endpoint doesn't return sample count + ) + + await asyncio.sleep(1) + + # Timeout + return ValidationState( + status=ValidationStatus.CRASHED, + reason="Worker validation timeout", + ) + + except asyncio.TimeoutError: + logger.error(f"Timeout validating {filepath} on worker {worker.url}") + return ValidationState( + status=ValidationStatus.CRASHED, reason="Worker validation timeout" + ) + except Exception as e: + logger.error(f"Error validating {filepath} on worker {worker.url}: {e}") + return ValidationState( + status=ValidationStatus.CRASHED, + reason=f"Worker validation error: {str(e)}", + ) + finally: + worker.current_load -= 1 + + async def validate_group( + self, filepath: str, file_data: GroupValidationRequest + ) -> ValidationState: + """ + Validate a group of files using a worker server. + + Args: + filepath: Path to the main file + file_data: Group validation request data + + Returns: + ValidationState with the result + """ + worker = self._select_worker() + if not worker: + logger.error("No healthy workers available") + return ValidationState( + status=ValidationStatus.CRASHED, + reason="No healthy workers available for validation", + ) + + if not self.session: + logger.error("Session not initialized") + return ValidationState( + status=ValidationStatus.CRASHED, + reason="Distributed validation manager not properly initialized", + ) + + worker.current_load += 1 + try: + # Use the standard server endpoint + request_data = file_data.model_dump() + + logger.info( + f"Routing group validation of {filepath} to worker {worker.url}" + ) + + # Add auth header if token is set + headers = {} + if env.TOPLOCVALIDATOR_AUTH_TOKEN: + headers["Authorization"] = f"Bearer {env.TOPLOCVALIDATOR_AUTH_TOKEN}" + + # Call the worker's validategroup endpoint + async with self.session.post( + f"{worker.url}/validategroup/{filepath}", + json=request_data, + headers=headers, + timeout=aiohttp.ClientTimeout( + total=30 + ), # Short timeout for initial request + ) as response: + if response.status != 200: + error_text = await response.text() + logger.error( + f"Worker returned error: {response.status} - {error_text}" + ) + return ValidationState( + status=ValidationStatus.CRASHED, + reason=f"Worker error: {response.status} - {error_text}", + ) + + # The worker will start validation and return immediately + # We need to poll for the result using the statusgroup endpoint + await asyncio.sleep(1) # Give it a moment to start + + # Poll the statusgroup endpoint + max_polls = 1200 # 20 minutes with 1 second intervals + for _ in range(max_polls): + async with self.session.get( + f"{worker.url}/statusgroup/{filepath}", + headers=headers, + timeout=aiohttp.ClientTimeout(total=5), + ) as status_response: + if status_response.status == 200: + result = await status_response.json() + status_value = result["status"] + + if status_value != "pending": + # Validation complete + return ValidationState( + status=ValidationStatus(status_value), + reason=result.get("reason", ""), + failing_indices=result.get("failing_indices", []), + input_flops=result.get("input_flops", 0), + output_flops=result.get("output_flops", 0), + sample_count=0, # Not available in statusgroup endpoint + ) + + await asyncio.sleep(1) + + # Timeout + return ValidationState( + status=ValidationStatus.CRASHED, + reason="Worker validation timeout", + ) + + except asyncio.TimeoutError: + logger.error(f"Timeout validating group {filepath} on worker {worker.url}") + return ValidationState( + status=ValidationStatus.CRASHED, reason="Worker validation timeout" + ) + except Exception as e: + logger.error( + f"Error validating group {filepath} on worker {worker.url}: {e}" + ) + return ValidationState( + status=ValidationStatus.CRASHED, + reason=f"Worker validation error: {str(e)}", + ) + finally: + worker.current_load -= 1 + + def get_workers_status(self) -> List[Dict[str, Any]]: + """Get status information for all workers.""" + return [ + { + "url": worker.url, + "is_healthy": worker.is_healthy, + "last_health_check": worker.last_health_check.isoformat() + if worker.last_health_check + else None, + "current_load": worker.current_load, + } + for worker in self.workers + ] diff --git a/toplocvalidator/env.py b/toplocvalidator/env.py index d78f87f..284c323 100644 --- a/toplocvalidator/env.py +++ b/toplocvalidator/env.py @@ -42,6 +42,10 @@ TOPLOC_DISCORD_WEBHOOK: str = "" TOPLOC_FLOP_SCALE_FACTOR: int = 1 + # Worker configuration for distributed validation + WORKER_URLS: list[str] | None = None # Comma-separated list of worker URLs + ENABLE_DISTRIBUTED: bool = False # Enable distributed validation + _env = { "SKIP_FILESHA_CHECK": lambda: os.getenv("SKIP_FILESHA_CHECK", "false").lower() in ["true", "1", "yes", "y"], @@ -104,6 +108,11 @@ in ["true", "1", "yes", "y"], "TOPLOC_DISCORD_WEBHOOK": lambda: os.getenv("TOPLOC_DISCORD_WEBHOOK", ""), "TOPLOC_FLOP_SCALE_FACTOR": lambda: int(os.getenv("TOPLOC_FLOP_SCALE_FACTOR", "1")), + "WORKER_URLS": lambda: os.getenv("WORKER_URLS", "").split(",") + if os.getenv("WORKER_URLS", "") + else None, + "ENABLE_DISTRIBUTED": lambda: os.getenv("ENABLE_DISTRIBUTED", "false").lower() + in ["true", "1", "yes", "y"], } diff --git a/toplocvalidator/server.py b/toplocvalidator/server.py index eccad3d..c9e1849 100644 --- a/toplocvalidator/server.py +++ b/toplocvalidator/server.py @@ -30,6 +30,7 @@ from toplocvalidator.file_retrieval.gcp import GcpBucket from toplocvalidator.file_processing import set_topk from toplocvalidator.datamodels import GroupValidationRequest +from toplocvalidator.distributed_validation import DistributedValidationManager # Set up logging logging.basicConfig(level=logging.INFO) @@ -37,6 +38,7 @@ server_state = ValidationStateManager() validation_engine = None +distributed_manager: DistributedValidationManager | None = None gcp_bucket: GcpBucket | None = None @@ -113,8 +115,23 @@ async def lifespan(app: FastAPI): else: shardcast_process = None + # Initialize distributed validation manager if enabled + global distributed_manager + if env.ENABLE_DISTRIBUTED and env.WORKER_URLS: + logger.info( + f"Initializing distributed validation with workers: {env.WORKER_URLS}" + ) + distributed_manager = DistributedValidationManager(env.WORKER_URLS) + await distributed_manager.start() + else: + distributed_manager = None + yield # Shutdown + if distributed_manager: + logger.info("Stopping distributed validation manager") + await distributed_manager.stop() + logger.info(f"Cleaning up temporary directory: {env.TMP_DIR}") try: if env.TMP_DIR.exists(): @@ -127,7 +144,8 @@ async def lifespan(app: FastAPI): import signal # SIGTERM is not working, so we use SIGKILL - os.kill(shardcast_process.pid, signal.SIGKILL) + if shardcast_process.pid: + os.kill(shardcast_process.pid, signal.SIGKILL) shardcast_process.join() @@ -164,7 +182,15 @@ async def process_validation(filepath: str, sha: str, node_address: str) -> None # Use debug validation if debug time is set, otherwise use real validation if hasattr(app.state, "debug_validation_time"): result = await debug_validate_file(filepath, sha) + state = ValidationState(status=result, reason="Debug validation") + elif env.ENABLE_DISTRIBUTED and distributed_manager: + # Use distributed validation + state = await distributed_manager.validate_file(filepath, sha, node_address) + result = state.status else: + # Use local validation + if not validation_engine or not gcp_bucket: + raise RuntimeError("Validation engine or GCP bucket not initialized") state = await process_file( engine=validation_engine, gcp_bucket=gcp_bucket, @@ -178,7 +204,8 @@ async def process_validation(filepath: str, sha: str, node_address: str) -> None leaky_bucket_dict[node_address] = max( leaky_bucket_dict.get(node_address, 0) - 0.1, 0 ) - gcp_bucket.set_accepted_flag(filepath, state.sample_count) + if gcp_bucket: + gcp_bucket.set_accepted_flag(filepath, state.sample_count) elif result == ValidationStatus.MISSING: result = ValidationStatus.CRASHED # TODO: This is for BC reasons else: @@ -212,8 +239,19 @@ async def process_group_validation( try: if hasattr(app.state, "debug_validation_time"): result = await debug_validate_file(filepath, ", ".join(file_data.file_shas)) - state = ValidationState(status=result.status, reason="Debug validation") + state = ValidationState(status=result, reason="Debug validation") + elif env.ENABLE_DISTRIBUTED and distributed_manager: + # Use distributed validation + p_filepath = Path(filepath) + filepaths = [ + str(p_filepath.with_name(f"{p_filepath.stem}-{i}{p_filepath.suffix}")) + for i in range(file_data.group_size) + ] + state = await distributed_manager.validate_group(filepath, file_data) else: + # Use local validation + if not validation_engine or not gcp_bucket: + raise RuntimeError("Validation engine or GCP bucket not initialized") p_filepath = Path(filepath) filepaths = [ str(p_filepath.with_name(f"{p_filepath.stem}-{i}{p_filepath.suffix}")) @@ -318,7 +356,7 @@ async def validate_file( # Run validation in background without waiting for it to complete asyncio.create_task( - process_validation(filepath, file_data.file_sha, file_data.address) + process_validation(filepath, file_data.file_sha, file_data.address or "unknown") ) return {"message": f"Validation started for {filepath}"} @@ -418,6 +456,23 @@ async def health_check(): return {"status": "healthy", "message": "Server is ready"} +@app.get("/workers", tags=["Health Check"]) +async def get_workers_status(token: Annotated[str, Depends(verify_token)]): + """ + Get the status of all worker servers when distributed mode is enabled. + + Returns: + List of worker status information or error if distributed mode is not enabled + """ + if not env.ENABLE_DISTRIBUTED or not distributed_manager: + raise HTTPException(status_code=400, detail="Distributed mode is not enabled") + + return { + "workers": distributed_manager.get_workers_status(), + "distributed_enabled": True, + } + + class ValidationStatesDistribution(BaseModel): pending: int = 0 accepted: int = 0