From bf9c1c3f734148cf1491875e3f2e19a637954b0d Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Wed, 25 Jun 2025 05:36:02 +0000 Subject: [PATCH 1/3] zero shot --- toplocvalidator/distributed_validation.py | 284 ++++++++++++++++++++++ toplocvalidator/env.py | 9 + toplocvalidator/server.py | 10 + toplocvalidator/worker_server.py | 279 +++++++++++++++++++++ 4 files changed, 582 insertions(+) create mode 100644 toplocvalidator/distributed_validation.py create mode 100644 toplocvalidator/worker_server.py diff --git a/toplocvalidator/distributed_validation.py b/toplocvalidator/distributed_validation.py new file mode 100644 index 0000000..1ceda2c --- /dev/null +++ b/toplocvalidator/distributed_validation.py @@ -0,0 +1,284 @@ +"""Distributed validation module for routing requests to worker servers.""" + +import asyncio +import logging +from typing import Optional, List, Dict, Any +import aiohttp +from dataclasses import dataclass +from datetime import datetime + +from toplocvalidator.validation_server import ValidationStatus, ValidationState +from toplocvalidator.datamodels import GroupValidationRequest + +logger = logging.getLogger(__name__) + + +@dataclass +class WorkerInfo: + """Information about a worker server.""" + + url: str + is_healthy: bool = False + last_health_check: Optional[datetime] = None + current_load: int = 0 # Number of active validations + + +class DistributedValidationManager: + """Manages distributed validation across worker servers.""" + + def __init__(self, worker_urls: List[str]): + """ + Initialize the distributed validation manager. + + Args: + worker_urls: List of worker server URLs + """ + self.workers = [ + WorkerInfo(url=url.strip()) for url in worker_urls if url.strip() + ] + self.health_check_interval = 30 # seconds + self.session: Optional[aiohttp.ClientSession] = None + self._health_check_task: Optional[asyncio.Task] = None + + async def start(self) -> None: + """Start the distributed validation manager.""" + self.session = aiohttp.ClientSession() + # Start health check task + self._health_check_task = asyncio.create_task(self._health_check_loop()) + # Do initial health check + await self._check_all_workers_health() + + async def stop(self) -> None: + """Stop the distributed validation manager.""" + if self._health_check_task: + self._health_check_task.cancel() + try: + await self._health_check_task + except asyncio.CancelledError: + pass + if self.session: + await self.session.close() + + async def _health_check_loop(self) -> None: + """Periodically check health of all workers.""" + while True: + try: + await asyncio.sleep(self.health_check_interval) + await self._check_all_workers_health() + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Error in health check loop: {e}") + + async def _check_all_workers_health(self) -> None: + """Check health of all worker servers.""" + tasks = [self._check_worker_health(worker) for worker in self.workers] + await asyncio.gather(*tasks, return_exceptions=True) + + healthy_count = sum(1 for w in self.workers if w.is_healthy) + logger.info(f"Worker health check: {healthy_count}/{len(self.workers)} healthy") + + async def _check_worker_health(self, worker: WorkerInfo) -> None: + """Check health of a single worker.""" + if not self.session: + logger.warning("Session not initialized for health check") + return + + try: + async with self.session.get( + f"{worker.url}/worker/health", timeout=aiohttp.ClientTimeout(total=5) + ) as response: + worker.is_healthy = response.status == 200 + worker.last_health_check = datetime.now() + except Exception as e: + logger.warning(f"Worker {worker.url} health check failed: {e}") + worker.is_healthy = False + worker.last_health_check = datetime.now() + + def _select_worker(self) -> Optional[WorkerInfo]: + """ + Select a healthy worker with the lowest load. + + Returns: + Selected worker or None if no healthy workers available + """ + healthy_workers = [w for w in self.workers if w.is_healthy] + if not healthy_workers: + return None + + # Select worker with lowest load + return min(healthy_workers, key=lambda w: w.current_load) + + async def validate_file( + self, filepath: str, file_sha: str, node_address: Optional[str] = None + ) -> ValidationState: + """ + Validate a file using a worker server. + + Args: + filepath: Path to the file to validate + file_sha: SHA of the file + node_address: Optional node address for tracking + + Returns: + ValidationState with the result + """ + worker = self._select_worker() + if not worker: + logger.error("No healthy workers available") + return ValidationState( + status=ValidationStatus.CRASHED, + reason="No healthy workers available for validation", + ) + + if not self.session: + logger.error("Session not initialized") + return ValidationState( + status=ValidationStatus.CRASHED, + reason="Distributed validation manager not properly initialized", + ) + + worker.current_load += 1 + try: + request_data = { + "filepath": filepath, + "file_sha": file_sha, + "node_address": node_address, + } + + logger.info(f"Routing validation of {filepath} to worker {worker.url}") + + async with self.session.post( + f"{worker.url}/worker/validate", + json=request_data, + timeout=aiohttp.ClientTimeout(total=600), # 10 minute timeout + ) as response: + if response.status != 200: + error_text = await response.text() + logger.error( + f"Worker returned error: {response.status} - {error_text}" + ) + return ValidationState( + status=ValidationStatus.CRASHED, + reason=f"Worker error: {response.status} - {error_text}", + ) + + result = await response.json() + + # Convert response to ValidationState + status = ValidationStatus(result["status"]) + return ValidationState( + status=status, + reason=result.get("reason", ""), + sample_count=result.get("sample_count", 0), + ) + + except asyncio.TimeoutError: + logger.error(f"Timeout validating {filepath} on worker {worker.url}") + return ValidationState( + status=ValidationStatus.CRASHED, reason="Worker validation timeout" + ) + except Exception as e: + logger.error(f"Error validating {filepath} on worker {worker.url}: {e}") + return ValidationState( + status=ValidationStatus.CRASHED, + reason=f"Worker validation error: {str(e)}", + ) + finally: + worker.current_load -= 1 + + async def validate_group( + self, filepath: str, file_data: GroupValidationRequest + ) -> ValidationState: + """ + Validate a group of files using a worker server. + + Args: + filepath: Path to the main file + file_data: Group validation request data + + Returns: + ValidationState with the result + """ + worker = self._select_worker() + if not worker: + logger.error("No healthy workers available") + return ValidationState( + status=ValidationStatus.CRASHED, + reason="No healthy workers available for validation", + ) + + if not self.session: + logger.error("Session not initialized") + return ValidationState( + status=ValidationStatus.CRASHED, + reason="Distributed validation manager not properly initialized", + ) + + worker.current_load += 1 + try: + request_data = {"filepath": filepath, "file_data": file_data.model_dump()} + + logger.info( + f"Routing group validation of {filepath} to worker {worker.url}" + ) + + async with self.session.post( + f"{worker.url}/worker/validategroup", + json=request_data, + timeout=aiohttp.ClientTimeout( + total=1200 + ), # 20 minute timeout for groups + ) as response: + if response.status != 200: + error_text = await response.text() + logger.error( + f"Worker returned error: {response.status} - {error_text}" + ) + return ValidationState( + status=ValidationStatus.CRASHED, + reason=f"Worker error: {response.status} - {error_text}", + ) + + result = await response.json() + + # Convert response to ValidationState + status = ValidationStatus(result["status"]) + return ValidationState( + status=status, + reason=result.get("reason", ""), + failing_indices=result.get("failing_indices", []), + input_flops=result.get("input_flops", 0), + output_flops=result.get("output_flops", 0), + sample_count=result.get("sample_count", 0), + ) + + except asyncio.TimeoutError: + logger.error(f"Timeout validating group {filepath} on worker {worker.url}") + return ValidationState( + status=ValidationStatus.CRASHED, reason="Worker validation timeout" + ) + except Exception as e: + logger.error( + f"Error validating group {filepath} on worker {worker.url}: {e}" + ) + return ValidationState( + status=ValidationStatus.CRASHED, + reason=f"Worker validation error: {str(e)}", + ) + finally: + worker.current_load -= 1 + + def get_workers_status(self) -> List[Dict[str, Any]]: + """Get status information for all workers.""" + return [ + { + "url": worker.url, + "is_healthy": worker.is_healthy, + "last_health_check": worker.last_health_check.isoformat() + if worker.last_health_check + else None, + "current_load": worker.current_load, + } + for worker in self.workers + ] diff --git a/toplocvalidator/env.py b/toplocvalidator/env.py index d78f87f..284c323 100644 --- a/toplocvalidator/env.py +++ b/toplocvalidator/env.py @@ -42,6 +42,10 @@ TOPLOC_DISCORD_WEBHOOK: str = "" TOPLOC_FLOP_SCALE_FACTOR: int = 1 + # Worker configuration for distributed validation + WORKER_URLS: list[str] | None = None # Comma-separated list of worker URLs + ENABLE_DISTRIBUTED: bool = False # Enable distributed validation + _env = { "SKIP_FILESHA_CHECK": lambda: os.getenv("SKIP_FILESHA_CHECK", "false").lower() in ["true", "1", "yes", "y"], @@ -104,6 +108,11 @@ in ["true", "1", "yes", "y"], "TOPLOC_DISCORD_WEBHOOK": lambda: os.getenv("TOPLOC_DISCORD_WEBHOOK", ""), "TOPLOC_FLOP_SCALE_FACTOR": lambda: int(os.getenv("TOPLOC_FLOP_SCALE_FACTOR", "1")), + "WORKER_URLS": lambda: os.getenv("WORKER_URLS", "").split(",") + if os.getenv("WORKER_URLS", "") + else None, + "ENABLE_DISTRIBUTED": lambda: os.getenv("ENABLE_DISTRIBUTED", "false").lower() + in ["true", "1", "yes", "y"], } diff --git a/toplocvalidator/server.py b/toplocvalidator/server.py index eccad3d..5067d2c 100644 --- a/toplocvalidator/server.py +++ b/toplocvalidator/server.py @@ -30,6 +30,7 @@ from toplocvalidator.file_retrieval.gcp import GcpBucket from toplocvalidator.file_processing import set_topk from toplocvalidator.datamodels import GroupValidationRequest +from toplocvalidator.distributed_validation import DistributedValidationManager # Set up logging logging.basicConfig(level=logging.INFO) @@ -37,6 +38,7 @@ server_state = ValidationStateManager() validation_engine = None +distributed_manager: DistributedValidationManager | None = None gcp_bucket: GcpBucket | None = None @@ -164,7 +166,15 @@ async def process_validation(filepath: str, sha: str, node_address: str) -> None # Use debug validation if debug time is set, otherwise use real validation if hasattr(app.state, "debug_validation_time"): result = await debug_validate_file(filepath, sha) + state = ValidationState(status=result, reason="Debug validation") + elif env.ENABLE_DISTRIBUTED and distributed_manager: + # Use distributed validation + state = await distributed_manager.validate_file(filepath, sha, node_address) + result = state.status else: + # Use local validation + if not validation_engine or not gcp_bucket: + raise RuntimeError("Validation engine or GCP bucket not initialized") state = await process_file( engine=validation_engine, gcp_bucket=gcp_bucket, diff --git a/toplocvalidator/worker_server.py b/toplocvalidator/worker_server.py new file mode 100644 index 0000000..29b9ffc --- /dev/null +++ b/toplocvalidator/worker_server.py @@ -0,0 +1,279 @@ +"""Worker server for distributed validation processing. + +This server performs validation work without maintaining state. +It receives validation requests from the main coordinator server. +""" + +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +import logging +from contextlib import asynccontextmanager +import shutil +from typing import Optional +from pathlib import Path + +from toplocvalidator import env +from toplocvalidator.shardcast_downloader import run_main_bg +from toplocvalidator.validation_server import ValidationStatus +from toplocvalidator.batch_processing.validation_engine import ValidationEngine +from toplocvalidator.validation_flow import process_file, process_group_files +from toplocvalidator.file_retrieval.gcp import GcpBucket +from toplocvalidator.datamodels import GroupValidationRequest + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +validation_engine: Optional[ValidationEngine] = None +gcp_bucket: Optional[GcpBucket] = None + + +class WorkerValidationRequest(BaseModel): + filepath: str + file_sha: str + node_address: Optional[str] = None + + +class WorkerGroupValidationRequest(BaseModel): + filepath: str + file_data: GroupValidationRequest + + +class WorkerValidationResponse(BaseModel): + status: str + reason: str = "" + failing_indices: list[int] = [] + input_flops: int = 0 + output_flops: int = 0 + sample_count: int = 0 + + +def initialize_gcp_bucket(gcp_path: str) -> None: + """Initialize the GCP bucket with the given path.""" + global gcp_bucket + if env.GCP_CREDENTIALS == "": + raise RuntimeError( + "GCP credentials not set. Please set the GCP_CREDENTIALS environment variable." + ) + gcp_bucket = GcpBucket(gcp_path=gcp_path, credentials_base64=env.GCP_CREDENTIALS) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Lifespan context manager for startup and shutdown events.""" + # Startup + env.TMP_DIR.mkdir(parents=True, exist_ok=True) + if env.SHARDCAST_SERVERS: + shardcast_process = run_main_bg( + servers=env.SHARDCAST_SERVERS, + output_dir=env.SHARDCAST_OUTPUT_DIR, + versions_to_keep=env.SHARDCAST_VERSIONS_TO_KEEP, + backlog_version=env.SHARDCAST_BACKLOG_VERSION, + ) + else: + shardcast_process = None + + yield + # Shutdown + logger.info(f"Cleaning up temporary directory: {env.TMP_DIR}") + try: + if env.TMP_DIR.exists(): + shutil.rmtree(env.TMP_DIR) + logger.info("Temporary directory cleaned up successfully") + except Exception as e: + logger.error(f"Error cleaning up temporary directory: {e}") + if shardcast_process: + import os + import signal + + if shardcast_process.pid: + os.kill(shardcast_process.pid, signal.SIGKILL) + shardcast_process.join() + + +app = FastAPI(lifespan=lifespan, title="TopLoc Worker Server") + + +@app.post("/worker/validate", response_model=WorkerValidationResponse) +async def validate_file(request: WorkerValidationRequest) -> WorkerValidationResponse: + """ + Validate a single file and return the result. + + This endpoint performs the actual validation work without maintaining state. + """ + logger.info( + f"Worker received validation request for {request.filepath} with SHA: {request.file_sha}" + ) + + if validation_engine is None or gcp_bucket is None: + logger.error("Validation engine or GCP bucket not initialized") + return WorkerValidationResponse( + status=ValidationStatus.CRASHED.value, + reason="Worker not properly initialized", + ) + + try: + state = await process_file( + engine=validation_engine, + gcp_bucket=gcp_bucket, + filepath=request.filepath, + sha=request.file_sha, + ) + + return WorkerValidationResponse( + status=state.status.value, + reason=state.reason, + sample_count=state.sample_count, + ) + except Exception as e: + logger.error(f"Error validating {request.filepath}: {e}") + return WorkerValidationResponse( + status=ValidationStatus.CRASHED.value, + reason=str(e), + ) + + +@app.post("/worker/validategroup", response_model=WorkerValidationResponse) +async def validate_group( + request: WorkerGroupValidationRequest, +) -> WorkerValidationResponse: + """ + Validate a group of files and return the result. + + This endpoint performs group validation without maintaining state. + """ + logger.info(f"Worker received group validation request for {request.filepath}") + + if validation_engine is None or gcp_bucket is None: + logger.error("Validation engine or GCP bucket not initialized") + return WorkerValidationResponse( + status=ValidationStatus.CRASHED.value, + reason="Worker not properly initialized", + ) + + try: + p_filepath = Path(request.filepath) + filepaths = [ + str(p_filepath.with_name(f"{p_filepath.stem}-{i}{p_filepath.suffix}")) + for i in range(request.file_data.group_size) + ] + + state = await process_group_files( + engine=validation_engine, + gcp_bucket=gcp_bucket, + filepaths=filepaths, + file_data=request.file_data, + ) + + return WorkerValidationResponse( + status=state.status.value, + reason=state.reason, + failing_indices=state.failing_indices, + input_flops=state.input_flops, + output_flops=state.output_flops, + sample_count=state.sample_count, + ) + except Exception as e: + logger.error(f"Error validating group {request.filepath}: {e}") + return WorkerValidationResponse( + status=ValidationStatus.CRASHED.value, + reason=str(e), + ) + + +@app.get("/worker/health") +async def health_check(): + """ + Health check endpoint for the worker. + """ + global validation_engine + + if validation_engine is None: + raise HTTPException( + status_code=425, detail="Model is being loaded, please try again later" + ) + + if not validation_engine.is_model_loaded(): + raise HTTPException( + status_code=425, detail="Model is being loaded, please try again later" + ) + + return {"status": "healthy", "message": "Worker is ready"} + + +@app.get("/worker/info") +async def worker_info(): + """ + Get information about this worker instance. + """ + import socket + + hostname = socket.gethostname() + ip_address = socket.gethostbyname(hostname) + + return { + "hostname": hostname, + "ip_address": ip_address, + "model": validation_engine.model_name_or_path if validation_engine else None, + } + + +def main(): + """Main function to start the worker server.""" + import argparse + import uvicorn + from toplocvalidator.file_processing import set_topk + + parser = argparse.ArgumentParser(description="TopLoc Worker Server") + parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to") + parser.add_argument("--port", type=int, default=8001, help="Port to bind to") + parser.add_argument( + "--topk", type=int, default=None, help="Top-k value for validation" + ) + parser.add_argument( + "--gcp-path", type=str, default="", help="GCP path for file retrieval" + ) + parser.add_argument( + "--model-name-or-path", + type=str, + required=True, + help="Model name or path for validation", + ) + parser.add_argument( + "--tp", type=int, default=1, help="Tensor parallel size for the vllm engine" + ) + parser.add_argument( + "--config-path", + type=str, + default=None, + help="Path to custom validation config file", + ) + parser.add_argument("--log-level", type=str, default="INFO", help="Logging level") + + args = parser.parse_args() + + # Set logging level + logging.getLogger().setLevel(args.log_level) + + # Set top-k if provided + if args.topk is not None: + set_topk(args.topk) + + # Initialize GCP bucket + if args.gcp_path: + initialize_gcp_bucket(args.gcp_path) + + # Initialize validation engine + global validation_engine + validation_engine = ValidationEngine( + model_name_or_path=args.model_name_or_path, + tp=args.tp, + config_path=args.config_path, + ) + + # Run the server + uvicorn.run(app, host=args.host, port=args.port) + + +if __name__ == "__main__": + main() From e69d37fbe2b265a10bfb639ea91f963ea6b6765f Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Wed, 25 Jun 2025 06:28:14 +0000 Subject: [PATCH 2/3] better? --- pyproject.toml | 2 ++ toplocvalidator/server.py | 53 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 48d92d3..acf391e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "shardcast>=0.1.8", "setuptools", "datasets", + "aiohttp", ] [project.optional-dependencies] @@ -36,6 +37,7 @@ dev-dependencies = [ [project.scripts] toplocvalidator = "toplocvalidator.server:main" +toplocvalidator-worker = "toplocvalidator.worker_server:main" [build-system] requires = ["hatchling"] diff --git a/toplocvalidator/server.py b/toplocvalidator/server.py index 5067d2c..c9e1849 100644 --- a/toplocvalidator/server.py +++ b/toplocvalidator/server.py @@ -115,8 +115,23 @@ async def lifespan(app: FastAPI): else: shardcast_process = None + # Initialize distributed validation manager if enabled + global distributed_manager + if env.ENABLE_DISTRIBUTED and env.WORKER_URLS: + logger.info( + f"Initializing distributed validation with workers: {env.WORKER_URLS}" + ) + distributed_manager = DistributedValidationManager(env.WORKER_URLS) + await distributed_manager.start() + else: + distributed_manager = None + yield # Shutdown + if distributed_manager: + logger.info("Stopping distributed validation manager") + await distributed_manager.stop() + logger.info(f"Cleaning up temporary directory: {env.TMP_DIR}") try: if env.TMP_DIR.exists(): @@ -129,7 +144,8 @@ async def lifespan(app: FastAPI): import signal # SIGTERM is not working, so we use SIGKILL - os.kill(shardcast_process.pid, signal.SIGKILL) + if shardcast_process.pid: + os.kill(shardcast_process.pid, signal.SIGKILL) shardcast_process.join() @@ -188,7 +204,8 @@ async def process_validation(filepath: str, sha: str, node_address: str) -> None leaky_bucket_dict[node_address] = max( leaky_bucket_dict.get(node_address, 0) - 0.1, 0 ) - gcp_bucket.set_accepted_flag(filepath, state.sample_count) + if gcp_bucket: + gcp_bucket.set_accepted_flag(filepath, state.sample_count) elif result == ValidationStatus.MISSING: result = ValidationStatus.CRASHED # TODO: This is for BC reasons else: @@ -222,8 +239,19 @@ async def process_group_validation( try: if hasattr(app.state, "debug_validation_time"): result = await debug_validate_file(filepath, ", ".join(file_data.file_shas)) - state = ValidationState(status=result.status, reason="Debug validation") + state = ValidationState(status=result, reason="Debug validation") + elif env.ENABLE_DISTRIBUTED and distributed_manager: + # Use distributed validation + p_filepath = Path(filepath) + filepaths = [ + str(p_filepath.with_name(f"{p_filepath.stem}-{i}{p_filepath.suffix}")) + for i in range(file_data.group_size) + ] + state = await distributed_manager.validate_group(filepath, file_data) else: + # Use local validation + if not validation_engine or not gcp_bucket: + raise RuntimeError("Validation engine or GCP bucket not initialized") p_filepath = Path(filepath) filepaths = [ str(p_filepath.with_name(f"{p_filepath.stem}-{i}{p_filepath.suffix}")) @@ -328,7 +356,7 @@ async def validate_file( # Run validation in background without waiting for it to complete asyncio.create_task( - process_validation(filepath, file_data.file_sha, file_data.address) + process_validation(filepath, file_data.file_sha, file_data.address or "unknown") ) return {"message": f"Validation started for {filepath}"} @@ -428,6 +456,23 @@ async def health_check(): return {"status": "healthy", "message": "Server is ready"} +@app.get("/workers", tags=["Health Check"]) +async def get_workers_status(token: Annotated[str, Depends(verify_token)]): + """ + Get the status of all worker servers when distributed mode is enabled. + + Returns: + List of worker status information or error if distributed mode is not enabled + """ + if not env.ENABLE_DISTRIBUTED or not distributed_manager: + raise HTTPException(status_code=400, detail="Distributed mode is not enabled") + + return { + "workers": distributed_manager.get_workers_status(), + "distributed_enabled": True, + } + + class ValidationStatesDistribution(BaseModel): pending: int = 0 accepted: int = 0 From 1acff3c2d8455ea105d9e0b117135f44dc84e8c4 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Wed, 25 Jun 2025 06:35:04 +0000 Subject: [PATCH 3/3] better? --- pyproject.toml | 1 - toplocvalidator/distributed_validation.py | 114 +++++++-- toplocvalidator/worker_server.py | 279 ---------------------- 3 files changed, 88 insertions(+), 306 deletions(-) delete mode 100644 toplocvalidator/worker_server.py diff --git a/pyproject.toml b/pyproject.toml index acf391e..b17e0b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,6 @@ dev-dependencies = [ [project.scripts] toplocvalidator = "toplocvalidator.server:main" -toplocvalidator-worker = "toplocvalidator.worker_server:main" [build-system] requires = ["hatchling"] diff --git a/toplocvalidator/distributed_validation.py b/toplocvalidator/distributed_validation.py index 1ceda2c..c06ca32 100644 --- a/toplocvalidator/distributed_validation.py +++ b/toplocvalidator/distributed_validation.py @@ -7,6 +7,7 @@ from dataclasses import dataclass from datetime import datetime +from toplocvalidator import env from toplocvalidator.validation_server import ValidationStatus, ValidationState from toplocvalidator.datamodels import GroupValidationRequest @@ -86,7 +87,7 @@ async def _check_worker_health(self, worker: WorkerInfo) -> None: try: async with self.session.get( - f"{worker.url}/worker/health", timeout=aiohttp.ClientTimeout(total=5) + f"{worker.url}/health", timeout=aiohttp.ClientTimeout(total=5) ) as response: worker.is_healthy = response.status == 200 worker.last_health_check = datetime.now() @@ -140,18 +141,27 @@ async def validate_file( worker.current_load += 1 try: + # Build request data matching the server's ValidationRequest model request_data = { - "filepath": filepath, "file_sha": file_sha, - "node_address": node_address, + "address": node_address, } logger.info(f"Routing validation of {filepath} to worker {worker.url}") + # Add auth header if token is set + headers = {} + if env.TOPLOCVALIDATOR_AUTH_TOKEN: + headers["Authorization"] = f"Bearer {env.TOPLOCVALIDATOR_AUTH_TOKEN}" + + # Call the worker's validate endpoint async with self.session.post( - f"{worker.url}/worker/validate", + f"{worker.url}/validate/{filepath}", json=request_data, - timeout=aiohttp.ClientTimeout(total=600), # 10 minute timeout + headers=headers, + timeout=aiohttp.ClientTimeout( + total=30 + ), # Short timeout for initial request ) as response: if response.status != 200: error_text = await response.text() @@ -163,14 +173,36 @@ async def validate_file( reason=f"Worker error: {response.status} - {error_text}", ) - result = await response.json() - - # Convert response to ValidationState - status = ValidationStatus(result["status"]) + # The worker will start validation and return immediately + # We need to poll for the result + await asyncio.sleep(1) # Give it a moment to start + + # Poll the status endpoint + max_polls = 600 # 10 minutes with 1 second intervals + for _ in range(max_polls): + async with self.session.get( + f"{worker.url}/status/{filepath}", + headers=headers, + timeout=aiohttp.ClientTimeout(total=5), + ) as status_response: + if status_response.status == 200: + status_result = await status_response.json() + status_value = status_result["status"] + + if status_value != "pending": + # Validation complete + return ValidationState( + status=ValidationStatus(status_value), + reason="", # Basic status endpoint doesn't return reason + sample_count=0, # Basic status endpoint doesn't return sample count + ) + + await asyncio.sleep(1) + + # Timeout return ValidationState( - status=status, - reason=result.get("reason", ""), - sample_count=result.get("sample_count", 0), + status=ValidationStatus.CRASHED, + reason="Worker validation timeout", ) except asyncio.TimeoutError: @@ -217,18 +249,26 @@ async def validate_group( worker.current_load += 1 try: - request_data = {"filepath": filepath, "file_data": file_data.model_dump()} + # Use the standard server endpoint + request_data = file_data.model_dump() logger.info( f"Routing group validation of {filepath} to worker {worker.url}" ) + # Add auth header if token is set + headers = {} + if env.TOPLOCVALIDATOR_AUTH_TOKEN: + headers["Authorization"] = f"Bearer {env.TOPLOCVALIDATOR_AUTH_TOKEN}" + + # Call the worker's validategroup endpoint async with self.session.post( - f"{worker.url}/worker/validategroup", + f"{worker.url}/validategroup/{filepath}", json=request_data, + headers=headers, timeout=aiohttp.ClientTimeout( - total=1200 - ), # 20 minute timeout for groups + total=30 + ), # Short timeout for initial request ) as response: if response.status != 200: error_text = await response.text() @@ -240,17 +280,39 @@ async def validate_group( reason=f"Worker error: {response.status} - {error_text}", ) - result = await response.json() - - # Convert response to ValidationState - status = ValidationStatus(result["status"]) + # The worker will start validation and return immediately + # We need to poll for the result using the statusgroup endpoint + await asyncio.sleep(1) # Give it a moment to start + + # Poll the statusgroup endpoint + max_polls = 1200 # 20 minutes with 1 second intervals + for _ in range(max_polls): + async with self.session.get( + f"{worker.url}/statusgroup/{filepath}", + headers=headers, + timeout=aiohttp.ClientTimeout(total=5), + ) as status_response: + if status_response.status == 200: + result = await status_response.json() + status_value = result["status"] + + if status_value != "pending": + # Validation complete + return ValidationState( + status=ValidationStatus(status_value), + reason=result.get("reason", ""), + failing_indices=result.get("failing_indices", []), + input_flops=result.get("input_flops", 0), + output_flops=result.get("output_flops", 0), + sample_count=0, # Not available in statusgroup endpoint + ) + + await asyncio.sleep(1) + + # Timeout return ValidationState( - status=status, - reason=result.get("reason", ""), - failing_indices=result.get("failing_indices", []), - input_flops=result.get("input_flops", 0), - output_flops=result.get("output_flops", 0), - sample_count=result.get("sample_count", 0), + status=ValidationStatus.CRASHED, + reason="Worker validation timeout", ) except asyncio.TimeoutError: diff --git a/toplocvalidator/worker_server.py b/toplocvalidator/worker_server.py deleted file mode 100644 index 29b9ffc..0000000 --- a/toplocvalidator/worker_server.py +++ /dev/null @@ -1,279 +0,0 @@ -"""Worker server for distributed validation processing. - -This server performs validation work without maintaining state. -It receives validation requests from the main coordinator server. -""" - -from fastapi import FastAPI, HTTPException -from pydantic import BaseModel -import logging -from contextlib import asynccontextmanager -import shutil -from typing import Optional -from pathlib import Path - -from toplocvalidator import env -from toplocvalidator.shardcast_downloader import run_main_bg -from toplocvalidator.validation_server import ValidationStatus -from toplocvalidator.batch_processing.validation_engine import ValidationEngine -from toplocvalidator.validation_flow import process_file, process_group_files -from toplocvalidator.file_retrieval.gcp import GcpBucket -from toplocvalidator.datamodels import GroupValidationRequest - -# Set up logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -validation_engine: Optional[ValidationEngine] = None -gcp_bucket: Optional[GcpBucket] = None - - -class WorkerValidationRequest(BaseModel): - filepath: str - file_sha: str - node_address: Optional[str] = None - - -class WorkerGroupValidationRequest(BaseModel): - filepath: str - file_data: GroupValidationRequest - - -class WorkerValidationResponse(BaseModel): - status: str - reason: str = "" - failing_indices: list[int] = [] - input_flops: int = 0 - output_flops: int = 0 - sample_count: int = 0 - - -def initialize_gcp_bucket(gcp_path: str) -> None: - """Initialize the GCP bucket with the given path.""" - global gcp_bucket - if env.GCP_CREDENTIALS == "": - raise RuntimeError( - "GCP credentials not set. Please set the GCP_CREDENTIALS environment variable." - ) - gcp_bucket = GcpBucket(gcp_path=gcp_path, credentials_base64=env.GCP_CREDENTIALS) - - -@asynccontextmanager -async def lifespan(app: FastAPI): - """Lifespan context manager for startup and shutdown events.""" - # Startup - env.TMP_DIR.mkdir(parents=True, exist_ok=True) - if env.SHARDCAST_SERVERS: - shardcast_process = run_main_bg( - servers=env.SHARDCAST_SERVERS, - output_dir=env.SHARDCAST_OUTPUT_DIR, - versions_to_keep=env.SHARDCAST_VERSIONS_TO_KEEP, - backlog_version=env.SHARDCAST_BACKLOG_VERSION, - ) - else: - shardcast_process = None - - yield - # Shutdown - logger.info(f"Cleaning up temporary directory: {env.TMP_DIR}") - try: - if env.TMP_DIR.exists(): - shutil.rmtree(env.TMP_DIR) - logger.info("Temporary directory cleaned up successfully") - except Exception as e: - logger.error(f"Error cleaning up temporary directory: {e}") - if shardcast_process: - import os - import signal - - if shardcast_process.pid: - os.kill(shardcast_process.pid, signal.SIGKILL) - shardcast_process.join() - - -app = FastAPI(lifespan=lifespan, title="TopLoc Worker Server") - - -@app.post("/worker/validate", response_model=WorkerValidationResponse) -async def validate_file(request: WorkerValidationRequest) -> WorkerValidationResponse: - """ - Validate a single file and return the result. - - This endpoint performs the actual validation work without maintaining state. - """ - logger.info( - f"Worker received validation request for {request.filepath} with SHA: {request.file_sha}" - ) - - if validation_engine is None or gcp_bucket is None: - logger.error("Validation engine or GCP bucket not initialized") - return WorkerValidationResponse( - status=ValidationStatus.CRASHED.value, - reason="Worker not properly initialized", - ) - - try: - state = await process_file( - engine=validation_engine, - gcp_bucket=gcp_bucket, - filepath=request.filepath, - sha=request.file_sha, - ) - - return WorkerValidationResponse( - status=state.status.value, - reason=state.reason, - sample_count=state.sample_count, - ) - except Exception as e: - logger.error(f"Error validating {request.filepath}: {e}") - return WorkerValidationResponse( - status=ValidationStatus.CRASHED.value, - reason=str(e), - ) - - -@app.post("/worker/validategroup", response_model=WorkerValidationResponse) -async def validate_group( - request: WorkerGroupValidationRequest, -) -> WorkerValidationResponse: - """ - Validate a group of files and return the result. - - This endpoint performs group validation without maintaining state. - """ - logger.info(f"Worker received group validation request for {request.filepath}") - - if validation_engine is None or gcp_bucket is None: - logger.error("Validation engine or GCP bucket not initialized") - return WorkerValidationResponse( - status=ValidationStatus.CRASHED.value, - reason="Worker not properly initialized", - ) - - try: - p_filepath = Path(request.filepath) - filepaths = [ - str(p_filepath.with_name(f"{p_filepath.stem}-{i}{p_filepath.suffix}")) - for i in range(request.file_data.group_size) - ] - - state = await process_group_files( - engine=validation_engine, - gcp_bucket=gcp_bucket, - filepaths=filepaths, - file_data=request.file_data, - ) - - return WorkerValidationResponse( - status=state.status.value, - reason=state.reason, - failing_indices=state.failing_indices, - input_flops=state.input_flops, - output_flops=state.output_flops, - sample_count=state.sample_count, - ) - except Exception as e: - logger.error(f"Error validating group {request.filepath}: {e}") - return WorkerValidationResponse( - status=ValidationStatus.CRASHED.value, - reason=str(e), - ) - - -@app.get("/worker/health") -async def health_check(): - """ - Health check endpoint for the worker. - """ - global validation_engine - - if validation_engine is None: - raise HTTPException( - status_code=425, detail="Model is being loaded, please try again later" - ) - - if not validation_engine.is_model_loaded(): - raise HTTPException( - status_code=425, detail="Model is being loaded, please try again later" - ) - - return {"status": "healthy", "message": "Worker is ready"} - - -@app.get("/worker/info") -async def worker_info(): - """ - Get information about this worker instance. - """ - import socket - - hostname = socket.gethostname() - ip_address = socket.gethostbyname(hostname) - - return { - "hostname": hostname, - "ip_address": ip_address, - "model": validation_engine.model_name_or_path if validation_engine else None, - } - - -def main(): - """Main function to start the worker server.""" - import argparse - import uvicorn - from toplocvalidator.file_processing import set_topk - - parser = argparse.ArgumentParser(description="TopLoc Worker Server") - parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to") - parser.add_argument("--port", type=int, default=8001, help="Port to bind to") - parser.add_argument( - "--topk", type=int, default=None, help="Top-k value for validation" - ) - parser.add_argument( - "--gcp-path", type=str, default="", help="GCP path for file retrieval" - ) - parser.add_argument( - "--model-name-or-path", - type=str, - required=True, - help="Model name or path for validation", - ) - parser.add_argument( - "--tp", type=int, default=1, help="Tensor parallel size for the vllm engine" - ) - parser.add_argument( - "--config-path", - type=str, - default=None, - help="Path to custom validation config file", - ) - parser.add_argument("--log-level", type=str, default="INFO", help="Logging level") - - args = parser.parse_args() - - # Set logging level - logging.getLogger().setLevel(args.log_level) - - # Set top-k if provided - if args.topk is not None: - set_topk(args.topk) - - # Initialize GCP bucket - if args.gcp_path: - initialize_gcp_bucket(args.gcp_path) - - # Initialize validation engine - global validation_engine - validation_engine = ValidationEngine( - model_name_or_path=args.model_name_or_path, - tp=args.tp, - config_path=args.config_path, - ) - - # Run the server - uvicorn.run(app, host=args.host, port=args.port) - - -if __name__ == "__main__": - main()