Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies = [
"shardcast>=0.1.8",
"setuptools",
"datasets",
"aiohttp",
]

[project.optional-dependencies]
Expand Down
346 changes: 346 additions & 0 deletions toplocvalidator/distributed_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,346 @@
"""Distributed validation module for routing requests to worker servers."""

import asyncio
import logging
from typing import Optional, List, Dict, Any
import aiohttp
from dataclasses import dataclass
from datetime import datetime

from toplocvalidator import env
from toplocvalidator.validation_server import ValidationStatus, ValidationState
from toplocvalidator.datamodels import GroupValidationRequest

logger = logging.getLogger(__name__)


@dataclass
class WorkerInfo:
"""Information about a worker server."""

url: str
is_healthy: bool = False
last_health_check: Optional[datetime] = None
current_load: int = 0 # Number of active validations


class DistributedValidationManager:
"""Manages distributed validation across worker servers."""

def __init__(self, worker_urls: List[str]):
"""
Initialize the distributed validation manager.

Args:
worker_urls: List of worker server URLs
"""
self.workers = [
WorkerInfo(url=url.strip()) for url in worker_urls if url.strip()
]
self.health_check_interval = 30 # seconds
self.session: Optional[aiohttp.ClientSession] = None
self._health_check_task: Optional[asyncio.Task] = None

async def start(self) -> None:
"""Start the distributed validation manager."""
self.session = aiohttp.ClientSession()
# Start health check task
self._health_check_task = asyncio.create_task(self._health_check_loop())
# Do initial health check
await self._check_all_workers_health()

async def stop(self) -> None:
"""Stop the distributed validation manager."""
if self._health_check_task:
self._health_check_task.cancel()
try:
await self._health_check_task
except asyncio.CancelledError:
pass
if self.session:
await self.session.close()

async def _health_check_loop(self) -> None:
"""Periodically check health of all workers."""
while True:
try:
await asyncio.sleep(self.health_check_interval)
await self._check_all_workers_health()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error in health check loop: {e}")

async def _check_all_workers_health(self) -> None:
"""Check health of all worker servers."""
tasks = [self._check_worker_health(worker) for worker in self.workers]
await asyncio.gather(*tasks, return_exceptions=True)

healthy_count = sum(1 for w in self.workers if w.is_healthy)
logger.info(f"Worker health check: {healthy_count}/{len(self.workers)} healthy")

async def _check_worker_health(self, worker: WorkerInfo) -> None:
"""Check health of a single worker."""
if not self.session:
logger.warning("Session not initialized for health check")
return

try:
async with self.session.get(
f"{worker.url}/health", timeout=aiohttp.ClientTimeout(total=5)
) as response:
worker.is_healthy = response.status == 200
worker.last_health_check = datetime.now()
except Exception as e:
logger.warning(f"Worker {worker.url} health check failed: {e}")
worker.is_healthy = False
worker.last_health_check = datetime.now()

def _select_worker(self) -> Optional[WorkerInfo]:
"""
Select a healthy worker with the lowest load.

Returns:
Selected worker or None if no healthy workers available
"""
healthy_workers = [w for w in self.workers if w.is_healthy]
if not healthy_workers:
return None

# Select worker with lowest load
return min(healthy_workers, key=lambda w: w.current_load)

async def validate_file(
self, filepath: str, file_sha: str, node_address: Optional[str] = None
) -> ValidationState:
"""
Validate a file using a worker server.

Args:
filepath: Path to the file to validate
file_sha: SHA of the file
node_address: Optional node address for tracking

Returns:
ValidationState with the result
"""
worker = self._select_worker()
if not worker:
logger.error("No healthy workers available")
return ValidationState(
status=ValidationStatus.CRASHED,
reason="No healthy workers available for validation",
)

if not self.session:
logger.error("Session not initialized")
return ValidationState(
status=ValidationStatus.CRASHED,
reason="Distributed validation manager not properly initialized",
)

worker.current_load += 1
try:
# Build request data matching the server's ValidationRequest model
request_data = {
"file_sha": file_sha,
"address": node_address,
}

logger.info(f"Routing validation of {filepath} to worker {worker.url}")

# Add auth header if token is set
headers = {}
if env.TOPLOCVALIDATOR_AUTH_TOKEN:
headers["Authorization"] = f"Bearer {env.TOPLOCVALIDATOR_AUTH_TOKEN}"

# Call the worker's validate endpoint
async with self.session.post(
f"{worker.url}/validate/{filepath}",
json=request_data,
headers=headers,
timeout=aiohttp.ClientTimeout(
total=30
), # Short timeout for initial request
) as response:
if response.status != 200:
error_text = await response.text()
logger.error(
f"Worker returned error: {response.status} - {error_text}"
)
return ValidationState(
status=ValidationStatus.CRASHED,
reason=f"Worker error: {response.status} - {error_text}",
)

# The worker will start validation and return immediately
# We need to poll for the result
await asyncio.sleep(1) # Give it a moment to start

# Poll the status endpoint
max_polls = 600 # 10 minutes with 1 second intervals
for _ in range(max_polls):
async with self.session.get(
f"{worker.url}/status/{filepath}",
headers=headers,
timeout=aiohttp.ClientTimeout(total=5),
) as status_response:
if status_response.status == 200:
status_result = await status_response.json()
status_value = status_result["status"]

