Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 133 additions & 4 deletions toplocvalidator/batch_processing/validation_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import asyncio
import multiprocessing as mp
import time
import threading
from vllm.model_executor.model_loader.loader import _process_weights_after_loading
from toplocvalidator import env
from toplocvalidator.validation_server import ValidationStatus
Expand Down Expand Up @@ -180,7 +181,7 @@ def validate_toploc2(
)

chunk_logits /= temperature

# top_logits = chunk_logits.max(dim=1).values
neg_gumbel_noise = generate_neg_gumbel_noise(
(chunk_logits.shape[0], chunk_logits.shape[-1]),
Expand Down Expand Up @@ -564,6 +565,11 @@ def __init__(
self.config = ValidationConfig.from_toml(config_path)
self.model_name_or_path = model_name_or_path
self.tp = tp
self.active_validations: Dict[str, float] = {} # rid -> start_time
self.active_validations_lock = threading.Lock()
self.monitor_thread = None
self.stop_monitoring = False
self.worker = None
self.start()

def is_model_loaded(self) -> bool:
Expand All @@ -585,6 +591,97 @@ def start(self):
)
self.worker.start()

# Start the monitoring thread
self.stop_monitoring = False
self.monitor_thread = threading.Thread(target=self._monitor_timeouts)
self.monitor_thread.daemon = True
self.monitor_thread.start()

def _monitor_timeouts(self) -> None:
"""Monitor active validations for timeouts and restart worker if needed."""
while not self.stop_monitoring:
try:
current_time = time.time()
timed_out_rids = []

with self.active_validations_lock:
for rid, start_time in self.active_validations.items():
if current_time - start_time > env.VALIDATION_TIMEOUT:
logger.warning(
f"Validation {rid} has been running for "
f"{current_time - start_time:.1f}s, exceeding timeout "
f"of {env.VALIDATION_TIMEOUT}s"
)
timed_out_rids.append(rid)

if timed_out_rids:
logger.error(
f"Detected {len(timed_out_rids)} timed out validations. "
f"Restarting worker process..."
)
self._restart_worker(timed_out_rids)

# Check every 5 seconds
time.sleep(5)

except Exception as e:
logger.error(f"Error in timeout monitor: {e}")
time.sleep(5)

def _restart_worker(self, timed_out_rids: list[str]) -> None:
"""Restart the worker process and mark timed out validations as crashed."""
try:
# Mark all timed out validations as crashed
for rid in timed_out_rids:
if rid in self.rid_to_output:
validation_output = ValidationOutput(event=self.manager.Event())
validation_output.set_result(
ValidationStatus.CRASHED,
f"Validation timed out after {env.VALIDATION_TIMEOUT}s",
)
self.rid_to_output[rid] = validation_output

with self.active_validations_lock:
self.active_validations.pop(rid, None)

# Kill the existing worker process
if self.worker and self.worker.is_alive():
logger.info("Terminating stuck worker process...")
self.worker.terminate()
self.worker.join(timeout=10)

# If still alive, force kill
if self.worker.is_alive():
logger.warning("Worker didn't terminate gracefully, killing...")
self.worker.kill()
self.worker.join()

# Clear the model loaded event since we're restarting
self.model_loaded_event.clear()

# Start a new worker process
logger.info("Starting new worker process...")
self.worker = mp.Process(
target=_validation_worker,
args=(
self.model_name_or_path,
self.tp,
self.config,
self.input_queue,
self.rid_to_output,
self.model_loaded_event,
),
)
self.worker.start()

logger.info("Worker process restarted successfully")

except Exception as e:
logger.error(f"Error restarting worker: {e}")
import traceback

traceback.print_exc()

async def validate_entries(
self, entries: "list[InferenceEntry]", step: Optional[int] = None
) -> Tuple[ValidationStatus, Optional[str]]:
Expand All @@ -603,13 +700,22 @@ async def validate_entries(
validation_output = ValidationOutput(event=self.manager.Event())
self.rid_to_output[entries[0].rid] = validation_output

# Track the start time of this validation
with self.active_validations_lock:
self.active_validations[entries[0].rid] = time.time()

start_time = time.time()
await loop.run_in_executor(
None, self.input_queue.put, ValidationInput(entries, None, step)
)
# Get result from worker process
await loop.run_in_executor(None, validation_output.event.wait)
logger.info(f"Validation time: {time.time() - start_time:.2f} seconds")

# Remove from active validations
with self.active_validations_lock:
self.active_validations.pop(entries[0].rid, None)

# TODO: This is kind of a hack because of the way ManagerDict works.
# Maybe we can fix later but it works for now.
validation_output = self.rid_to_output[entries[0].rid]
Expand Down Expand Up @@ -638,6 +744,10 @@ async def validate_grouped_entries(
validation_output = ValidationOutput(event=self.manager.Event())
self.rid_to_output[main_entries[0].rid] = validation_output

# Track the start time of this validation
with self.active_validations_lock:
self.active_validations[main_entries[0].rid] = time.time()

start_time = time.time()
await loop.run_in_executor(
None,
Expand All @@ -648,6 +758,10 @@ async def validate_grouped_entries(
await loop.run_in_executor(None, validation_output.event.wait)
logger.info(f"Validation time: {time.time() - start_time:.2f} seconds")

# Remove from active validations
with self.active_validations_lock:
self.active_validations.pop(main_entries[0].rid, None)

# TODO: This is kind of a hack because of the way ManagerDict works.
# Maybe we can fix later but it works for now.
validation_output = self.rid_to_output[main_entries[0].rid]
Expand All @@ -656,12 +770,27 @@ async def validate_grouped_entries(
return result

def stop(self):
self.input_queue.put(None)
self.worker.join()
self.worker = None
# Stop the monitoring thread
self.stop_monitoring = True
if self.monitor_thread and self.monitor_thread.is_alive():
self.monitor_thread.join(timeout=10)

# Stop the worker process
if self.worker:
self.input_queue.put(None)
self.worker.join()
self.worker = None

def __del__(self):
"""Cleanup worker process on deletion."""
if hasattr(self, "stop_monitoring"):
self.stop_monitoring = True
if (
hasattr(self, "monitor_thread")
and self.monitor_thread
and self.monitor_thread.is_alive()
):
self.monitor_thread.join(timeout=10)
if hasattr(self, "worker") and self.worker is not None:
self.input_queue.put(None) # Send shutdown signal
self.worker.join()
2 changes: 2 additions & 0 deletions toplocvalidator/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
TOPLOCVALIDATOR_AUTH_TOKEN: str = ""
GCP_CREDENTIALS: str = ""
GCP_PULL_TIMEOUT: int = 5 # Timeout in seconds for GCP file pull operations
VALIDATION_TIMEOUT: int = 300 # Timeout in seconds for validation processing
SKIP_DATA_SAMPLING_CHECK: bool = False # Feature flag
SKIP_TOPLOC2_CHECK: bool = True # Feature flag
# Filesha check
Expand Down Expand Up @@ -44,6 +45,7 @@
"TOPLOCVALIDATOR_AUTH_TOKEN": lambda: os.getenv("TOPLOCVALIDATOR_AUTH_TOKEN", ""),
"GCP_CREDENTIALS": lambda: os.getenv("GCP_CREDENTIALS", ""),
"GCP_PULL_TIMEOUT": lambda: int(os.getenv("GCP_PULL_TIMEOUT", "5")),
"VALIDATION_TIMEOUT": lambda: int(os.getenv("VALIDATION_TIMEOUT", "300")),
"SHARDCAST_SERVERS": lambda: os.getenv("SHARDCAST_SERVERS", "").split(",")
if os.getenv("SHARDCAST_SERVERS", "")
else None,
Expand Down
Loading