|
1 | 1 | import hashlib |
2 | 2 | import itertools |
| 3 | +import json |
3 | 4 | import os |
4 | 5 | import pickle |
5 | 6 | from dataclasses import dataclass, field |
| 7 | +import queue |
| 8 | +import threading |
| 9 | +import time |
6 | 10 | from typing import TYPE_CHECKING, Callable, List, Optional |
7 | 11 |
|
| 12 | +from sympy import Dict |
8 | 13 | import torch |
| 14 | +from transformers import Any |
9 | 15 | from vllm.config import VllmConfig |
10 | 16 | from vllm.distributed.kv_transfer.kv_connector.v1.base import ( |
11 | 17 | KVConnectorBase_V1, |
@@ -159,6 +165,64 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): |
159 | 165 | config["kv_block_size"] / 1024 / 1024, |
160 | 166 | config["io_size"] / 1024, |
161 | 167 | ) |
| 168 | + self.record_oper: bool = self.launch_config.get("record_oper", False) |
| 169 | + if self.record_oper: |
| 170 | + self.write_thread = threading.Thread(target=self._async_record_loop, daemon=True) |
| 171 | + self.write_thread.start() |
| 172 | + |
| 173 | + def log_operation(self, operation_data: Dict[str, Any]) -> None: |
| 174 | + """Record operation log (non-blocking)""" |
| 175 | + |
| 176 | + default_data = { |
| 177 | + "timestamp": time.time(), |
| 178 | + "op_type": "None", |
| 179 | + "block_size": self.block_size |
| 180 | + } |
| 181 | + log_entry = {**default_data, **operation_data} |
| 182 | + |
| 183 | + try: |
| 184 | + self.log_queue.put_nowait(log_entry) |
| 185 | + except queue.Full: |
| 186 | + logger.error( |
| 187 | + f"Log queue is full, dropping one log: {log_entry.get('request_id')}" |
| 188 | + ) |
| 189 | + |
| 190 | + def _async_record_loop(self): |
| 191 | + self.log_queue = queue.Queue(maxsize=10000) # Max cache: 10000 entries |
| 192 | + log_path = self.launch_config.get("record_oper_path", "/vllm-workspace/ucm_logs") |
| 193 | + batch_size = self.launch_config.get("record_oper_batch_size", 100) |
| 194 | + flush_interval = self.launch_config.get("record_oper_flush_interval", 5.0) |
| 195 | + batch_buffer = [] |
| 196 | + last_flush_time = time.time() |
| 197 | + while True: |
| 198 | + try: |
| 199 | + # Get log from queue (1 second timeout) |
| 200 | + is_flush = False |
| 201 | + current_time = time.time() |
| 202 | + log_entry = self.log_queue.get(timeout=1.0) |
| 203 | + batch_buffer.append(log_entry) |
| 204 | + |
| 205 | + # Flush if conditions are met |
| 206 | + if ( |
| 207 | + len(batch_buffer) >= batch_size |
| 208 | + or (current_time - last_flush_time) >= flush_interval |
| 209 | + ): |
| 210 | + is_flush = True |
| 211 | + last_flush_time = current_time |
| 212 | + self.log_queue.task_done() |
| 213 | + except queue.Empty: |
| 214 | + if (current_time - last_flush_time) >= flush_interval: |
| 215 | + last_flush_time = current_time |
| 216 | + except Exception as e: |
| 217 | + logger.error(f"Log thread exception: {str(e)}") |
| 218 | + |
| 219 | + if is_flush: |
| 220 | + with open(log_path, "a", encoding="utf-8") as f: |
| 221 | + for log_entry in self.batch_buffer: |
| 222 | + f.write(json.dumps(log_entry, ensure_ascii=False) + "\n") |
| 223 | + batch_buffer.clear() |
| 224 | + |
| 225 | + |
162 | 226 |
|
163 | 227 | def generate_hash(self, block_size: int, request: "Request") -> list[str]: |
164 | 228 | token_ids = request.all_token_ids |
@@ -465,6 +529,13 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: |
465 | 529 | request_to_task[request_id] = self.store.load( |
466 | 530 | ucm_total_block_ids, ucm_offsets, dst_tensor_addr |
467 | 531 | ) |
| 532 | + if self.record_oper: |
| 533 | + self.log_operation( |
| 534 | + { |
| 535 | + "op_type": "load", |
| 536 | + "blocks": ucm_block_ids, |
| 537 | + } |
| 538 | + ) |
468 | 539 | else: |
469 | 540 | request_to_task[request_id] = None |
470 | 541 | req_broadcast_addr[request_id] = dst_tensor_addr |
@@ -527,6 +598,13 @@ def wait_for_save(self) -> None: |
527 | 598 | request_to_task[request_id] = self.store.dump( |
528 | 599 | ucm_total_block_ids, ucm_offsets, dst_tensor_addr |
529 | 600 | ) |
| 601 | + if self.record_oper: |
| 602 | + self.log_operation( |
| 603 | + { |
| 604 | + "op_type": "dump", |
| 605 | + "blocks": ucm_block_ids, |
| 606 | + } |
| 607 | + ) |
530 | 608 | request_to_blocks[request_id] = ucm_block_ids |
531 | 609 |
|
532 | 610 | for request_id, task in request_to_task.items(): |
|
0 commit comments