if status_value != "pending":
# Validation complete
return ValidationState(
status=ValidationStatus(status_value),
reason="", # Basic status endpoint doesn't return reason
sample_count=0, # Basic status endpoint doesn't return sample count
)

await asyncio.sleep(1)

# Timeout
return ValidationState(
status=ValidationStatus.CRASHED,
reason="Worker validation timeout",
)

except asyncio.TimeoutError:
logger.error(f"Timeout validating {filepath} on worker {worker.url}")
return ValidationState(
status=ValidationStatus.CRASHED, reason="Worker validation timeout"
)
except Exception as e:
logger.error(f"Error validating {filepath} on worker {worker.url}: {e}")
return ValidationState(
status=ValidationStatus.CRASHED,
reason=f"Worker validation error: {str(e)}",
)
finally:
worker.current_load -= 1

async def validate_group(
self, filepath: str, file_data: GroupValidationRequest
) -> ValidationState:
"""
Validate a group of files using a worker server.

Args:
filepath: Path to the main file
file_data: Group validation request data

Returns:
ValidationState with the result
"""
worker = self._select_worker()
if not worker:
logger.error("No healthy workers available")
return ValidationState(
status=ValidationStatus.CRASHED,
reason="No healthy workers available for validation",
)

if not self.session:
logger.error("Session not initialized")
return ValidationState(
status=ValidationStatus.CRASHED,
reason="Distributed validation manager not properly initialized",
)

worker.current_load += 1
try:
# Use the standard server endpoint
request_data = file_data.model_dump()

logger.info(
f"Routing group validation of {filepath} to worker {worker.url}"
)

# Add auth header if token is set
headers = {}
if env.TOPLOCVALIDATOR_AUTH_TOKEN:
headers["Authorization"] = f"Bearer {env.TOPLOCVALIDATOR_AUTH_TOKEN}"

# Call the worker's validategroup endpoint
async with self.session.post(
f"{worker.url}/validategroup/{filepath}",
json=request_data,
headers=headers,
timeout=aiohttp.ClientTimeout(
total=30
), # Short timeout for initial request
) as response:
if response.status != 200:
error_text = await response.text()
logger.error(
f"Worker returned error: {response.status} - {error_text}"
)
return ValidationState(
status=ValidationStatus.CRASHED,
reason=f"Worker error: {response.status} - {error_text}",
)

# The worker will start validation and return immediately
# We need to poll for the result using the statusgroup endpoint
await asyncio.sleep(1) # Give it a moment to start

# Poll the statusgroup endpoint
max_polls = 1200 # 20 minutes with 1 second intervals
for _ in range(max_polls):
async with self.session.get(
f"{worker.url}/statusgroup/{filepath}",
headers=headers,
timeout=aiohttp.ClientTimeout(total=5),
) as status_response:
if status_response.status == 200:
result = await status_response.json()
status_value = result["status"]

if status_value != "pending":
# Validation complete
return ValidationState(
status=ValidationStatus(status_value),
reason=result.get("reason", ""),
failing_indices=result.get("failing_indices", []),
input_flops=result.get("input_flops", 0),
output_flops=result.get("output_flops", 0),
sample_count=0, # Not available in statusgroup endpoint
)

await asyncio.sleep(1)

# Timeout
return ValidationState(
status=ValidationStatus.CRASHED,
reason="Worker validation timeout",
)

except asyncio.TimeoutError:
logger.error(f"Timeout validating group {filepath} on worker {worker.url}")
return ValidationState(
status=ValidationStatus.CRASHED, reason="Worker validation timeout"
)
except Exception as e:
logger.error(
f"Error validating group {filepath} on worker {worker.url}: {e}"
)
return ValidationState(
status=ValidationStatus.CRASHED,
reason=f"Worker validation error: {str(e)}",
)
finally:
worker.current_load -= 1

def get_workers_status(self) -> List[Dict[str, Any]]:
"""Get status information for all workers."""
return [
{
"url": worker.url,
"is_healthy": worker.is_healthy,
"last_health_check": worker.last_health_check.isoformat()
if worker.last_health_check
else None,
"current_load": worker.current_load,
}
for worker in self.workers
]
9 changes: 9 additions & 0 deletions toplocvalidator/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@
TOPLOC_DISCORD_WEBHOOK: str = ""
TOPLOC_FLOP_SCALE_FACTOR: int = 1

# Worker configuration for distributed validation
WORKER_URLS: list[str] | None = None # Comma-separated list of worker URLs
ENABLE_DISTRIBUTED: bool = False # Enable distributed validation

_env = {
"SKIP_FILESHA_CHECK": lambda: os.getenv("SKIP_FILESHA_CHECK", "false").lower()
in ["true", "1", "yes", "y"],
Expand Down Expand Up @@ -104,6 +108,11 @@
in ["true", "1", "yes", "y"],
"TOPLOC_DISCORD_WEBHOOK": lambda: os.getenv("TOPLOC_DISCORD_WEBHOOK", ""),
"TOPLOC_FLOP_SCALE_FACTOR": lambda: int(os.getenv("TOPLOC_FLOP_SCALE_FACTOR", "1")),
"WORKER_URLS": lambda: os.getenv("WORKER_URLS", "").split(",")
if os.getenv("WORKER_URLS", "")
else None,
"ENABLE_DISTRIBUTED": lambda: os.getenv("ENABLE_DISTRIBUTED", "false").lower()
in ["true", "1", "yes", "y"],
}


Expand Down
Loading
Loading