diff --git a/.github/workflows/cpp-linter.yml b/.github/workflows/cpp-linter.yml new file mode 100644 index 000000000..e013a62c7 --- /dev/null +++ b/.github/workflows/cpp-linter.yml @@ -0,0 +1,34 @@ +name: cpp-linter + +on: + push: + branches: [ "*" ] + pull_request: + branches: [ "dev*", "main", "*release" ] + + +jobs: + cpp-linter: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 + with: + persist-credentials: false + - uses: cpp-linter/cpp-linter-action@main + id: linter + continue-on-error: true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + style: file + tidy-checks: '-*' + files-changed-only: true + lines-changed-only: diff + format-review: true + thread-comments: ${{ github.event_name == 'pull_request' && 'update' }} + + - name: Fail fast?! + if: steps.linter.outputs.checks-failed != 0 + run: | + echo "some linter checks failed. ${{ steps.linter.outputs.checks-failed }}" + exit 1 diff --git a/docs/source/getting-started/quick_start.md b/docs/source/getting-started/quick_start.md index 9e7630e18..1f33ab3b3 100644 --- a/docs/source/getting-started/quick_start.md +++ b/docs/source/getting-started/quick_start.md @@ -59,7 +59,17 @@ First, specify the python hash seed by: export PYTHONHASHSEED=123456 ``` -Run the following command to start the vLLM server with the Qwen/Qwen2.5-14B-Instruct model: +Create a config yaml like following and save it to your own directory: +```yaml +# UCM Configuration File Example +# Refer to file unified-cache-management/examples/ucm_config_example.yaml for more details +ucm_connector_name: "UcmNfsStore" + +ucm_connector_config: + storage_backends: "/mnt/test" +``` + +Run the following command to start the vLLM server with the Qwen/Qwen2.5-14B-Instruct model and your config file path: ```bash # Change the model path to your own model path @@ -77,11 +87,7 @@ vllm serve ${MODEL_PATH} \ "kv_connector_module_path": "ucm.integration.vllm.uc_connector", "kv_role": "kv_both", "kv_connector_extra_config": { - "ucm_connector_name": "UcmDramStore", - "ucm_connector_config": { - "max_cache_size": 5368709120, - "kv_block_size": 262144 - } + "UCM_CONFIG_FILE": "/workspace/unified-cache-management/examples/ucm_config_example.yaml" } }' ``` diff --git a/docs/source/index.md b/docs/source/index.md index 2352d3996..69be815e6 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -57,6 +57,7 @@ getting-started/installation_npu user-guide/prefix-cache/index user-guide/sparse-attention/index user-guide/pd-disaggregation/index +user-guide/metrics/metrics ::: :::{toctree} diff --git a/docs/source/user-guide/metrics/metrics.md b/docs/source/user-guide/metrics/metrics.md new file mode 100644 index 000000000..22b532681 --- /dev/null +++ b/docs/source/user-guide/metrics/metrics.md @@ -0,0 +1,193 @@ +# Observability + +UCM (Unified Cache Management) provides detailed metrics monitoring through Prometheus endpoints, allowing in-depth monitoring of cache performance and behavior. This document describes how to enable and configure observability from the embedded vLLM `/metrics` API endpoint. + +--- + +## Quick Start Guide + +### 1) On UCM Side + +First, set the `PROMETHEUS_MULTIPROC_DIR` environment variable. + +```bash +export PROMETHEUS_MULTIPROC_DIR=/vllm-workspace +``` + +Then, start the UCM service. + +```bash +export CUDA_VISIBLE_DEVICES=0 +vllm serve /home/models/Qwen2.5-14B-Instruct \ + --max-model-len 5000 \ + --tensor-parallel-size 1 \ + --gpu_memory_utilization 0.87 \ + --trust-remote-code \ + --disable-log-requests \ + --no-enable-prefix-caching \ + --enforce-eager \ + --max-num-batched-tokens 40000 \ + --max-num-seqs 10 \ + --host 0.0.0.0 \ + --port 8000 \ + --kv-transfer-config \ + '{ + "kv_connector": "UCMConnector", + "kv_connector_module_path": "ucm.integration.vllm.ucm_connector", + "kv_role": "kv_both", + "kv_connector_extra_config": { + "UCM_CONFIG_FILE": "/vllm-workspace/unified-cache-management/examples/ucm_config.yaml" + } + }' +``` +**Note**: You can refer to the `ucm_config.yaml` file at https://github.com/ModelEngine-Group/unified-cache-management/tree/develop/examples to configure the `metrics_config_path` parameter. + +You can use the `vllm bench serve` command to run benchmarks: + +```bash +vllm bench serve \ + --backend vllm \ + --model /home/models/Qwen2.5-14B-Instruct \ + --host 127.0.0.1 \ + --port 8000 \ + --dataset-name random \ + --num-prompts 20 \ + --random-input-len 200 \ + --random-output-len 10 \ + --request-rate 1 \ + --ignore-eos +``` + +Once the HTTP server is running, you can access the UCM metrics at the `/metrics` endpoint. + +```bash +curl http://$:8000/metrics | grep ucm: +``` + +You will also find some `.db` files in the `$PROMETHEUS_MULTIPROC_DIR` directory, which are temporary files used by Prometheus. + +### 2) Start Prometheus and Grafana with Docker Compose + +#### Create Docker Compose Configuration Files + +First, create the `docker-compose.yaml` file: + +```yaml +# docker-compose.yaml +version: "3" + +services: + prometheus: + image: prom/prometheus:latest + extra_hosts: + - "host.docker.internal:host-gateway" + ports: + - "9090:9090" + volumes: + - ${PWD}/prometheus.yaml:/etc/prometheus/prometheus.yml + + grafana: + image: grafana/grafana:latest + depends_on: + - prometheus + ports: + - "3000:3000" +``` + +Then, create the `prometheus.yaml` configuration file: + +```yaml +# prometheus.yaml +global: + scrape_interval: 5s + evaluation_interval: 30s + +scrape_configs: + - job_name: vllm + static_configs: + - targets: + - 'host.docker.internal:8000' +``` + +**Note**: Make sure the port number in `prometheus.yaml` matches the port number used when starting the vLLM service. + +#### Start Services + +Run the following command in the directory containing `docker-compose.yaml` and `prometheus.yaml`: + +```bash +docker compose up +``` + +This will start Prometheus and Grafana services. + +### 3) Configure Grafana Dashboard + +#### Access Grafana + +Navigate to `http://:3000`. Log in with the default username (`admin`) and password (`admin`). You will be prompted to change the password on first login. + +#### Add Prometheus Data Source + +1. Navigate to `http://:3000/connections/datasources/new` and select **Prometheus**. + +2. On the Prometheus configuration page, add the Prometheus server URL in the **Connection** section. For this Docker Compose setup, Grafana and Prometheus run in separate containers, but Docker creates DNS names for each container. You can directly use `http://prometheus:9090`. + +3. Click **Save & Test**. You should see a green checkmark showing "Successfully queried the Prometheus API." + +#### Import Dashboard + +1. Navigate to `http://:3000/dashboard/import`. + +2. Click **Upload JSON file**, then upload the `unified-cache-management/examples/metrics/grafana.json` file. + +3. Select the Prometheus data source configured earlier. + +4. Click **Import** to complete the import. + +You should now be able to see the UCM monitoring dashboard with real-time visualization of all 9 metrics. + +## Available Metrics + +UCM exposes various metrics to monitor its performance. The following table lists all available metrics organized by category: + +| Metric Name | Type | Description | +|------------|------|-------------| +| **Load Operation Metrics** | | | +| `ucm:load_requests_num` | Histogram | Number of requests loaded per `start_load_kv` call | +| `ucm:load_blocks_num` | Histogram | Number of blocks loaded per `start_load_kv` call | +| `ucm:load_duration` | Histogram | Time to load KV cache from UCM (milliseconds) | +| `ucm:load_speed` | Histogram | Speed of loading from UCM (GB/s) | +| **Save Operation Metrics** | | | +| `ucm:save_requests_num` | Histogram | Number of requests saved per `wait_for_save` call | +| `ucm:save_blocks_num` | Histogram | Number of blocks saved per `wait_for_save` call | +| `ucm:save_duration` | Histogram | Time to save to UCM (milliseconds) | +| `ucm:save_speed` | Histogram | Speed of saving to UCM (GB/s) | +| **Lookup Hit Rate Metrics** | | | +| `ucm:interval_lookup_hit_rates` | Histogram | Hit rate of UCM lookup requests | + +## Prometheus Configuration + +Metrics configuration is defined in the `unified-cache-management/examples/metrics/metrics_configs.yaml` file: + +```yaml +log_interval: 5 # Interval in seconds for logging metrics + +prometheus: + multiproc_dir: "/vllm-workspace" # Prometheus directory + metric_prefix: "ucm:" # Metric name prefix + + enabled_metrics: + counters: true + gauges: true + histograms: true + + histograms: + - name: "load_requests_num" + documentation: "Number of requests loaded from ucm" + buckets: [1, 5, 10, 20, 50, 100, 200, 500, 1000] + # ... other metric configurations +``` + +--- + diff --git a/docs/source/user-guide/prefix-cache/dram_store.md b/docs/source/user-guide/prefix-cache/dram_store.md deleted file mode 100644 index 1be2f30a2..000000000 --- a/docs/source/user-guide/prefix-cache/dram_store.md +++ /dev/null @@ -1,133 +0,0 @@ -# DRAM Store - -This document provides a usage example and configuration guide for the **DRAM Connector**. This connector enables offloading of KV cache from GPU HBM to CPU DRAM, helping reduce memory pressure and supporting larger models or batch sizes. - -## Performance - -### Overview -The following are the multi-concurrency performance test results of UCM in the Prefix Cache scenario under a CUDA environment, showing the performance improvements of UCM on two different models. -During the tests, HBM cache was disabled, and KV Cache was retrieved and matched only from DRAM. - -In the QwQ-32B model, the test used one H20 server with 2 GPUs. - -Here, Full Compute refers to pure VLLM inference, while DRAM80% indicates that after UCM pooling, the DRAM hit rate of the KV cache is 80%. - -The following table shows the results on the QwQ-32B model: -| **QwQ-32B** | | | | | -| ---------------: | -------------: | ------------------: | -------------: | :----------- | -| **Input length** | **Concurrent** | **Full Compute(s)** | **DRAM80%(s)** | **Speedup** | -| 4 000 | 1 | 1.0269 | 0.3102 | **+230.9 %** | -| 8 000 | 1 | 2.0902 | 0.5718 | **+265.5 %** | -| 16 000 | 1 | 4.4852 | 1.1914 | **+276.4 %** | -| 4 000 | 2 | 1.5383 | 0.4209 | **+265.4 %** | -| 8 000 | 2 | 3.1323 | 0.8231 | **+280.5 %** | -| 16 000 | 2 | 6.7984 | 1.7420 | **+290.2 %** | -| 4 000 | 4 | 2.8173 | 0.9444 | **+198.2 %** | -| 8 000 | 4 | 5.2643 | 1.8290 | **+187.8 %** | -| 16 000 | 4 | 11.3651 | 3.6706 | **+209.6 %** | -## Features - -The DRAM connector supports the following functionalities: - -- `dump`: Offload KV cache blocks from HBM to DRAM. -- `load`: Load KV cache blocks from DRAM back to HBM. -- `lookup`: Look up KV blocks stored in DRAM by block hash. -- `wait`: Ensure that all copy streams between CPU and GPU have completed. -- `commit`: Mark cache operations as complete and ready for reuse. - -## Configuration - -To use the DRAM connector, you need to configure the `connector_config` dictionary in your model's launch configuration. - -### Required Parameters - -- `max_cache_size` *(optional)*: - Specifies the maximum allowed DRAM memory usage (in **bytes**) for caching in `kv_connector_extra_config["ucm_connector_config"]`. - If not provided, it defaults to **5 GB**. -- `kv_block_size` *(optional)*: - Specifies the memory size (in **bytes**) of a single key or value cache block used in vLLM’s paged attention mechanism, which is calculated as : `block_size * head_size * total_num_kv_heads * element_size`. - -### Example: - -```python -# Allocate up to 8GB DRAM for KV cache -# KV Block size (in byte) is 262144 -kv_connector_extra_config={"ucm_connector_name": "UcmDramStore", "ucm_connector_config":{"max_cache_size": 5368709120, "kv_block_size": 262144}} -``` - -## Launching Inference - -### Offline Inference - -To start **offline inference** with the DRAM connector,modify the script `examples/offline_inference.py` to include the `kv_connector_extra_config` for DRAM connector usage: - -```python -# In examples/offline_inference.py -ktc = KVTransferConfig( - ... - kv_connector_extra_config={"ucm_connector_name": "UcmDramStore", "ucm_connector_config":{"max_cache_size": 5368709120, "kv_block_size": 262144}} -) -``` - -Then run the script as follows: - -```bash -cd examples/ -python offline_inference.py -``` - -### Online Inference - -For **online inference** , vLLM with our connector can also be deployed as a server that implements the OpenAI API protocol. - -First, specify the python hash seed by: -```bash -export PYTHONHASHSEED=123456 -``` - -Run the following command to start the vLLM server with the Qwen/Qwen2.5-14B-Instruct model: - -```bash -vllm serve /home/models/Qwen2.5-14B-Instruct \ ---max-model-len 20000 \ ---tensor-parallel-size 2 \ ---gpu_memory_utilization 0.87 \ ---trust-remote-code \ ---port 7800 \ ---kv-transfer-config \ -'{ - "kv_connector": "UnifiedCacheConnectorV1", - "kv_connector_module_path": "ucm.integration.vllm.uc_connector", - "kv_role": "kv_both", - "kv_connector_extra_config": { - "ucm_connector_name": "UcmDramStore", - "ucm_connector_config": { - "max_cache_size": 5368709120, - "kv_block_size": 262144 - } - } -}' -``` - -If you see log as below: - -```bash -INFO: Started server process [32890] -INFO: Waiting for application startup. -INFO: Application startup complete. -``` - -Congratulations, you have successfully started the vLLM server with DRAM Connector! - -After successfully started the vLLM server,You can interact with the API as following: - -```bash -curl http://localhost:7800/v1/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "/home/models/Qwen2.5-14B-Instruct", - "prompt": "Shanghai is a", - "max_tokens": 7, - "temperature": 0 - }' -``` diff --git a/docs/source/user-guide/prefix-cache/index.md b/docs/source/user-guide/prefix-cache/index.md index defe27d38..ba3d16bef 100644 --- a/docs/source/user-guide/prefix-cache/index.md +++ b/docs/source/user-guide/prefix-cache/index.md @@ -79,6 +79,5 @@ performance. :::{toctree} :maxdepth: 1 -dram_store nfs_store ::: \ No newline at end of file diff --git a/docs/source/user-guide/prefix-cache/nfs_store.md b/docs/source/user-guide/prefix-cache/nfs_store.md index b581acf56..741fcedf7 100644 --- a/docs/source/user-guide/prefix-cache/nfs_store.md +++ b/docs/source/user-guide/prefix-cache/nfs_store.md @@ -87,8 +87,15 @@ To use the NFS connector, you need to configure the `connector_config` dictionar ### Example: -```python -kv_connector_extra_config={"ucm_connector_name": "UcmNfsStore", "ucm_connector_config":{"storage_backends": "/mnt/test1", "transferStreamNumber": 32}} +Create a config yaml like following and save it to your own directory: +```yaml +# UCM Configuration File Example +# Refer to file unified-cache-management/examples/ucm_config_example.yaml for more details +ucm_connector_name: "UcmNfsStore" + +ucm_connector_config: + storage_backends: "/mnt/test" + transferStreamNumber: 32 ``` ## Launching Inference @@ -101,7 +108,7 @@ To start **offline inference** with the NFS connector,modify the script `examp # In examples/offline_inference.py ktc = KVTransferConfig( ... - kv_connector_extra_config={"ucm_connector_name": "UcmNfsStore", "ucm_connector_config":{"storage_backends": "/mnt/test1", "transferStreamNumber": 32}} + kv_connector_extra_config={"UCM_CONFIG_FILE": "/workspace/unified-cache-management/examples/ucm_config_example.yaml"} ) ``` @@ -131,13 +138,7 @@ vllm serve /home/models/Qwen2.5-14B-Instruct \ "kv_connector": "UnifiedCacheConnectorV1", "kv_connector_module_path": "ucm.integration.vllm.uc_connector", "kv_role": "kv_both", - "kv_connector_extra_config": { - "ucm_connector_name": "UcmNfsStore", - "ucm_connector_config": { - "storage_backends": "/mnt/test", - "transferStreamNumber":32 - } - } + "kv_connector_extra_config": {"UCM_CONFIG_FILE": "/workspace/unified-cache-management/examples/ucm_config_example.yaml"} }' ``` diff --git a/examples/offline_inference.py b/examples/offline_inference.py index f50682464..5a2fea372 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -1,5 +1,4 @@ import contextlib -import json import os import time from dataclasses import asdict @@ -16,11 +15,6 @@ logger = init_logger(__name__) -def setup_environment_variables(): - os.environ["VLLM_USE_V1"] = "1" - os.environ["PYTHONHASHSEED"] = "123456" - - @contextlib.contextmanager def build_llm_with_uc(module_path: str, name: str, model: str): ktc = KVTransferConfig( @@ -28,20 +22,7 @@ def build_llm_with_uc(module_path: str, name: str, model: str): kv_connector_module_path=module_path, kv_role="kv_both", kv_connector_extra_config={ - "ucm_connector_name": "UcmDramStore", - "ucm_connector_config": { - "max_cache_size": 5368709120, - "kv_block_size": 262144, - }, - "ucm_sparse_config": { - "ESA": { - "init_window_sz": 1, - "local_window_sz": 2, - "min_blocks": 4, - "sparse_ratio": 0.3, - "retrieval_stride": 5, - } - }, + "UCM_CONFIG_FILE": "/workspace/unified-cache-management/examples/ucm_config_example.yaml" }, ) @@ -53,6 +34,8 @@ def build_llm_with_uc(module_path: str, name: str, model: str): max_num_batched_tokens=30000, block_size=128, enforce_eager=True, + trust_remote_code=True, + enable_prefix_caching=False, ) llm = LLM(**asdict(llm_args)) @@ -79,22 +62,41 @@ def print_output( def main(): - module_path = "ucm.integration.vllm.uc_connector" - name = "UnifiedCacheConnectorV1" - model = os.getenv("MODEL_PATH", "/home/models/Qwen2.5-14B-Instruct") + module_path = "ucm.integration.vllm.ucm_connector" + name = "UCMConnector" + model = os.getenv("MODEL_PATH", "/home/models/DeepSeek-V2-Lite") tokenizer = AutoTokenizer.from_pretrained(model, use_chat_template=True) - setup_environment_variables() with build_llm_with_uc(module_path, name, model) as llm: messages = [ { "role": "system", - "content": "You are a highly specialized assistant whose mission is to faithfully reproduce English literary texts verbatim, without any deviation, paraphrasing, or omission. Your primary responsibility is accuracy: every word, every punctuation mark, and every line must appear exactly as in the original source. Core Principles: Verbatim Reproduction: If the user asks for a passage, you must output the text word-for-word. Do not alter spelling, punctuation, capitalization, or line breaks. Do not paraphrase, summarize, modernize, or “improve” the language. Consistency: The same input must always yield the same output. Do not generate alternative versions or interpretations. Clarity of Scope: Your role is not to explain, interpret, or critique. You are not a storyteller or commentator, but a faithful copyist of English literary and cultural texts. Recognizability: Because texts must be reproduced exactly, they will carry their own cultural recognition. You should not add labels, introductions, or explanations before or after the text. Coverage: You must handle passages from classic literature, poetry, speeches, or cultural texts. Regardless of tone—solemn, visionary, poetic, persuasive—you must preserve the original form, structure, and rhythm by reproducing it precisely. Success Criteria: A human reader should be able to compare your output directly with the original and find zero differences. The measure of success is absolute textual fidelity. Your function can be summarized as follows: verbatim reproduction only, no paraphrase, no commentary, no embellishment, no omission.", + "content": "You are a highly specialized assistant whose mission is to faithfully reproduce English " + "literary texts verbatim, without any deviation, paraphrasing, or omission. Your primary " + "responsibility is accuracy: every word, every punctuation mark, and every line must " + "appear exactly as in the original source. Core Principles: Verbatim Reproduction: If the " + "user asks for a passage, you must output the text word-for-word. Do not alter spelling, " + "punctuation, capitalization, or line breaks. Do not paraphrase, summarize, modernize, " + "or “improve” the language. Consistency: The same input must always yield the same output. " + "Do not generate alternative versions or interpretations. Clarity of Scope: Your role is " + "not to explain, interpret, or critique. You are not a storyteller or commentator, " + "but a faithful copyist of English literary and cultural texts. Recognizability: Because " + "texts must be reproduced exactly, they will carry their own cultural recognition. You " + "should not add labels, introductions, or explanations before or after the text. Coverage: " + "You must handle passages from classic literature, poetry, speeches, or cultural texts. " + "Regardless of tone—solemn, visionary, poetic, persuasive—you must preserve the original " + "form, structure, and rhythm by reproducing it precisely. Success Criteria: A human reader " + "should be able to compare your output directly with the original and find zero " + "differences. The measure of success is absolute textual fidelity. Your function can be " + "summarized as follows: verbatim reproduction only, no paraphrase, no commentary, " + "no embellishment, no omission.", }, { "role": "user", - "content": "Please reproduce verbatim the opening sentence of the United States Declaration of Independence (1776), starting with 'When in the Course of human events' and continuing word-for-word without paraphrasing.", + "content": "Please reproduce verbatim the opening sentence of the United States Declaration of " + "Independence (1776), starting with 'When in the Course of human events' and continuing " + "word-for-word without paraphrasing.", }, ] diff --git a/examples/ucm_config_example.yaml b/examples/ucm_config_example.yaml new file mode 100644 index 000000000..b72079420 --- /dev/null +++ b/examples/ucm_config_example.yaml @@ -0,0 +1,35 @@ +# UCM Configuration File Example +# +# This file demonstrates how to configure UCM using YAML. +# You can use this config file by setting the path to this file in kv_connector_extra_config in launch script or command line like this: +# kv_connector_extra_config={"UCM_CONFIG_FILE": "/workspace/unified-cache-management/examples/ucm_config_example.yaml"} +# +# Alternatively, you can still use kv_connector_extra_config in KVTransferConfig +# for backward compatibility. + +# Connector name (e.g., "UcmNfsStore", "UcmDramStore") +ucm_connectors: + - ucm_connector_name: "UcmNfsStore" + ucm_connector_config: + storage_backends: "/mnt/test" + use_direct: false + +load_only_first_rank: false + +# Sparse attention configuration +# Format 1: Dictionary format (for methods like ESA, KvComp) +# ucm_sparse_config: +# ESA: +# init_window_sz: 1 +# local_window_sz: 2 +# min_blocks: 4 +# sparse_ratio: 0.3 +# retrieval_stride: 5 + # Or for GSA: + # GSA: {} + + +# Whether to use layerwise loading/saving (optional, default: True for UnifiedCacheConnectorV1) +# use_layerwise: true +# hit_ratio: 0.9 + diff --git a/setup.py b/setup.py index 8c462dabe..5a617c747 100644 --- a/setup.py +++ b/setup.py @@ -135,7 +135,7 @@ def _get_package_data_with_so(): setup( name="uc-manager", - version="0.1.0rc3", + version="0.1.0rc4", description="Unified Cache Management", author="Unified Cache Team", packages=find_packages(), diff --git a/test/common/capture_utils.py b/test/common/capture_utils.py index ee12ed2a5..b12b76637 100644 --- a/test/common/capture_utils.py +++ b/test/common/capture_utils.py @@ -1,3 +1,4 @@ +import functools from typing import Any, Dict, List from common.db_utils import write_to_db @@ -44,6 +45,7 @@ def post_process(table_name: str, **kwargs) -> List[Dict[str, Any]]: # ---------------- decorator ---------------- def export_vars(func): + @functools.wraps(func) def wrapper(*args, **kwargs): result = func(*args, **kwargs) # If the function returns a dict containing '_data' or 'data', post-process it diff --git a/test/common/llmperf/__init__.py b/test/common/llmperf/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/common/llmperf/run_inference.py b/test/common/llmperf/run_inference.py new file mode 100644 index 000000000..b04deb1ea --- /dev/null +++ b/test/common/llmperf/run_inference.py @@ -0,0 +1,185 @@ +import json +import os +import random +from pathlib import Path +from typing import Any, Dict, List + +import yaml +from common.llmperf.utils.token_benchmark import run_token_benchmark +from common.llmperf.utils.utils import reset_prefill_cache + + +def run_test_cases( + llm_api, + model, + timeout, + max_num_completed_requests, + concurrent_requests, + mean_input_tokens, + stddev_input, + mean_output_tokens, + stddev_output, + additional_sampling_params, + timestamp_dir, + server_url, + tokenizer_path, + hit_rate, +): + print(f"[INFO] Total {len(mean_input_tokens)} test cases to be executed") + all_summaries = [] + failed_case = [] + + # Clear proxy environment variables + env = os.environ.copy() + env.pop("http_proxy", None) + env.pop("https_proxy", None) + + for i, ( + mean_input, + mean_output, + max_completed, + concurrent, + additional_sampling_params, + hit_rate_val, + ) in enumerate( + zip( + mean_input_tokens, + mean_output_tokens, + max_num_completed_requests, + concurrent_requests, + additional_sampling_params, + hit_rate, + ), + start=1, + ): + # for i, case in enumerate(mean_input_tokens): + print(f"\n>>> Executing test case {i} <<<") + reset_prefill_cache(env, server_url) + # Use a fixed random_seed for each test to control PC hit_rate + random_seed = random.randint(1, 100000) + + try: + # Determine if two runs are needed (PC hit_rate test) + if hit_rate_val == 0: + summary = run_token_benchmark( + llm_api=llm_api, + model=model, + test_timeout_s=timeout, + max_num_completed_requests=max_completed, + concurrent_requests=concurrent, + mean_input_tokens=mean_input, + stddev_input_tokens=stddev_input, + mean_output_tokens=mean_output, + stddev_output_tokens=stddev_output, + additional_sampling_params=additional_sampling_params, + results_dir=str(timestamp_dir), + random_seed=random_seed, + openai_api_base=server_url + "/v1", + tokenizer_path=tokenizer_path, + user_metadata={"case_idx": i, "phase": "normal"}, + ) + else: + print( + f"[INFO] hit_rate > 0 detected, entering prefill mode, PC hit rate: {hit_rate_val} %" + ) + # hit_rate > 0: first prefill mode + prefill_mean_input = int(mean_input * hit_rate_val / 100) + print( + f"[INFO] Prefill execution: mean_input_tokens={prefill_mean_input}" + ) + run_token_benchmark( + llm_api=llm_api, + model=model, + test_timeout_s=timeout, + max_num_completed_requests=max_completed, + concurrent_requests=concurrent, + mean_input_tokens=prefill_mean_input, + stddev_input_tokens=stddev_input, + mean_output_tokens=2, + stddev_output_tokens=stddev_output, + additional_sampling_params=additional_sampling_params, + results_dir=str(timestamp_dir), + random_seed=random_seed, + openai_api_base=server_url + "/v1", + tokenizer_path=tokenizer_path, + user_metadata={"case_idx": i, "phase": "prefill"}, + ) + reset_prefill_cache(env, server_url) + # Then run normal mode + print("[INFO] Prefill completed, switching to normal mode execution") + summary = run_token_benchmark( + llm_api=llm_api, + model=model, + test_timeout_s=timeout, + max_num_completed_requests=max_completed, + concurrent_requests=concurrent, + mean_input_tokens=mean_input, + stddev_input_tokens=stddev_input, + mean_output_tokens=mean_output, + stddev_output_tokens=stddev_output, + additional_sampling_params=additional_sampling_params, + results_dir=str(timestamp_dir), + random_seed=random_seed, + openai_api_base=server_url + "/v1", + tokenizer_path=tokenizer_path, + user_metadata={"case_idx": i, "phase": "normal"}, + ) + all_summaries.append(summary) + except Exception as e: + print(f"[Warning] {e}") + failed_case.append(i) + + return all_summaries, failed_case + + +def inference_results( + mean_input_tokens, + mean_output_tokens, + max_num_completed_requests, + concurrent_requests, + additional_sampling_params, + hit_rate, +): + config_file = Path(__file__).parent.parent.parent / "config.yaml" + print("[INFO] Initialization complete, starting main process") + print(f"[INFO] Reading configuration file: {config_file}") + with open(config_file, "r", encoding="utf-8") as f: + config = yaml.safe_load(f) + llm_api = config.get("llm_connection", {}).get("llm_api", "openai") + model = config.get("llm_connection", {}).get("model", "") + test_timeout_s = config.get("llm_connection", {}).get("test_timeout_s", 60000) + stddev_input_tokens = config.get("llm_connection", {}).get( + "stddev_input_tokens", 0 + ) + stddev_output_tokens = config.get("llm_connection", {}).get( + "stddev_output_tokens", 0 + ) + timestamp_dir = Path("results") + timestamp_dir.mkdir(parents=True, exist_ok=True) + server_url = config.get("llm_connection", {}).get("server_url", "") + tokenizer_path = config.get("llm_connection", {}).get("tokenizer_path", "") + print(f"[INFO] Created results directory: {timestamp_dir}") + + all_summaries, failed_cases = run_test_cases( + llm_api, + model, + test_timeout_s, + max_num_completed_requests, + concurrent_requests, + mean_input_tokens, + stddev_input_tokens, + mean_output_tokens, + stddev_output_tokens, + additional_sampling_params, + timestamp_dir, + server_url, + tokenizer_path, + hit_rate, + ) + total = len(mean_input_tokens) + print( + f"\n[INFO] All tests completed! Success: {total - len(failed_cases)}/{total}" + ) + if failed_cases: + print(f"[WARN] Failed case indices: {failed_cases}") + return all_summaries diff --git a/test/common/llmperf/utils/__init__.py b/test/common/llmperf/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/common/llmperf/utils/common_metrics.py b/test/common/llmperf/utils/common_metrics.py new file mode 100644 index 000000000..40e21124e --- /dev/null +++ b/test/common/llmperf/utils/common_metrics.py @@ -0,0 +1,17 @@ +# TODO (Avnishn): compute metrics in class +INTER_TOKEN_LAT = "inter_token_latency_s" +TTFT = "ttft_s" +E2E_LAT = "end_to_end_latency_s" +NUM_INPUT_TOKENS = "number_input_tokens" +NUM_OUTPUT_TOKENS = "number_output_tokens" +NUM_TOTAL_TOKENS = "number_total_tokens" +REQ_OUTPUT_THROUGHPUT = "request_output_throughput_token_per_s" +ERROR_MSG = "error_msg" +ERROR_CODE = "error_code" +ERROR_CODE_FREQ = "error_code_frequency" +NUM_ERRORS = "number_errors" +OUTPUT_THROUGHPUT = "mean_output_throughput_token_per_s" +NUM_COMPLETED_REQUESTS = "num_completed_requests" +COMPLETED_REQUESTS_PER_MIN = "num_completed_requests_per_min" +ERROR_RATE = "error_rate" +NUM_REQ_STARTED = "num_requests_started" diff --git a/test/common/llmperf/utils/models.py b/test/common/llmperf/utils/models.py new file mode 100644 index 000000000..1cbab6281 --- /dev/null +++ b/test/common/llmperf/utils/models.py @@ -0,0 +1,23 @@ +from typing import Any, Dict, Optional, Tuple + +from pydantic import BaseModel + + +class RequestConfig(BaseModel): + """The configuration for a request to the LLM API. + + Args: + model: The model to use. + prompt: The prompt to provide to the LLM API. + sampling_params: Additional sampling parameters to send with the request. + For more information see the Router app's documentation for the completions + llm_api: The name of the LLM API to send the request to. + metadata: Additional metadata to attach to the request for logging or validation purposes. + """ + + model: str + prompt: Tuple[str, int] + sampling_params: Optional[Dict[str, Any]] = None + llm_api: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None + openai_api_base: Optional[str] = "" diff --git a/test/common/llmperf/utils/openai_chat_completions_client.py b/test/common/llmperf/utils/openai_chat_completions_client.py new file mode 100644 index 000000000..5023bfa1f --- /dev/null +++ b/test/common/llmperf/utils/openai_chat_completions_client.py @@ -0,0 +1,136 @@ +import json +import os +import time +from asyncio import timeout +from pathlib import Path +from typing import Any, Dict, Tuple + +import requests +import yaml +from common.llmperf.utils import common_metrics +from common.llmperf.utils.models import RequestConfig + +config_file = Path(__file__).parent.parent.parent.parent / "config.yaml" +with open(config_file, "r", encoding="utf-8") as f: + config = yaml.safe_load(f) +stream = config.get("llm_connection", {}).get("stream", True) +ignore_eos = config.get("llm_connection", {}).get("ignore_eos", True) +timeout = config.get("llm_connection", {}).get("timeout", 180) + + +class OpenAIChatCompletionsClient: + """ + used for sending HTTP requests, receiving token streams, measuring latency, etc. + """ + + def llm_request( + self, request_config: RequestConfig + ) -> Tuple[Dict[str, Any], str, RequestConfig]: + prompt, prompt_len = request_config.prompt + + message = [ + {"role": "user", "content": prompt}, + ] + model = request_config.model + body = { + "model": model, + "messages": message, + "stream": stream, + "ignore_eos": ignore_eos, + } + sampling_params = request_config.sampling_params + body.update(sampling_params or {}) + + time_to_next_token = [] + tokens_received = 0 + ttft = 0.0 + error_response_code = None + generated_text = "" + error_msg = "" + output_throughput = 0.0 + total_request_time = 0.0 + flag = False + + metrics: Dict[str, Any] = {} + + metrics[common_metrics.ERROR_CODE] = None + metrics[common_metrics.ERROR_MSG] = "" + + start_time = time.monotonic() + most_recent_received_token_time = start_time + + address = request_config.openai_api_base + + if not address: + raise ValueError("the environment variable OPENAI_API_BASE must be set.") + key = os.environ.get("OPENAI_API_KEY", "secret_abcdefg") + if not key: + raise ValueError("the environment variable OPENAI_API_KEY must be set.") + headers = {"Authorization": f"Bearer {key}"} + if not address.endswith("/"): + address = address + "/" + address += "chat/completions" + try: + with requests.post( + address, + json=body, + stream=stream, + timeout=timeout, + headers=headers, + ) as response: + if response.status_code != 200: + error_msg = response.text + error_response_code = response.status_code + response.raise_for_status() + + for chunk in response.iter_lines(chunk_size=None): + if not chunk: + continue + stem = b"data: " + if chunk.startswith(stem): + chunk = chunk[len(stem) :] + # Data might already be bytes or str + if isinstance(chunk, bytes): + chunk = chunk.decode("utf-8", errors="ignore") + if chunk.strip() == "[DONE]": + continue + tokens_received += 1 + data = json.loads(chunk) + if "error" in data: + error_msg = data["error"]["message"] + error_response_code = data["error"]["code"] + raise RuntimeError(error_msg) + delta = data["choices"][0]["delta"] + content = delta.get("content", None) or delta.get( + "reasoning_content", "" + ) + if content: + if tokens_received != 0 and flag == False: + ttft = time.monotonic() - start_time + flag = True + else: + time_to_next_token.append( + time.monotonic() - most_recent_received_token_time + ) + most_recent_received_token_time = time.monotonic() + generated_text += content + + total_request_time = time.monotonic() - start_time + if total_request_time > 0: + output_throughput = tokens_received / total_request_time + + except Exception as e: + metrics[common_metrics.ERROR_MSG] = error_msg + metrics[common_metrics.ERROR_CODE] = error_response_code + print(f"Warning Or Error: {e}") + print(error_response_code) + + metrics[common_metrics.INTER_TOKEN_LAT] = sum(time_to_next_token) + metrics[common_metrics.TTFT] = ttft + metrics[common_metrics.E2E_LAT] = total_request_time + metrics[common_metrics.REQ_OUTPUT_THROUGHPUT] = output_throughput + metrics[common_metrics.NUM_TOTAL_TOKENS] = tokens_received + prompt_len + metrics[common_metrics.NUM_OUTPUT_TOKENS] = tokens_received + metrics[common_metrics.NUM_INPUT_TOKENS] = prompt_len + + return metrics, generated_text, request_config diff --git a/test/common/llmperf/utils/sonnet.txt b/test/common/llmperf/utils/sonnet.txt new file mode 100644 index 000000000..9f13ead47 --- /dev/null +++ b/test/common/llmperf/utils/sonnet.txt @@ -0,0 +1,84 @@ +Shall I compare thee to a summer's day? +Thou art more lovely and more temperate: +Rough winds do shake the darling buds of May, +And summer's lease hath all too short a date: +Sometime too hot the eye of heaven shines, +And often is his gold complexion dimm'd; +And every fair from fair sometime declines, +By chance or nature's changing course untrimm'd; +But thy eternal summer shall not fade +Nor lose possession of that fair thou owest; +Nor shall Death brag thou wander'st in his shade, +When in eternal lines to time thou growest: +So long as men can breathe or eyes can see, +So long lives this and this gives life to thee. +Then let not winter's ragged hand deface +In thee thy summer, ere thou be distill'd: +Make sweet some vial; treasure thou some place +With beauty's treasure, ere it be self-kill'd. +That use is not forbidden usury, +Which happies those that pay the willing loan; +That's for thyself to breed another thee, +Or ten times happier, be it ten for one; +Ten times thyself were happier than thou art, +If ten of thine ten times refigured thee: +Then what could death do, if thou shouldst depart, +Leaving thee living in posterity? +Be not self-will'd, for thou art much too fair +To be death's conquest and make worms thine heir. +Where art thou, Muse, that thou forget'st so long +To speak of that which gives thee all thy might? +Spend'st thou thy fury on some worthless song, +Darkening thy power to lend base subjects light? +Return, forgetful Muse, and straight redeem +In gentle numbers time so idly spent; +Sing to the ear that doth thy lays esteem +And gives thy pen both skill and argument. +Rise, resty Muse, my love's sweet face survey, +If Time have any wrinkle graven there; +If any, be a satire to decay, +And make Time's spoils despised every where. +Give my love fame faster than Time wastes life; +So thou prevent'st his scythe and crooked knife. +My glass shall not persuade me I am old, +So long as youth and thou are of one date; +But when in thee time's furrows I behold, +Then look I death my days should expiate. +For all that beauty that doth cover thee +Is but the seemly raiment of my heart, +Which in thy breast doth live, as thine in me: +How can I then be elder than thou art? +O, therefore, love, be of thyself so wary +As I, not for myself, but for thee will; +Bearing thy heart, which I will keep so chary +As tender nurse her babe from faring ill. +Presume not on thy heart when mine is slain; +Thou gavest me thine, not to give back again. +So am I as the rich, whose blessed key +Can bring him to his sweet up-locked treasure, +The which he will not every hour survey, +For blunting the fine point of seldom pleasure. +Therefore are feasts so solemn and so rare, +Since, seldom coming, in the long year set, +Like stones of worth they thinly placed are, +Or captain jewels in the carcanet. +So is the time that keeps you as my chest, +Or as the wardrobe which the robe doth hide, +To make some special instant special blest, +By new unfolding his imprison'd pride. +Blessed are you, whose worthiness gives scope, +Being had, to triumph, being lack'd, to hope. +If there be nothing new, but that which is +Hath been before, how are our brains beguiled, +Which, labouring for invention, bear amiss +The second burden of a former child! +O, that record could with a backward look, +Even of five hundred courses of the sun, +Show me your image in some antique book, +Since mind at first in character was done! +That I might see what the old world could say +To this composed wonder of your frame; +Whether we are mended, or whether better they, +Or whether revolution be the same. +O, sure I am, the wits of former days +To subjects worse have given admiring praise. \ No newline at end of file diff --git a/test/common/llmperf/utils/token_benchmark.py b/test/common/llmperf/utils/token_benchmark.py new file mode 100644 index 000000000..67553cf1b --- /dev/null +++ b/test/common/llmperf/utils/token_benchmark.py @@ -0,0 +1,386 @@ +import json +import logging +import random +import re +import time +from collections.abc import Iterable +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import pandas as pd +from common.llmperf.utils import common_metrics +from common.llmperf.utils.models import RequestConfig +from common.llmperf.utils.openai_chat_completions_client import ( + OpenAIChatCompletionsClient, +) +from common.llmperf.utils.utils import ( + LLMPerfResults, + randomly_sample_sonnet_lines_prompt, + sample_random_positive_int, +) +from transformers import AutoTokenizer + + +def get_token_throughput_latencies( + model: str, + mean_input_tokens: int, + stddev_input_tokens: int, + mean_output_tokens: int, + stddev_output_tokens: int, + additional_sampling_params: Optional[Dict[str, Any]] = None, + concurrent_requests: int = 1, + max_num_completed_requests: int = 500, + test_timeout_s=90, + llm_api="openai", + random_seed: int = None, + openai_api_base: str = "", + tokenizer_path: str = None, +) -> Tuple[Dict[str, Any], List[Dict[str, Any]], float, float]: + """Get the token throughput and latencies for the given model. + + Args: + model: The name of the model to query. + mean_input_tokens: The mean number of tokens to send in the prompt for the request. + stddev_input_tokens: The standard deviation of the number of tokens to send in the prompt for the request. + mean_output_tokens: The mean number of tokens to generate per request. + stddev_output_tokens: The standard deviation of the number of tokens to generate per request. + additional_sampling_params: Additional sampling parameters to send with the request. + For more information see the LLM APIs documentation for the completions + concurrent_requests: The number of concurrent requests to make. Increase + this to increase the amount of load and vice versa. + test_timeout_s: The amount of time to run the test for before reporting results. + llm_api: The name of the llm api to use. Either "openai" or "litellm". + + Returns: + A summary of the performance metrics collected across all completed requests + (e.g. throughput, latencies, etc.) + The individual metrics for each request. + """ + random.seed(random_seed) + + print(f"Using tokenizer:{tokenizer_path}") + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + get_token_length = lambda text: len(tokenizer.encode(text)) + + if not additional_sampling_params: + additional_sampling_params = {} + + # 1. create prompts + prompts: List[Tuple[str, int]] = [] + num_output_tokens_list: List[int] = [] + for i in range(max_num_completed_requests): + num_output = sample_random_positive_int( + mean_output_tokens, stddev_output_tokens + ) + num_output_tokens_list.append(num_output) + prompts.append( + randomly_sample_sonnet_lines_prompt( + prompt_tokens_mean=mean_input_tokens, + prompt_tokens_stddev=stddev_input_tokens, + tokenizer=tokenizer, + ) + ) + start_time = time.monotonic() + completed_requests: List[Dict[str, Any]] = [] + incremental_time_delay = 0.0 + client = OpenAIChatCompletionsClient() + futures = [] + + # 2. Submitting tasks using a thread pool + with ThreadPoolExecutor(max_workers=concurrent_requests) as executor: + for idx in range(max_num_completed_requests): + sampling = {"max_tokens": num_output_tokens_list[idx]} + sampling.update(additional_sampling_params) + cfg = RequestConfig( + model=model, + prompt=prompts[idx], + sampling_params=sampling, + llm_api=llm_api, + openai_api_base=openai_api_base, + ) + futures.append(executor.submit(client.llm_request, cfg)) + # 3. Waiting for completion or timeout + for future in as_completed(futures, timeout=test_timeout_s): + try: + metrics, gen_text, req_cfg = future.result() + except Exception as e: + logging.warning(f"[WARN] Future raised exception: {e}") + continue + num_output_tokens = get_token_length(gen_text) + if num_output_tokens: + metrics[common_metrics.INTER_TOKEN_LAT] /= ( + (metrics[common_metrics.NUM_OUTPUT_TOKENS] - 1) + if (metrics[common_metrics.NUM_OUTPUT_TOKENS] - 1) + else 1 + ) + metrics[common_metrics.NUM_OUTPUT_TOKENS] = num_output_tokens + metrics[common_metrics.NUM_TOTAL_TOKENS] = ( + metrics[common_metrics.NUM_INPUT_TOKENS] + num_output_tokens + ) + try: + metrics[common_metrics.REQ_OUTPUT_THROUGHPUT] = ( + num_output_tokens / metrics[common_metrics.E2E_LAT] + ) + except ZeroDivisionError: + logging.error("Division by zero in throughput calculation.") + + completed_requests.append(metrics) + + incremental_time_delay += metrics.get( + common_metrics.INTER_TOKEN_LAT, 0.0 + ) + + end_time = time.monotonic() + + print(f"Results for token benchmark for {model} queried with the {llm_api} api.\n") + if mean_output_tokens == 2: + print(f"[INFO] First token sending pre-embedding completed\n") + return {}, [], 0.0, 0.0 + + ret = metrics_summary(completed_requests, start_time, end_time) + + metadata = { + "model": model, + "mean_input_tokens": mean_input_tokens, + "stddev_input_tokens": stddev_input_tokens, + "mean_output_tokens": mean_output_tokens, + "stddev_output_tokens": stddev_output_tokens, + "concurrent_requests": concurrent_requests, + "additional_sampling_params": additional_sampling_params, + } + + metadata["results"] = ret + elapsed_time = end_time - start_time + return metadata, completed_requests, elapsed_time, incremental_time_delay + + +def compute_throughput( + summary: Dict[str, Any], + completed_requests: List[Dict[str, Any]], + elapsed_time: float, + incremental_time_delay: float, +) -> Tuple[float, float]: + """ + Compute total_throughput (token/s) based on the metrics in summary. + + Formula: (mean_output_tokens * num_completed_requests) / total_e2e_latency_s + + Args: + summary (Dict[str, Any]): A dictionary containing performance metrics. + + Returns: + float: The computed total throughput in tokens per second. Returns 0.0 if latency is zero. + """ + mean_output_tokens = summary.get("mean_output_tokens", 0) + + total_throughput = ( + (mean_output_tokens * len(completed_requests)) / elapsed_time + if elapsed_time > 0 + else 0.0 + ) + incremental_throughput = ( + (mean_output_tokens * len(completed_requests)) / incremental_time_delay + if incremental_time_delay > 0 + else 0.0 + ) + return round(total_throughput, 4), round(incremental_throughput, 4) + + +def metrics_summary( + metrics: List[Dict[str, Any]], start_time: int, end_time: int +) -> Dict[str, Any]: + """Generate a summary over metrics generated from potentially multiple instances of this client. + + Args: + metrics: The metrics to summarize. + start_time: The time the test started. + end_time: The time the test ended. + + Returns: + A summary with the following information: + - Overall throughput (generated tokens / total test time) + - Number of completed requests + - Error rate + - Error code frequency + - Quantiles (p25-p99) for the following metrics: + - Inter token latency + - Time to first token + - User total request time + - Number of tokens processed per request + - Number of tokens generated per request + - User throughput (tokens / s) + """ + ret = {} + + def flatten(item): + for sub_item in item: + if isinstance(sub_item, Iterable) and not isinstance(sub_item, str): + yield from flatten(sub_item) + else: + yield sub_item + + df = pd.DataFrame(metrics) + df_without_errored_req = df[df[common_metrics.ERROR_CODE].isna()] + + for key in [ + common_metrics.INTER_TOKEN_LAT, + common_metrics.TTFT, + common_metrics.E2E_LAT, + common_metrics.REQ_OUTPUT_THROUGHPUT, + common_metrics.NUM_INPUT_TOKENS, + common_metrics.NUM_OUTPUT_TOKENS, + ]: + print(key) + ret[key] = {} + series = pd.Series(list(flatten(df_without_errored_req[key]))).dropna() + series = series[series > 0] # Calculate non-zero values + quantiles = series.quantile([0.25, 0.5, 0.75, 0.9, 0.95, 0.99]).to_dict() + quantiles_reformatted_keys = {} + for quantile, value in quantiles.items(): + reformatted_key = f"p{int(quantile * 100)}" + print(f" {reformatted_key} = {value}") + quantiles_reformatted_keys[reformatted_key] = value + ret[key]["quantiles"] = quantiles_reformatted_keys + mean = series.mean() + print(f" mean = {mean}") + ret[key]["mean"] = mean + print(f" min = {series.min()}") + ret[key]["min"] = series.min() + print(f" max = {series.max()}") + ret[key]["max"] = series.max() + print(f" stddev = {series.std()}") + ret[key]["stddev"] = series.std() + + ret[common_metrics.NUM_REQ_STARTED] = len(metrics) + + error_codes = df[common_metrics.ERROR_CODE].dropna() + num_errors = len(error_codes) + ret[common_metrics.ERROR_RATE] = num_errors / len(metrics) if len(metrics) else 0 + ret[common_metrics.NUM_ERRORS] = num_errors + print(f"Number Of Errored Requests: {num_errors}") + error_code_frequency = dict(error_codes.value_counts()) + if num_errors: + error_code_frequency = dict(error_codes.value_counts()) + print("Error Code Frequency") + print(error_code_frequency) + ret[common_metrics.ERROR_CODE_FREQ] = str(error_code_frequency) + + overall_output_throughput = df_without_errored_req[ + common_metrics.NUM_OUTPUT_TOKENS + ].sum() / (end_time - start_time) + + print(f"Overall Output Throughput: {overall_output_throughput}") + ret[common_metrics.OUTPUT_THROUGHPUT] = overall_output_throughput + + num_completed_requests = len(df_without_errored_req) + num_completed_requests_per_min = ( + num_completed_requests / (end_time - start_time) * 60 + ) + print(f"Number Of Completed Requests: {num_completed_requests}") + print(f"Completed Requests Per Minute: {num_completed_requests_per_min}") + + ret[common_metrics.NUM_COMPLETED_REQUESTS] = num_completed_requests + ret[common_metrics.COMPLETED_REQUESTS_PER_MIN] = num_completed_requests_per_min + + return ret + + +def run_token_benchmark( + llm_api: str, + model: str, + test_timeout_s: int, + max_num_completed_requests: int, + concurrent_requests: int, + mean_input_tokens: int, + stddev_input_tokens: int, + mean_output_tokens: int, + stddev_output_tokens: int, + additional_sampling_params: str, + results_dir: str, + random_seed: int, + openai_api_base: str, + tokenizer_path: str, + user_metadata: Dict[str, Any], +): + """ + Args: + llm_api: The name of the llm api to use. + model: The name of the model to query. + max_num_completed_requests: The number of requests to complete before finishing the test. + test_timeout_s: The amount of time to run the test for before reporting results. + concurrent_requests: The number of concurrent requests to make. Increase + this to increase the amount of load and vice versa. + mean_input_tokens: The mean number of tokens to send in the prompt for the request. + stddev_input_tokens: The standard deviation of the number of tokens to send in the prompt for the request. + mean_output_tokens: The mean number of tokens to generate per request. + stddev_output_tokens: The standard deviation of the number of tokens to generate per request. + additional_sampling_params: Additional sampling parameters to send with the request. + For more information see the LLM APIs documentation for the completions. + results_dir: The directory to save the results to. + user_metadata: Additional metadata to include in the results. + """ + if mean_input_tokens < 40: + print( + "the minimum number of input tokens that will be sent is 41" + " because of the prompting logic right now" + ) + + summary, completed_requests, elapsed_time, incremental_time_delay = ( + get_token_throughput_latencies( + model=model, + llm_api=llm_api, + test_timeout_s=test_timeout_s, + max_num_completed_requests=max_num_completed_requests, + mean_input_tokens=mean_input_tokens, + stddev_input_tokens=stddev_input_tokens, + mean_output_tokens=mean_output_tokens, + stddev_output_tokens=stddev_output_tokens, + concurrent_requests=concurrent_requests, + additional_sampling_params=json.loads(additional_sampling_params), + random_seed=random_seed, + openai_api_base=openai_api_base, + tokenizer_path=tokenizer_path, + ) + ) + if mean_output_tokens == 2: + return summary, completed_requests, elapsed_time, incremental_time_delay + + timestamp = int(time.time() * 1000) + if results_dir: + filename = f"{model}_{mean_input_tokens}_{mean_output_tokens}_{timestamp}" + filename = re.sub(r"[^\w\d-]+", "-", filename) + filename = re.sub(r"-{2,}", "-", filename) + summary_filename = f"{filename}_summary" + + # Update to metadata. + summary.update(user_metadata) + total_tp, req_tp = compute_throughput( + summary, completed_requests, elapsed_time, incremental_time_delay + ) + summary["num_completed_requests"] = len(completed_requests) + summary["elapsed_time"] = elapsed_time + summary["incremental_time_delay"] = incremental_time_delay + summary["total_throughput"] = total_tp + summary["incremental_throughput"] = req_tp + + results = LLMPerfResults(name=summary_filename, metadata=summary) + results_dir = Path(results_dir) + if not results_dir.exists(): + results_dir.mkdir(parents=True) + elif not results_dir.is_dir(): + raise ValueError(f"{results_dir} is not a directory") + + llmperf_dir = results_dir / "llmperf" + if not llmperf_dir.exists(): + llmperf_dir.mkdir(parents=True) + elif not llmperf_dir.is_dir(): + raise ValueError(f"{llmperf_dir} is not a directory") + + try: + with open(llmperf_dir / f"{summary_filename}.json", "w") as f: + json.dump(results.to_dict(), f, indent=4, default=str) + except Exception as e: + print(results.to_dict()) + raise e + return summary diff --git a/test/common/llmperf/utils/utils.py b/test/common/llmperf/utils/utils.py new file mode 100644 index 000000000..e2c270871 --- /dev/null +++ b/test/common/llmperf/utils/utils.py @@ -0,0 +1,171 @@ +import hashlib +import json +import math +import os +import pathlib +import random +import subprocess +import time +from typing import Any, Dict, Tuple + +from transformers import LlamaTokenizerFast + +RESULTS_VERSION = "2025-10-30" + + +class LLMPerfResults: + def __init__( + self, + name: str, + metadata: Dict[str, Any] = None, + ): + self.name = name + self.metadata = metadata or {} + self.timestamp = int(time.time()) + self.metadata["timestamp"] = self.timestamp + self.version = RESULTS_VERSION + + def to_dict(self): + data = { + "version": self.version, + "name": self.name, + } + data.update(self.metadata) + data = flatten_dict(data) + return data + + def json(self): + data = self.to_dict() + return json.dumps(data) + + +def upload_to_s3(results_path: str, s3_path: str) -> None: + """Upload the results to s3. + + Args: + results_path: The path to the results file. + s3_path: The s3 path to upload the results to. + + """ + + command = ["aws", "s3", "sync", results_path, f"{s3_path}/"] + result = subprocess.run(command) + if result.returncode == 0: + print("Files uploaded successfully!") + else: + print("An error occurred:") + print(result.stderr) + + +def randomly_sample_sonnet_lines_prompt( + prompt_tokens_mean: int = 550, + prompt_tokens_stddev: int = 250, + tokenizer: LlamaTokenizerFast = None, +) -> Tuple[str, int]: + """Generate a prompt that randomly samples lines from a the shakespeare sonnet at sonnet.txt. + + Args: + prompt_length_mean: The mean length of the prompt to generate. + prompt_len_stddev: The standard deviation of the length of the prompt to generate. + expect_output_tokens: The number of tokens to expect in the output. This is used to + determine the length of the prompt. The prompt will be generated such that the output + will be approximately this many tokens. + + Note: + tokens will be counted from the sonnet using the Llama tokenizer. Using one tokenizer + ensures a fairer comparison across different LLMs. For example, if gpt 3.5 tokenizes + a prompt in less tokens than Llama2, then this will be reflected in the results since + they will be fed identical prompts. + + Returns: + A tuple of the prompt and the length of the prompt. + """ + get_token_length = lambda text: len(tokenizer.encode(text)) + + prompt = ( + "Randomly stream lines from the following text " + "Don't generate eos tokens:\n\n" + ) + # get a prompt length that is at least as long as the base + num_prompt_tokens = sample_random_positive_int( + prompt_tokens_mean, prompt_tokens_stddev + ) + while num_prompt_tokens < get_token_length(prompt): + num_prompt_tokens = sample_random_positive_int( + prompt_tokens_mean, prompt_tokens_stddev + ) + remaining_prompt_tokens = num_prompt_tokens - get_token_length(prompt) + sonnet_path = pathlib.Path(__file__).parent.resolve() / "sonnet.txt" + with open(sonnet_path, "r") as f: + sonnet_lines = f.readlines() + random.shuffle(sonnet_lines) + sampling_lines = True + while sampling_lines: + for line in sonnet_lines: + line_to_add = line + if remaining_prompt_tokens - get_token_length(line_to_add) < 0: + # This will cut off a line in the middle of a word, but that's ok since an + # llm should be able to handle that. + line_to_add = line_to_add[: int(math.ceil(remaining_prompt_tokens))] + sampling_lines = False + prompt += line_to_add + break + prompt += line_to_add + remaining_prompt_tokens -= get_token_length(line_to_add) + print(hashlib.sha256(prompt.encode("utf-8")).hexdigest()) + return (prompt, num_prompt_tokens) + + +def sample_random_positive_int(mean: int, stddev: int) -> int: + """Sample random numbers from a gaussian distribution until a positive number is sampled. + + Args: + mean: The mean of the gaussian distribution to sample from. + stddev: The standard deviation of the gaussian distribution to sample from. + + Returns: + A random positive integer sampled from the gaussian distribution. + """ + ret = -1 + while ret <= 0: + ret = int(random.gauss(mean, stddev)) + return ret + + +def flatten_dict(d, parent_key="", sep="_"): + items = [] + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def reset_prefill_cache(env, server_url): + """ + prefix cache / HBM + Param: + env + server_url + """ + reset_url = f"{server_url}/reset_prefix_cache" + print(f"[INFO] Resetting prefix cache: {reset_url}") + try: + result = subprocess.run( + ["curl", "-X", "POST", reset_url, "-s", "-f"], + env=env, + check=False, + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode == 0: + print("[INFO] Prefix cache successfully reset") + else: + print( + f"[ERROR] Unsuccessfully reset prefix cache,error code: {result.returncode}" + ) + except Exception as e: + print(f"[ERROR] Exception in resetting prefix cache: {e}") diff --git a/test/config.yaml b/test/config.yaml index 88d00a610..7ac32f484 100644 --- a/test/config.yaml +++ b/test/config.yaml @@ -15,4 +15,13 @@ database: name: "ucm_pytest" user: "root" password: "123456" - charset: "utf8mb4" \ No newline at end of file + charset: "utf8mb4" + +# LLM Connection Configuration +llm_connection: + model: "qwen3" + server_url: "http://141.111.32.70:9382" + tokenizer_path: "/home/models/QwQ-32B" + stream: true # stream output + ignore_eos: true # Ignore the returned terminator + timeout: 180 # request time out \ No newline at end of file diff --git a/test/suites/E2E/test_demo_performance.py b/test/suites/E2E/test_demo_performance.py deleted file mode 100644 index 1b76818f6..000000000 --- a/test/suites/E2E/test_demo_performance.py +++ /dev/null @@ -1,66 +0,0 @@ -import pytest -from common.config_utils import config_utils as config_instance - - -# ---------------- Fixture Example ---------------- -class Calculator: - def __init__(self): - print("[Calculator Initialization]") - pass - - def add(self, a, b): - return a + b - - def divide(self, a, b): - if b == 0: - raise ZeroDivisionError("Cannot divide by zero") - return a / b - - -@pytest.fixture(scope="module", name="calc") -def calculator(): - return Calculator() - - -@pytest.mark.feature("mark") -class TestCalculator: - # The calc instance will only be initialized on the first call, see the pytest documentation for more usage - def test_add(self, calc): - assert calc.add(1, 2) == 3 - - def test_divide(self, calc): - assert calc.divide(6, 2) == 3 - - def test_divide_by_zero(self, calc): - with pytest.raises(ZeroDivisionError): - calc.divide(6, 0) - - -# ---------------- Write to DB Example ---------------- -from common.capture_utils import * - - -@pytest.mark.feature("capture") # pytest must be the top -@export_vars -def test_capture_mix(): - """Mixed single + lists via '_name' + '_data'""" - assert 1 == 1 - return { - "_name": "demo", - "_data": { - "length": 10086, # single value - "accuracy": [0.1, 0.2, 0.3], # list - "loss": [0.1, 0.2, 0.3], # list - }, - } - - -# ---------------- Read Config Example ---------------- -from common.config_utils import config_utils as config_instance - - -@pytest.mark.feature("config") -def test_config(): - assert ( - config_instance.get_nested_config("database.host", "localhost") == "127.0.0.1" - ) diff --git a/test/suites/E2E/test_uc_performance.py b/test/suites/E2E/test_uc_performance.py new file mode 100644 index 000000000..dbec0318b --- /dev/null +++ b/test/suites/E2E/test_uc_performance.py @@ -0,0 +1,158 @@ +import pytest +from common.capture_utils import export_vars +from common.llmperf.run_inference import inference_results + + +@pytest.mark.parametrize("mean_input_tokens", [[2000, 3000]]) +@pytest.mark.parametrize("mean_output_tokens", [[200, 500]]) +@pytest.mark.parametrize("max_num_completed_requests", [[8, 4]]) +@pytest.mark.parametrize("concurrent_requests", [[8, 4]]) +@pytest.mark.parametrize("additional_sampling_params", [["{}", "{}"]]) +@pytest.mark.parametrize("hit_rate", [[0, 50]]) +@pytest.mark.feature("uc_performance_test") +@export_vars +def test_performance( + mean_input_tokens, + mean_output_tokens, + max_num_completed_requests, + concurrent_requests, + additional_sampling_params, + hit_rate, +): + all_summaries = inference_results( + mean_input_tokens, + mean_output_tokens, + max_num_completed_requests, + concurrent_requests, + additional_sampling_params, + hit_rate, + ) + failed_cases = [] + + value_lists = { + "mean_input_tokens": [], + "mean_output_tokens": [], + "results_inter_token_latency_s_quantiles_p50": [], + "results_inter_token_latency_s_quantiles_p90": [], + "results_inter_token_latency_s_quantiles_p99": [], + "results_inter_token_latency_s_mean": [], + "results_ttft_s_quantiles_p50": [], + "results_ttft_s_quantiles_p90": [], + "results_ttft_s_quantiles_p99": [], + "results_ttft_s_mean": [], + "results_end_to_end_latency_s_quantiles_p50": [], + "results_end_to_end_latency_s_quantiles_p90": [], + "results_end_to_end_latency_s_quantiles_p99": [], + "results_end_to_end_latency_s_mean": [], + "num_completed_requests": [], + "elapsed_time": [], + "incremental_time_delay": [], + "total_throughput": [], + "incremental_throughput": [], + } + + for i, summary in enumerate(all_summaries): + mean_input_tokens = summary["mean_input_tokens"] + mean_output_tokens = summary["mean_output_tokens"] + + results_inter_token_latency_s_quantiles_p50 = summary["results"][ + "inter_token_latency_s" + ]["quantiles"]["p50"] + results_inter_token_latency_s_quantiles_p90 = summary["results"][ + "inter_token_latency_s" + ]["quantiles"]["p90"] + results_inter_token_latency_s_quantiles_p99 = summary["results"][ + "inter_token_latency_s" + ]["quantiles"]["p99"] + results_inter_token_latency_s_mean = summary["results"][ + "inter_token_latency_s" + ]["mean"] + + results_ttft_s_quantiles_p50 = summary["results"]["ttft_s"]["quantiles"]["p50"] + results_ttft_s_quantiles_p90 = summary["results"]["ttft_s"]["quantiles"]["p90"] + results_ttft_s_quantiles_p99 = summary["results"]["ttft_s"]["quantiles"]["p99"] + results_ttft_s_mean = summary["results"]["ttft_s"]["mean"] + + results_end_to_end_latency_s_quantiles_p50 = summary["results"][ + "end_to_end_latency_s" + ]["quantiles"]["p50"] + results_end_to_end_latency_s_quantiles_p90 = summary["results"][ + "end_to_end_latency_s" + ]["quantiles"]["p90"] + results_end_to_end_latency_s_quantiles_p99 = summary["results"][ + "end_to_end_latency_s" + ]["quantiles"]["p99"] + results_end_to_end_latency_s_mean = summary["results"]["end_to_end_latency_s"][ + "mean" + ] + + num_completed_requests = summary["num_completed_requests"] + elapsed_time = summary["elapsed_time"] + incremental_time_delay = summary["incremental_time_delay"] + total_throughput = summary["total_throughput"] + incremental_throughput = summary["incremental_throughput"] + + values = [ + mean_input_tokens, + mean_output_tokens, + results_inter_token_latency_s_quantiles_p50, + results_inter_token_latency_s_quantiles_p90, + results_inter_token_latency_s_quantiles_p99, + results_inter_token_latency_s_mean, + results_ttft_s_quantiles_p50, + results_ttft_s_quantiles_p90, + results_ttft_s_quantiles_p99, + results_ttft_s_mean, + results_end_to_end_latency_s_quantiles_p50, + results_end_to_end_latency_s_quantiles_p90, + results_end_to_end_latency_s_quantiles_p99, + results_end_to_end_latency_s_mean, + num_completed_requests, + elapsed_time, + incremental_time_delay, + total_throughput, + incremental_throughput, + ] + + for var_name, val in zip( + [ + "mean_input_tokens", + "mean_output_tokens", + "results_inter_token_latency_s_quantiles_p50", + "results_inter_token_latency_s_quantiles_p90", + "results_inter_token_latency_s_quantiles_p99", + "results_inter_token_latency_s_mean", + "results_ttft_s_quantiles_p50", + "results_ttft_s_quantiles_p90", + "results_ttft_s_quantiles_p99", + "results_ttft_s_mean", + "results_end_to_end_latency_s_quantiles_p50", + "results_end_to_end_latency_s_quantiles_p90", + "results_end_to_end_latency_s_quantiles_p99", + "results_end_to_end_latency_s_mean", + "num_completed_requests", + "elapsed_time", + "incremental_time_delay", + "total_throughput", + "incremental_throughput", + ], + values, + ): + value_lists[var_name].append(val) + if val is None: + failed_cases.append((i, var_name, "missing")) + + try: + assert val > 0, f"value <= 0" + except AssertionError as e: + failed_cases.append((i, var_name, str(e))) + + # Output final result + if failed_cases: + print(f"\n[WARNING] Assertion failed: {len(failed_cases)} abnormal cases found") + for i, key, reason in failed_cases: + print(f" Iteration={i + 1}, key='{key}' -> {reason}") + else: + print("\n[INFO] All values are greater than 0. Assertion passed!") + + return {"_name": "llmperf", "_data": value_lists} diff --git a/ucm/__init__.py b/ucm/__init__.py index 18c7adaa5..9b39d21c7 100644 --- a/ucm/__init__.py +++ b/ucm/__init__.py @@ -1,48 +1,4 @@ -# -# MIT License -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# - -""" -vLLM integration module for Unified Cache Management. - -This module automatically applies patches to vLLM when imported, -eliminating the need for manual `git apply` commands. -""" - -# Auto-apply patches when this module is imported -try: - from ucm.integration.vllm.patch.apply_patch import ensure_patches_applied - - ensure_patches_applied() -except Exception as e: - # Don't fail if patches can't be applied - might be running in environment without vLLM - import warnings - - warnings.warn( - f"Failed to apply vLLM patches: {e}. " - f"If you're using vLLM, ensure it's installed and patches are compatible." - ) - from ucm.integration.vllm.uc_connector import UnifiedCacheConnectorV1 +from ucm.integration.vllm.ucm_connector import UCMConnector -__all__ = ["UnifiedCacheConnectorV1"] +__all__ = ["UnifiedCacheConnectorV1", "UCMConnector"] diff --git a/ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_adapt.py b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_adapt.py index b16e65a02..c71068349 100644 --- a/ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_adapt.py +++ b/ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_adapt.py @@ -112,6 +112,13 @@ def maybe_execute_sparse_attention_finished( ): if not has_ucm_sparse(): return + ucm_sparse = get_ucm_sparse() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + ucm_sparse.attention_finished( + query, key, value, attn_output, layer_name, forward_context + ) attention_v1.maybe_execute_sparse_attention_finished = ( maybe_execute_sparse_attention_finished diff --git a/ucm/integration/vllm/uc_connector.py b/ucm/integration/vllm/uc_connector.py index 2271e5e2c..dd870c253 100644 --- a/ucm/integration/vllm/uc_connector.py +++ b/ucm/integration/vllm/uc_connector.py @@ -44,6 +44,7 @@ from ucm.logger import init_logger from ucm.store.factory import UcmConnectorFactory from ucm.store.ucmstore import Task +from ucm.utils import Config if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -113,22 +114,11 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): vllm_config.parallel_config ) self.head_size = vllm_config.model_config.get_head_size() - if ( - self._vllm_config.kv_transfer_config is not None - and "ucm_connector_name" - in self._vllm_config.kv_transfer_config.kv_connector_extra_config - ): - name = self._vllm_config.kv_transfer_config.kv_connector_extra_config[ - "ucm_connector_name" - ] - config = {} - if ( - "ucm_connector_config" - in self._vllm_config.kv_transfer_config.kv_connector_extra_config - ): - config = self._vllm_config.kv_transfer_config.kv_connector_extra_config[ - "ucm_connector_config" - ] + ucm_config = Config(vllm_config.kv_transfer_config) + launch_config = ucm_config.get_config() + if "ucm_connector_name" in launch_config: + name = launch_config.get("ucm_connector_name") + config = launch_config.get("ucm_connector_config") or {} config["device"] = self.rank config["role"] = ( "scheduler" if role == KVConnectorRole.SCHEDULER else "worker" diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py new file mode 100644 index 000000000..106b9dba8 --- /dev/null +++ b/ucm/integration/vllm/ucm_connector.py @@ -0,0 +1,860 @@ +import hashlib +import itertools +import json +import os +import pickle +from dataclasses import dataclass, field +import queue +import threading +import time +from typing import TYPE_CHECKING, Callable, List, Optional + +from sympy import Dict +import torch +from transformers import Any +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) +from vllm.distributed.parallel_state import get_tp_group, get_world_group +from vllm.platforms import current_platform +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.request import Request + +from ucm.logger import init_logger +from ucm.store.factory import UcmConnectorFactory +from ucm.store.ucmstore import Task, UcmKVStoreBase +from ucm.utils import Config + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + +logger = init_logger(__name__) + + +@dataclass +class RequestMeta: + ucm_block_ids: list[str] = field(default_factory=list) + hbm_hit_block_num: int = 0 + # local_computed_block + external_computed_block + total_hit_block_num: int = 0 + + +@dataclass +class RequestDispatchMeta: + load_block_ids: tuple[ + list[str], list[int] + ] # [0] mean ucm_block_ids, [1] means vllm_block_ids + dump_block_ids: tuple[list[str], list[int]] + + +@dataclass +class UCMConnectorMetadata(KVConnectorMetadata): + request_meta: dict[str, RequestDispatchMeta] = field(default_factory=dict) + + +class RequestHasher: + """hash(md5) request to generate ucm block id""" + + _SEED_HASH = None + + def __init__(self, vllm_config, rank_id): + meta = f"{vllm_config.model_config.model}:{vllm_config.parallel_config.world_size}:{vllm_config.model_config.dtype}:{rank_id}" + self.meta_bytes = meta.encode("utf-8") + + if RequestHasher._SEED_HASH is None: + RequestHasher._SEED_HASH = self("UCM_HASH_SEED") + + def __call__(self, input_data) -> int: + if isinstance(input_data, str): + input_bytes = input_data.encode("utf-8") + else: + input_bytes = pickle.dumps(input_data, protocol=pickle.HIGHEST_PROTOCOL) + + h = hashlib.md5(self.meta_bytes + input_bytes) + return int.from_bytes(h.digest(), byteorder="big") + + +class UCMDirectConnector(KVConnectorBase_V1): + """ + This connector means synchronize: + load -> forward -> save + """ + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self.kv_caches: dict[str, torch.Tensor] = {} + self.rank = ( + -1 if role == KVConnectorRole.SCHEDULER else get_world_group().local_rank + ) + self.block_size = self._vllm_config.cache_config.block_size + self.is_mla = self._vllm_config.model_config.is_deepseek_mla + self.kv_cache_dtype: torch.dtype = None + + if current_platform.is_cuda_alike(): + logger.info("CUDA device is available.") + torch_dev = torch + dev_name = "cuda" + elif current_platform.is_npu(): + logger.info("NPU device is available.") + torch_dev = torch.npu + dev_name = "npu" + else: + raise RuntimeError("Unsupported device platform for UCMDirectConnector.") + + if self.rank >= 0: + self.device = torch_dev.device(f"{dev_name}:{self.rank}") + self._layer_offset_cache = {} + + self.store: UcmKVStoreBase + + if role == KVConnectorRole.SCHEDULER: + self.request_hasher = RequestHasher(vllm_config, 0) + else: + self.request_hasher = RequestHasher(vllm_config, self.rank) + + # save block info, avoid hash request twice, and track them until request finished + self.requests_meta: dict[str, RequestMeta] = {} + + ucm_config = Config(vllm_config.kv_transfer_config) + self.launch_config = ucm_config.get_config() + + self.load_only_first_rank: bool = ( + self.launch_config.get("load_only_first_rank", self.is_mla) and self.is_mla + ) + if self.load_only_first_rank: + if role == KVConnectorRole.WORKER: + self.group_coordinator = get_tp_group() + self.broadcast_fn = self.group_coordinator.broadcast + self.broadcast_stream = torch.cuda.Stream() + + connector_configs = self.launch_config.get("ucm_connectors", []) + assert len(connector_configs) > 0, "no storage connector name in config." + + name = connector_configs[0].get("ucm_connector_name") + config = connector_configs[0].get("ucm_connector_config") or {} + config["device"] = self.rank + config["role"] = "scheduler" if role == KVConnectorRole.SCHEDULER else "worker" + element_size = vllm_config.model_config.dtype.itemsize + single_head_dim = vllm_config.model_config.get_head_size() + num_head_per_tp = vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config + ) + total_tp_size = vllm_config.parallel_config.tensor_parallel_size + num_layers = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config + ) + block_size_per_layer = self.block_size * element_size * single_head_dim + config["kv_block_size"] = ( + block_size_per_layer + * num_layers + * (1 if self.is_mla else num_head_per_tp * 2) + ) + config["io_size"] = block_size_per_layer * ( + 1 if self.is_mla else num_head_per_tp + ) + self.store = UcmConnectorFactory.create_connector(name, config) + + logger.info("init UCConnectorImpl, connector: %s", name) + logger.info( + "single file size = %d MB, io_size = %d KB,", + config["kv_block_size"] / 1024 / 1024, + config["io_size"] / 1024, + ) + self.record_oper: bool = self.launch_config.get("record_oper", False) + if self.record_oper: + self.write_thread = threading.Thread(target=self._async_record_loop, daemon=True) + self.write_thread.start() + + def log_operation(self, operation_data: Dict[str, Any]) -> None: + """Record operation log (non-blocking)""" + + default_data = { + "timestamp": time.time(), + "op_type": "None", + "block_size": self.block_size + } + log_entry = {**default_data, **operation_data} + + try: + self.log_queue.put_nowait(log_entry) + except queue.Full: + logger.error( + f"Log queue is full, dropping one log: {log_entry.get('request_id')}" + ) + + def _async_record_loop(self): + self.log_queue = queue.Queue(maxsize=10000) # Max cache: 10000 entries + log_path = self.launch_config.get("record_oper_path", "/vllm-workspace/ucm_logs") + batch_size = self.launch_config.get("record_oper_batch_size", 100) + flush_interval = self.launch_config.get("record_oper_flush_interval", 5.0) + batch_buffer = [] + last_flush_time = time.time() + while True: + try: + # Get log from queue (1 second timeout) + is_flush = False + current_time = time.time() + log_entry = self.log_queue.get(timeout=1.0) + batch_buffer.append(log_entry) + + # Flush if conditions are met + if ( + len(batch_buffer) >= batch_size + or (current_time - last_flush_time) >= flush_interval + ): + is_flush = True + last_flush_time = current_time + self.log_queue.task_done() + except queue.Empty: + if (current_time - last_flush_time) >= flush_interval: + last_flush_time = current_time + except Exception as e: + logger.error(f"Log thread exception: {str(e)}") + + if is_flush: + with open(log_path, "a", encoding="utf-8") as f: + for log_entry in self.batch_buffer: + f.write(json.dumps(log_entry, ensure_ascii=False) + "\n") + batch_buffer.clear() + + + + def generate_hash(self, block_size: int, request: "Request") -> list[str]: + token_ids = request.all_token_ids + + ret = [] + parent_block_hash_value = RequestHasher._SEED_HASH + for start in range(0, len(token_ids), block_size): + end = start + block_size + block_token_ids = token_ids[start:end] + # Do not hash the block if it is not full. + if len(block_token_ids) < block_size: + break + + block_token_ids_tuple = tuple(block_token_ids) + hash_value = self.request_hasher( + (parent_block_hash_value, block_token_ids_tuple) + ) + parent_block_hash_value = hash_value + ret.append(str(hash_value)) + + return ret + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + + assert num_computed_tokens % self.block_size == 0 + hbm_hit_block_num = num_computed_tokens // self.block_size + + ucm_block_ids = self.generate_hash(self.block_size, request) + + external_block_ids = ucm_block_ids[hbm_hit_block_num:] + if not external_block_ids: + return 0, False + + lookup_results = self.store.lookup(external_block_ids) + external_hit_blocks = 0 + for i, hit in enumerate(lookup_results): + if not hit: + break + external_hit_blocks += 1 + logger.info( + f"request_id: {request.request_id}, " + f"total_blocks_num: {len(ucm_block_ids)}, " + f"hit hbm: {hbm_hit_block_num}, " + f"hit external: {external_hit_blocks}" + ) + + total_hit_block_num = hbm_hit_block_num + external_hit_blocks + + external_hit_tokens = external_hit_blocks * self.block_size + + # When all the tokens are cached in ssd or hbm, + # we need to recompute the last token. This if condition will be removed + # once vLLM scheduler provides a better solution in the future. + if total_hit_block_num * self.block_size == request.num_tokens: + external_hit_tokens -= 1 + + self.requests_meta[request.request_id] = RequestMeta( + ucm_block_ids=ucm_block_ids, + hbm_hit_block_num=hbm_hit_block_num, + total_hit_block_num=total_hit_block_num, + ) + + return external_hit_tokens, False + + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): + pass + + def _generate_dispatch_meta( + self, + req_meta: RequestMeta, + new_tokens: int, + vllm_block_ids: list[int], + need_load: bool = True, + ) -> RequestDispatchMeta: + """ + Request Blocks layout: + ---------------------------------------------------------------------------------------------------- + | local_computed_block(HBM hit) | external_computed_block(external hit) | new_block(need to dump) | + ---------------------------------------------------------------------------------------------------- + | hbm_hit_block_num | LOAD | new_blocks_num | + ---------------------------------------------------------------------------------------------------- + | total_hit_block_num | + ---------------------------------------------------------------------------------------------------- + | scheduled_block_num | + """ + + new_blocks_num = new_tokens // self.block_size + hbm_hit_block_num = req_meta.hbm_hit_block_num + total_hit_block_num = req_meta.total_hit_block_num + scheduled_block_num = total_hit_block_num + new_blocks_num + ucm_block_ids = req_meta.ucm_block_ids + + dump_ucm_block_ids = ucm_block_ids[total_hit_block_num:scheduled_block_num] + if need_load: + dump_vllm_block_ids = vllm_block_ids[ + total_hit_block_num:scheduled_block_num + ] + else: + dump_vllm_block_ids = vllm_block_ids + + # after this round, req_meta will be updated + req_meta.total_hit_block_num = scheduled_block_num + + load_ucm_block_ids, load_vllm_block_ids = [], [] + if need_load: + load_ucm_block_ids = ucm_block_ids[hbm_hit_block_num:total_hit_block_num] + load_vllm_block_ids = vllm_block_ids[hbm_hit_block_num:total_hit_block_num] + + return RequestDispatchMeta( + (load_ucm_block_ids, load_vllm_block_ids), + (dump_ucm_block_ids, dump_vllm_block_ids), + ) + + def build_connector_meta( + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: + requests_dispatch_meta = {} + # for new request, we need to load and dump + for request in scheduler_output.scheduled_new_reqs: + request_id, vllm_block_ids = request.req_id, request.block_ids[0] + req_meta = self.requests_meta.get(request_id) + if req_meta: + requests_dispatch_meta[request_id] = self._generate_dispatch_meta( + req_meta, + scheduler_output.num_scheduled_tokens[request_id], + vllm_block_ids, + ) + + # for cached request, there are 3 situation: + # 1. chunked prefill: we only need dump + # 2. resumed: we need to handle like new request + # 3. TODO decode stage: nothing happened + scheduled_cached_reqs = scheduler_output.scheduled_cached_reqs + if not isinstance(scheduled_cached_reqs, list): + # >= 0.9.2 + for i, request_id in enumerate(scheduled_cached_reqs.req_ids): + if scheduler_output.num_scheduled_tokens[request_id] == 1: + # decode stage + continue + req_meta = self.requests_meta.get(request_id) + if req_meta: + requests_dispatch_meta[request_id] = self._generate_dispatch_meta( + req_meta, + scheduler_output.num_scheduled_tokens[request_id], + scheduled_cached_reqs.new_block_ids[i][0], + scheduled_cached_reqs.resumed_from_preemption[i], + ) + else: + for request in scheduled_cached_reqs: + request_id = request.request_id + if scheduler_output.num_scheduled_tokens[request_id] == 1: + # decode stage + continue + req_meta = self.requests_meta.get(request_id) + if req_meta: + requests_dispatch_meta[request_id] = self._generate_dispatch_meta( + req_meta, + scheduler_output.num_scheduled_tokens[request_id], + request.new_block_ids[0], + request.resumed_from_preemption, + ) + + # clear finished request + for request_id in scheduler_output.finished_req_ids: + self.requests_meta.pop(request_id, None) + + return UCMConnectorMetadata(requests_dispatch_meta) + + def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"): + if len(self.kv_caches) > 0: + return + for layer_name in forward_context.no_compile_layers: + attn_layer = forward_context.no_compile_layers[layer_name] + if not hasattr(attn_layer, "kv_cache"): + continue + + if layer_name not in self.kv_caches: + self.kv_caches[layer_name] = attn_layer.kv_cache[ + forward_context.virtual_engine + ] + if self.kv_cache_dtype is None: + self.kv_cache_dtype = self.kv_caches[layer_name][0].dtype + + @staticmethod + def _extract_layer_index(layer_name: str) -> Optional[int]: + """ + Extract the layer index from the layer name. + """ + for chunk in layer_name.split("."): + if chunk.isdigit(): + return int(chunk) + return None + + def _precompute_layer_offsets(self): + if not self.kv_caches: + return + + sample_kv_layer = next(iter(self.kv_caches.values())) + elem_size = sample_kv_layer[0].element_size() + block_data_size = ( + sample_kv_layer[0].numel() if self.is_mla else sample_kv_layer[0][0].numel() + ) * elem_size + layer_data_size = block_data_size if self.is_mla else block_data_size * 2 + + # precompute all layers offset + for layer_name, _ in self.kv_caches.items(): + layer_id = self._extract_layer_index(layer_name) + assert layer_id is not None + k_offset = layer_data_size * layer_id + v_offset = k_offset + block_data_size if not self.is_mla else 0 + self._layer_offset_cache[layer_name] = (k_offset, v_offset) + + def _get_tensor_and_offset( + self, vllm_block_ids: list[int], kv_layer: torch.Tensor, layer_name: str + ) -> tuple[list[torch.Tensor], list[int]]: + """ + GQA/MHA: one layer shape is (2, num_blocks, block_size, num_kv_heads, head_size) + MLA: one layer shape is (num_blocks, block_size, head_size) + """ + k_tensors, k_offsets = [], [] + v_tensors, v_offsets = [], [] + k_offset, v_offset = self._layer_offset_cache[layer_name] + + for vllm_block_id in vllm_block_ids: + k_tensors.append( + kv_layer[vllm_block_id] if self.is_mla else kv_layer[0][vllm_block_id] + ) + k_offsets.append(k_offset) + if not self.is_mla: + v_tensors.append(kv_layer[1][vllm_block_id]) + v_offsets.append(v_offset) + return k_tensors + v_tensors, k_offsets + v_offsets + + def _generate_task(self, vllm_block_ids: List[int], ucm_block_ids: List[str]): + if not self._layer_offset_cache: + self._precompute_layer_offsets() + + num_layers = len(self.kv_caches) + num_blocks_per_layer = len(vllm_block_ids) + num_tensors_per_layer = num_blocks_per_layer * (1 if self.is_mla else 2) + dst_tensor_addr = [None] * (num_layers * num_tensors_per_layer) + ucm_offsets = [0] * (num_layers * num_tensors_per_layer) + + idx = 0 + for layer_name, one_layer_kv_cache in self.kv_caches.items(): + tensors, offsets = self._get_tensor_and_offset( + vllm_block_ids, one_layer_kv_cache, layer_name + ) + dst_tensor_addr[idx : idx + len(tensors)] = tensors + ucm_offsets[idx : idx + len(offsets)] = offsets + idx += len(tensors) + + repeat_times = len(self.kv_caches) * (1 if self.is_mla else 2) + ucm_total_block_ids = ucm_block_ids * repeat_times + + assert len(ucm_total_block_ids) == len(ucm_offsets) == len(dst_tensor_addr) + return ucm_total_block_ids, ucm_offsets, dst_tensor_addr + + def _broadcast(self, dst_tensor_addr: list[torch.Tensor]): + rec_tensor: torch.Tensor = None + with torch.cuda.stream(self.broadcast_stream): + if self.rank == 0: + tensor_to_broadcast = torch.stack(dst_tensor_addr, dim=0) + self.broadcast_fn(tensor_to_broadcast, 0) + else: + shape = (len(dst_tensor_addr),) + dst_tensor_addr[0].shape + # TODO create earlier + rec_tensor = torch.empty( + shape, dtype=self.kv_cache_dtype, device=self.device + ) + self.broadcast_fn(rec_tensor, 0) + self.broadcast_stream.synchronize() + if self.rank != 0 and rec_tensor is not None: + for i, tensor in enumerate(dst_tensor_addr): + tensor.copy_(rec_tensor[i]) + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: + + metadata = self._get_connector_metadata() + assert isinstance(metadata, UCMConnectorMetadata) + + self._init_kv_caches_from_forward_context(forward_context) + + request_to_task: dict[str, Optional[Task]] = {} + req_broadcast_addr = {} + for request_id, request in metadata.request_meta.items(): + if len(request.load_block_ids[0]) == 0: + continue + + ucm_block_ids, vllm_block_ids = request.load_block_ids + if self.rank != 0 and not self.is_mla: + for i, ucm_block_id in enumerate(ucm_block_ids): + ucm_block_ids[i] = str(self.request_hasher(ucm_block_id)) + ucm_total_block_ids, ucm_offsets, dst_tensor_addr = self._generate_task( + vllm_block_ids, ucm_block_ids + ) + if self.rank == 0 or not self.load_only_first_rank: + request_to_task[request_id] = self.store.load( + ucm_total_block_ids, ucm_offsets, dst_tensor_addr + ) + if self.record_oper: + self.log_operation( + { + "op_type": "load", + "blocks": ucm_block_ids, + } + ) + else: + request_to_task[request_id] = None + req_broadcast_addr[request_id] = dst_tensor_addr + + for request_id, task in request_to_task.items(): + # TODO error handling + if self.rank == 0 or not self.load_only_first_rank: + if self.store.wait(task) != 0: + logger.error(f"request {request_id} load kv cache failed.") + if self.load_only_first_rank: + self._broadcast(req_broadcast_addr[request_id]) + + def wait_for_layer_load(self, layer_name: str) -> None: + pass + + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs, + ) -> None: + pass + + def wait_for_save(self) -> None: + + if self.is_mla and self.rank != 0: + return + + metadata = self._get_connector_metadata() + assert isinstance(metadata, UCMConnectorMetadata) + + request_to_task: dict[str, Task] = {} + request_to_blocks: dict[str, list[str]] = {} + for request_id, request in metadata.request_meta.items(): + if len(request.dump_block_ids[0]) == 0: + continue + + ucm_block_ids, vllm_block_ids = request.dump_block_ids + if self.rank != 0: + for i, ucm_block_id in enumerate(ucm_block_ids): + ucm_block_ids[i] = str(self.request_hasher(ucm_block_id)) + rets = self.store.create(ucm_block_ids) + end = 0 + for i, ret in enumerate(rets): + if ret != 0: + logger.error( + f"create blocks for {request_id} failed, block index: {i}, ret code: {ret}" + ) + break + end += 1 + + if end == 0: + continue + ucm_block_ids = ucm_block_ids[:end] + vllm_block_ids = vllm_block_ids[:end] + ucm_total_block_ids, ucm_offsets, dst_tensor_addr = self._generate_task( + vllm_block_ids, ucm_block_ids + ) + request_to_task[request_id] = self.store.dump( + ucm_total_block_ids, ucm_offsets, dst_tensor_addr + ) + if self.record_oper: + self.log_operation( + { + "op_type": "dump", + "blocks": ucm_block_ids, + } + ) + request_to_blocks[request_id] = ucm_block_ids + + for request_id, task in request_to_task.items(): + ucm_block_ids = request_to_blocks[request_id] + if self.store.wait(task) == 0: + self.store.commit(ucm_block_ids, True) + else: + logger.error(f"request {request_id} dump kv cache failed.") + self.store.commit(ucm_block_ids, False) + + def clear_connector_metadata(self) -> None: + super().clear_connector_metadata() + + +class UCMLayerWiseConnector(UCMDirectConnector): + """ + This Connector means overlap: + load l0 -> forward l0 -> save l0 + load l1 -> forward l1 -> save l1 + load l2 -> forward l2 -> save l2 + """ + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config, role) + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: + raise NotImplementedError + + def wait_for_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: + raise NotImplementedError + + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs, + ) -> None: + raise NotImplementedError + + def wait_for_save(self) -> None: + raise NotImplementedError + + +class UCMPDConnector(UCMDirectConnector): + """ + This Connector means overlap (especially for Decode Instance): + step (req0,1,2) forward -> step (req0,1,2,3) forward + load req3 -> load req4 + """ + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config, role) + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + raise NotImplementedError + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + """ + Notifies worker-side connector ids of requests that have + finished generating tokens. + + Returns: + ids of requests that have finished asynchronous transfer + (requests that previously returned True from request_finished()), + tuple of (sending/saving ids, recving/loading ids). + The finished saves/sends req ids must belong to a set provided in a + call to this method (this call or a prior one). + """ + raise NotImplementedError + + +class UCMMockConnector(UCMDirectConnector): + """ + This Connector can control hit ratio, for example: if your hit ratio is 100%, + you can set "hit_ratio" by config or env_vars, then get_num_new_matched_tokens() + will reduce hit_tokens under the hit_ratio you set. + """ + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config, role) + self._hit_ratio = float(self.launch_config["hit_ratio"]) + logger.info(f"hit_ratio: {self._hit_ratio}") + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + hit_tokens, _ = super().get_num_new_matched_tokens(request, num_computed_tokens) + expect_hit_tokens = int(self._hit_ratio * request.num_prompt_tokens) + if hit_tokens <= expect_hit_tokens: + return hit_tokens, False + expect_hit_block_num = expect_hit_tokens // self.block_size + request_meta = self.requests_meta[request.request_id] + request_meta.total_hit_block_num = expect_hit_block_num + request_meta.hbm_hit_block_num = min( + expect_hit_block_num, request_meta.hbm_hit_block_num + ) + + logger.info( + "Hijacked By MockConnector," + f"request_id: {request.request_id}, " + f"total_blocks_num: {len(request_meta.ucm_block_ids)}, " + f"hit hbm: {request_meta.hbm_hit_block_num}, " + f"hit external: {request_meta.total_hit_block_num - request_meta.hbm_hit_block_num}" + ) + + return expect_hit_block_num * self.block_size, False + + +class UCMConnector(KVConnectorBase_V1): + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self.connector: KVConnectorBase_V1 + # TODO new conn by config + if ( + self._vllm_config.kv_transfer_config is not None + and "hit_ratio" + in self._vllm_config.kv_transfer_config.kv_connector_extra_config + ): + self.connector = UCMMockConnector(vllm_config, role) + else: + self.connector = UCMDirectConnector(vllm_config, role) + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + return self.connector.get_num_new_matched_tokens(request, num_computed_tokens) + + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): + """ + Update KVConnector state after block allocation. + """ + self.connector.update_state_after_alloc(request, blocks, num_external_tokens) + + def build_connector_meta( + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: + """ + Build the connector metadata for this step. + + This function should NOT modify fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + return self.connector.build_connector_meta(scheduler_output) + + def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None: + """Set the connector metadata from the scheduler. + + This function should be called by the model runner every time + before the model execution. The metadata will be used for runtime + KV cache loading and saving. + + Args: + connector_metadata (dict): the connector metadata. + """ + self.connector.bind_connector_metadata(connector_metadata) + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: + """ + Start loading the KV cache from the connector to vLLM's paged + KV buffer. This is called from the forward context before the + forward pass to enable async loading during model execution. + + Args: + forward_context (ForwardContext): the forward context. + **kwargs: additional arguments for the load operation + + Note: + The number of elements in kv_caches and layer_names should be + the same. + + """ + self.connector.start_load_kv(forward_context, **kwargs) + + def wait_for_layer_load(self, layer_name: str) -> None: + """ + Block until the KV for a specific layer is loaded into vLLM's + paged buffer. This is called from within attention layer to ensure + async copying from start_load_kv is complete. + + This interface will be useful for layer-by-layer pipelining. + + Args: + layer_name: the name of that layer + """ + self.connector.wait_for_layer_load(layer_name) + + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", + **kwargs, + ) -> None: + """ + Start saving the a layer of KV cache from vLLM's paged buffer + to the connector. This is called from within attention layer to + enable async copying during execution. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + self.connector.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs) + + def wait_for_save(self) -> None: + """ + Block until all the save operations is done. This is called + as the forward context exits to ensure that the async saving + from save_kv_layer is complete before finishing the forward. + + This prevents overwrites of paged KV buffer before saving done. + """ + self.connector.wait_for_save() + + def clear_connector_metadata(self) -> None: + """Clear the connector metadata. + + This function should be called by the model runner every time + after the model execution. + """ + self.connector.clear_connector_metadata() diff --git a/ucm/shared/CMakeLists.txt b/ucm/shared/CMakeLists.txt index 3952830cd..e378ff0ad 100644 --- a/ucm/shared/CMakeLists.txt +++ b/ucm/shared/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(vendor) +add_subdirectory(infra) add_subdirectory(trans) add_subdirectory(test) diff --git a/ucm/shared/infra/CMakeLists.txt b/ucm/shared/infra/CMakeLists.txt new file mode 100644 index 000000000..ba4345dce --- /dev/null +++ b/ucm/shared/infra/CMakeLists.txt @@ -0,0 +1,22 @@ +file(GLOB_RECURSE UCMINFRA_STATUS_SOURCE_FILES "status/*.*") +add_library(infra_status OBJECT ${UCMINFRA_STATUS_SOURCE_FILES}) +target_include_directories(infra_status PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) +target_link_libraries(infra_status PUBLIC fmt) + +file(GLOB UCMINFRA_LOGGER_SOURCE_FILES "logger/*.*") +file(GLOB_RECURSE UCMINFRA_LOGGER_DETAIL_SOURCE_FILES "logger/${LOGGER_BACKEND}/*.cc") +add_library(infra_logger OBJECT ${UCMINFRA_LOGGER_SOURCE_FILES} ${UCMINFRA_LOGGER_DETAIL_SOURCE_FILES}) +target_include_directories(infra_logger PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) +target_link_libraries(infra_logger PUBLIC fmt spdlog) + +file(GLOB_RECURSE UCMINFRA_TEMPLATE_SOURCE_FILES "template/*.*") +add_library(infra_template OBJECT ${UCMINFRA_TEMPLATE_SOURCE_FILES}) +target_include_directories(infra_template PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) + +file(GLOB_RECURSE UCMINFRA_THREAD_SOURCE_FILES "thread/*.*") +add_library(infra_thread OBJECT ${UCMINFRA_THREAD_SOURCE_FILES}) +target_include_directories(infra_thread PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) + +file(GLOB_RECURSE UCMINFRA_TIME_SOURCE_FILES "time/*.*") +add_library(infra_time OBJECT ${UCMINFRA_TIME_SOURCE_FILES}) +target_include_directories(infra_time PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/ucm/store/infra/logger/flux/flux_logger.cc b/ucm/shared/infra/logger/flux/flux_logger.cc similarity index 100% rename from ucm/store/infra/logger/flux/flux_logger.cc rename to ucm/shared/infra/logger/flux/flux_logger.cc diff --git a/ucm/store/infra/logger/logger.h b/ucm/shared/infra/logger/logger.h similarity index 97% rename from ucm/store/infra/logger/logger.h rename to ucm/shared/infra/logger/logger.h index f27dd23df..516b9e663 100644 --- a/ucm/store/infra/logger/logger.h +++ b/ucm/shared/infra/logger/logger.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_LOGGER_H -#define UNIFIEDCACHE_LOGGER_H +#ifndef UNIFIEDCACHE_INFRA_LOGGER_H +#define UNIFIEDCACHE_INFRA_LOGGER_H #include #include diff --git a/ucm/store/infra/logger/spdlog/spdlog_logger.cc b/ucm/shared/infra/logger/spdlog/spdlog_logger.cc similarity index 100% rename from ucm/store/infra/logger/spdlog/spdlog_logger.cc rename to ucm/shared/infra/logger/spdlog/spdlog_logger.cc diff --git a/ucm/shared/trans/status.h b/ucm/shared/infra/status/status.h similarity index 51% rename from ucm/shared/trans/status.h rename to ucm/shared/infra/status/status.h index cab271794..3711de842 100644 --- a/ucm/shared/trans/status.h +++ b/ucm/shared/infra/status/status.h @@ -21,17 +21,34 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_TRANS_STATUS_H -#define UNIFIEDCACHE_TRANS_STATUS_H +#ifndef UNIFIEDCACHE_INFRA_STATUS_H +#define UNIFIEDCACHE_INFRA_STATUS_H +#include #include #include -namespace UC::Trans { +namespace UC { + +template +static inline constexpr int32_t __MakeStatusCode() +{ + return -50000 - i; +} class Status { static constexpr int32_t OK_ = 0; static constexpr int32_t ERROR_ = -1; + static constexpr int32_t EPARAM_ = __MakeStatusCode<0>(); + static constexpr int32_t EOOM_ = __MakeStatusCode<1>(); + static constexpr int32_t EOSERROR_ = __MakeStatusCode<2>(); + static constexpr int32_t EDUPLICATE_ = __MakeStatusCode<3>(); + static constexpr int32_t ERETRY_ = __MakeStatusCode<4>(); + static constexpr int32_t ENOOBJ_ = __MakeStatusCode<5>(); + static constexpr int32_t ESERIALIZE_ = __MakeStatusCode<6>(); + static constexpr int32_t EDESERIALIZE_ = __MakeStatusCode<7>(); + static constexpr int32_t EUNSUPPORTED_ = __MakeStatusCode<8>(); + static constexpr int32_t ENOSPACE_ = __MakeStatusCode<9>(); int32_t code_; std::string message_; explicit Status(int32_t code) : code_(code) {} @@ -39,7 +56,13 @@ class Status { public: bool operator==(const Status& other) const noexcept { return code_ == other.code_; } bool operator!=(const Status& other) const noexcept { return !(*this == other); } - std::string ToString() const { return fmt::format("({}) {}", code_, message_); } + int32_t Underlying() const { return code_; } + std::string ToString() const + { + auto str = std::to_string(code_); + if (message_.empty()) { return str; } + return fmt::format("{}, {}", str, message_); + } constexpr bool Success() const noexcept { return code_ == OK_; } constexpr bool Failure() const noexcept { return !Success(); } @@ -47,8 +70,21 @@ class Status { Status(int32_t code, std::string message) : code_{code}, message_{std::move(message)} {} static Status OK() { return Status{OK_}; } static Status Error(std::string message) { return {ERROR_, std::move(message)}; } + static Status Error() { return Status{ERROR_}; } + static Status InvalidParam() { return Status{EPARAM_}; } + static Status OutOfMemory() { return Status{EOOM_}; } + static Status OsApiError() { return Status{EOSERROR_}; } + static Status DuplicateKey() { return Status{EDUPLICATE_}; } + static Status Retry() { return Status{ERETRY_}; } + static Status NotFound() { return Status{ENOOBJ_}; } + static Status SerializeFailed() { return Status{ESERIALIZE_}; } + static Status DeserializeFailed() { return Status{EDESERIALIZE_}; } + static Status Unsupported() { return Status{EUNSUPPORTED_}; } + static Status NoSpace() { return Status{ENOSPACE_}; } }; -} // namespace UC::Trans +inline std::string format_as(const Status& status) { return status.ToString(); } + +} // namespace UC #endif diff --git a/ucm/store/infra/template/hashset.h b/ucm/shared/infra/template/hashset.h similarity index 98% rename from ucm/store/infra/template/hashset.h rename to ucm/shared/infra/template/hashset.h index b09692bc1..102f69b62 100644 --- a/ucm/store/infra/template/hashset.h +++ b/ucm/shared/infra/template/hashset.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_HASHSET_H -#define UNIFIEDCACHE_HASHSET_H +#ifndef UNIFIEDCACHE_INFRA_HASHSET_H +#define UNIFIEDCACHE_INFRA_HASHSET_H #include #include diff --git a/ucm/store/infra/template/singleton.h b/ucm/shared/infra/template/singleton.h similarity index 94% rename from ucm/store/infra/template/singleton.h rename to ucm/shared/infra/template/singleton.h index fda4957b5..f667288ee 100644 --- a/ucm/store/infra/template/singleton.h +++ b/ucm/shared/infra/template/singleton.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_SINGLETON_H -#define UNIFIEDCACHE_SINGLETON_H +#ifndef UNIFIEDCACHE_INFRA_SINGLETON_H +#define UNIFIEDCACHE_INFRA_SINGLETON_H namespace UC { diff --git a/ucm/store/infra/template/timer.h b/ucm/shared/infra/template/timer.h similarity index 71% rename from ucm/store/infra/template/timer.h rename to ucm/shared/infra/template/timer.h index 1963faa80..0c9db149d 100644 --- a/ucm/store/infra/template/timer.h +++ b/ucm/shared/infra/template/timer.h @@ -21,16 +21,14 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_TIMER_H -#define UNIFIEDCACHE_TIMER_H +#ifndef UNIFIEDCACHE_INFRA_TIMER_H +#define UNIFIEDCACHE_INFRA_TIMER_H -#include -#include -#include #include +#include #include -#include "logger/logger.h" -#include "status/status.h" +#include +#include namespace UC { @@ -38,29 +36,30 @@ template class Timer { public: Timer(const std::chrono::seconds& interval, Callable&& callable) - : interval_(interval), callable_(callable), running_(false) {} - ~Timer() { + : interval_(interval), callable_(callable), running_(false) + { + } + ~Timer() + { { std::lock_guard lg(this->mutex_); this->running_ = false; + this->cv_.notify_one(); } - - this->cv_.notify_one(); if (this->thread_.joinable()) { this->thread_.join(); } } - Status Start() + bool Start() { { std::lock_guard lg(this->mutex_); - if (this->running_) { return Status::OK(); } + if (this->running_) { return true; } } try { this->running_ = true; this->thread_ = std::thread(&Timer::Runner, this); - return Status::OK(); - } catch (const std::exception& e) { - UC_ERROR("Failed({}) to start timer.", e.what()); - return Status::OutOfMemory(); + return true; + } catch (...) { + return false; } } @@ -68,14 +67,12 @@ class Timer { void Runner() { while (this->running_) { - try { - { - std::unique_lock lg(this->mutex_); - this->cv_.wait_for(lg, this->interval_, [this] { return !this->running_; }); - if (!this->running_) { break; } - } - this->callable_(); - } catch (const std::exception& e) { UC_ERROR("Failed({}) to run timer.", e.what()); } + { + std::unique_lock lg(this->mutex_); + this->cv_.wait_for(lg, this->interval_, [this] { return !this->running_; }); + if (!this->running_) { break; } + } + this->callable_(); } } @@ -88,6 +85,6 @@ class Timer { std::atomic running_; }; -} // namespace UC +} // namespace UC -#endif \ No newline at end of file +#endif diff --git a/ucm/store/infra/template/topn_heap.h b/ucm/shared/infra/template/topn_heap.h similarity index 97% rename from ucm/store/infra/template/topn_heap.h rename to ucm/shared/infra/template/topn_heap.h index 98884c238..737d0b19a 100644 --- a/ucm/store/infra/template/topn_heap.h +++ b/ucm/shared/infra/template/topn_heap.h @@ -22,11 +22,12 @@ * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_TOP_N_HEAP_H -#define UNIFIEDCACHE_TOP_N_HEAP_H +#ifndef UNIFIEDCACHE_INFRA_TOP_N_HEAP_H +#define UNIFIEDCACHE_INFRA_TOP_N_HEAP_H #include #include +#include namespace UC { @@ -117,4 +118,4 @@ class TopNFixedHeap : public TopNHeap { } // namespace UC -#endif \ No newline at end of file +#endif diff --git a/ucm/store/infra/thread/index_pool.h b/ucm/shared/infra/thread/index_pool.h similarity index 97% rename from ucm/store/infra/thread/index_pool.h rename to ucm/shared/infra/thread/index_pool.h index 225ee8842..4217b7a0e 100644 --- a/ucm/store/infra/thread/index_pool.h +++ b/ucm/shared/infra/thread/index_pool.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_INDEX_POOL_H -#define UNIFIEDCACHE_INDEX_POOL_H +#ifndef UNIFIEDCACHE_INFRA_INDEX_POOL_H +#define UNIFIEDCACHE_INFRA_INDEX_POOL_H #include #include diff --git a/ucm/store/infra/thread/latch.h b/ucm/shared/infra/thread/latch.h similarity index 95% rename from ucm/store/infra/thread/latch.h rename to ucm/shared/infra/thread/latch.h index 1837ca27a..fb1dcf583 100644 --- a/ucm/store/infra/thread/latch.h +++ b/ucm/shared/infra/thread/latch.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_LATCH_H -#define UNIFIEDCACHE_LATCH_H +#ifndef UNIFIEDCACHE_INFRA_LATCH_H +#define UNIFIEDCACHE_INFRA_LATCH_H #include #include @@ -66,4 +66,4 @@ class Latch { } // namespace UC -#endif // UNIFIEDCACHE_LATCH_H +#endif // UNIFIEDCACHE_INFRA_LATCH_H diff --git a/ucm/store/infra/thread/thread_pool.h b/ucm/shared/infra/thread/thread_pool.h similarity index 98% rename from ucm/store/infra/thread/thread_pool.h rename to ucm/shared/infra/thread/thread_pool.h index c33a0c28e..baa514ed7 100644 --- a/ucm/store/infra/thread/thread_pool.h +++ b/ucm/shared/infra/thread/thread_pool.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_THREAD_POOL_H -#define UNIFIEDCACHE_THREAD_POOL_H +#ifndef UNIFIEDCACHE_INFRA_THREAD_POOL_H +#define UNIFIEDCACHE_INFRA_THREAD_POOL_H #include #include diff --git a/ucm/store/infra/time/stopwatch.h b/ucm/shared/infra/time/stopwatch.h similarity index 95% rename from ucm/store/infra/time/stopwatch.h rename to ucm/shared/infra/time/stopwatch.h index 2386f394b..c2a5bb331 100644 --- a/ucm/store/infra/time/stopwatch.h +++ b/ucm/shared/infra/time/stopwatch.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#ifndef UNIFIEDCACHE_STOPWATCH_H -#define UNIFIEDCACHE_STOPWATCH_H +#ifndef UNIFIEDCACHE_INFRA_STOPWATCH_H +#define UNIFIEDCACHE_INFRA_STOPWATCH_H #include diff --git a/ucm/store/test/case/infra/hashset_test.cc b/ucm/shared/test/case/infra/hashset_test.cc similarity index 100% rename from ucm/store/test/case/infra/hashset_test.cc rename to ucm/shared/test/case/infra/hashset_test.cc diff --git a/ucm/shared/test/case/trans/trans_test.cc b/ucm/shared/test/case/trans/trans_test.cc index f38769cea..4f2415b1b 100644 --- a/ucm/shared/test/case/trans/trans_test.cc +++ b/ucm/shared/test/case/trans/trans_test.cc @@ -28,7 +28,7 @@ class UCTransUnitTest : public ::testing::Test {}; TEST_F(UCTransUnitTest, CopyDataWithCE) { - const auto ok = UC::Trans::Status::OK(); + const auto ok = UC::Status::OK(); constexpr int32_t deviceId = 0; constexpr size_t size = 36 * 1024; constexpr size_t number = 64 * 61; @@ -60,7 +60,7 @@ TEST_F(UCTransUnitTest, CopyDataWithCE) TEST_F(UCTransUnitTest, CopyDataWithSM) { - const auto ok = UC::Trans::Status::OK(); + const auto ok = UC::Status::OK(); constexpr int32_t deviceId = 0; constexpr size_t size = 36 * 1024; constexpr size_t number = 64 * 61; diff --git a/ucm/shared/trans/CMakeLists.txt b/ucm/shared/trans/CMakeLists.txt index bbf001fce..57a1bd0aa 100644 --- a/ucm/shared/trans/CMakeLists.txt +++ b/ucm/shared/trans/CMakeLists.txt @@ -8,6 +8,7 @@ if(RUNTIME_ENVIRONMENT STREQUAL "simu") add_subdirectory(simu) endif() target_include_directories(trans PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/..) +target_link_libraries(trans PUBLIC infra_status) file(GLOB_RECURSE UCMTRANS_CPY_SOURCE_FILES "./cpy/*.cc") pybind11_add_module(ucmtrans ${UCMTRANS_CPY_SOURCE_FILES}) diff --git a/ucm/shared/trans/buffer.h b/ucm/shared/trans/buffer.h index 918b067fa..a73752513 100644 --- a/ucm/shared/trans/buffer.h +++ b/ucm/shared/trans/buffer.h @@ -25,7 +25,7 @@ #define UNIFIEDCACHE_TRANS_BUFFER_H #include -#include "status.h" +#include "status/status.h" namespace UC::Trans { diff --git a/ucm/shared/trans/stream.h b/ucm/shared/trans/stream.h index 3cb0c3683..425617968 100644 --- a/ucm/shared/trans/stream.h +++ b/ucm/shared/trans/stream.h @@ -25,7 +25,7 @@ #define UNIFIEDCACHE_TRANS_STREAM_H #include -#include "status.h" +#include "status/status.h" namespace UC::Trans { diff --git a/ucm/sparse/esa/esa.py b/ucm/sparse/esa/esa.py index c7047f876..d8316cc66 100644 --- a/ucm/sparse/esa/esa.py +++ b/ucm/sparse/esa/esa.py @@ -307,7 +307,7 @@ def maybe_register_static_data(self, forward_context: ForwardContext): self.init_static_flag = True def wait_transfer_task_done(self): - assert len(self.tasks) > 0 + # assert len(self.tasks) > 0 for task_hash, task in self.tasks.items(): # TODO: handle exceptions ret = self.store_instance.wait(task) @@ -352,9 +352,10 @@ def wait_retrieval_and_start_load(self): self.pre_topk_block_hashes, diff_blocks = diff_two_map( self.pre_topk_block_hashes, target_map ) - self.launch_transfer_task( - "load", list(diff_blocks.values()), list(diff_blocks.keys()) - ) + if diff_blocks: + self.launch_transfer_task( + "load", list(diff_blocks.values()), list(diff_blocks.keys()) + ) ## 2. load all # self.launch_transfer_task( @@ -438,7 +439,8 @@ def attention_begin( self.k_cache[vllm_block_ids[-local_window_sz:]] = self.local_window self.start_retrieval(query, forward_context) self.wait_retrieval_and_start_load() - self.wait_transfer_task_done() + if len(self.tasks) > 0: + self.wait_transfer_task_done() def attention_finished( self, diff --git a/ucm/sparse/esa/retrieval/cpy/retrieval_backend.cpp b/ucm/sparse/esa/retrieval/cpy/retrieval_backend.cpp index dacf38686..67ffac201 100644 --- a/ucm/sparse/esa/retrieval/cpy/retrieval_backend.cpp +++ b/ucm/sparse/esa/retrieval/cpy/retrieval_backend.cpp @@ -58,8 +58,6 @@ class RetrievalWorkerBackend { if (rc != 0) { std::cerr << "Error binding memory to NUMA node " << numaId << std::endl; } - #else - std::cerr << "NUMA support is disabled." << std::endl; #endif } diff --git a/ucm/sparse/esa/retrieval/retrieval_worker.py b/ucm/sparse/esa/retrieval/retrieval_worker.py index ebed1ed1c..7209d604e 100644 --- a/ucm/sparse/esa/retrieval/retrieval_worker.py +++ b/ucm/sparse/esa/retrieval/retrieval_worker.py @@ -1,10 +1,11 @@ import time +from collections import defaultdict import numpy as np import torch -# import retrieval_backend from ucm.sparse.esa.retrieval import retrieval_backend +from ucm.sparse.kvstar.utils import get_bind_cpus_for_rank class RetrievalWorker: @@ -42,7 +43,19 @@ def wait(self, req_id): data = torch.rand(kv_cache_blocks, dim).to(torch.float32) print("data created", data.shape) - backend = retrieval_backend.RetrievalWorkerBackend(data) + ratio = 0.75 + total_tp_size = 4 + local_tp_rank = 0 + bind_info_list, alloc_numa_ids = get_bind_cpus_for_rank( + total_tp_size, local_tp_rank, ratio=ratio + ) + + bind_info_dict = defaultdict(list) + for item in bind_info_list: + bind_info_dict[item[1]].append(item[0]) + bind_info_dict = dict(bind_info_dict) + + backend = retrieval_backend.RetrievalWorkerBackend(data, bind_info_dict) worker = RetrievalWorker(backend) topk = 3000 search_blocks_range = 8000 diff --git a/ucm/sparse/kvcomp/hash_encoder.py b/ucm/sparse/kvcomp/hash_encoder.py index db76079d3..7546aa71e 100644 --- a/ucm/sparse/kvcomp/hash_encoder.py +++ b/ucm/sparse/kvcomp/hash_encoder.py @@ -31,6 +31,124 @@ logger = init_logger(__name__) +if hasattr(torch, "cuda") and torch.cuda.is_available(): + from vllm.triton_utils import tl, triton + + @triton.jit + def triton_hash_code_kernel( + x_ptr, + code_ptr, + pack_w_ptr, + hash_out_ptr, + M, + K, + N, + stride_xm, + stride_xk, + stride_codek, + stride_coden, + stride_pack_w, + stride_om, + stride_on, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # sample dimension + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # hash_rbits dimension + offs_k = tl.arange(0, BLOCK_K) # input_dim dimension + + # Matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + x = tl.load( + x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk, + mask=(offs_m[:, None] < M) & (offs_k[None, :] < K), + other=0.0, + ) + code = tl.load( + code_ptr + + offs_k[:, None] * stride_codek + + offs_n[None, :] * stride_coden, + mask=(offs_k[:, None] < K) & (offs_n[None, :] < N), + other=0.0, + ) + acc += tl.dot(x, code) + offs_k += BLOCK_K + + # Binarize and pack + bits = (acc > 0).to(tl.uint8) # Binarize + bits = tl.reshape(bits, (BLOCK_M, BLOCK_N // 8, 8)) # Reshape for packing + + # Load the packing weights (ensure it has the correct shape) + pack_w = tl.load(pack_w_ptr + tl.arange(0, 8) * stride_pack_w) + packed = tl.sum(bits * pack_w[None, None, :], axis=-1).to(tl.uint8) + + # Store results + offs_n = pid_n * (BLOCK_N // 8) + tl.arange(0, BLOCK_N // 8) + hash_out_ptrs = ( + hash_out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on + ) + tl.store( + hash_out_ptrs, + packed, + mask=(offs_m[:, None] < M) & (offs_n[None, :] < (N // 8)), + ) + + def triton_hash_code(x, code, pack_weight): + input_dim = x.shape[-1] + samples = x.shape[0] + hash_bits = code.shape[-1] + assert (pack_weight.shape[0] == 8) and (hash_bits % 8 == 0) + hash_out = torch.empty( + (samples, hash_bits // 8), dtype=pack_weight.dtype, device=x.device + ) + + grid = lambda opts: ( + triton.cdiv(samples, opts["BLOCK_M"]), + triton.cdiv(input_dim, opts["BLOCK_N"]), + ) + + triton_hash_code_kernel[grid]( + x, + code, + pack_weight, + hash_out, + samples, + input_dim, + hash_bits, + x.stride(0), + x.stride(1), + code.stride(0), + code.stride(1), + pack_weight.stride(0), + hash_out.stride(0), + hash_out.stride(1), + BLOCK_M=32, + BLOCK_K=64, + BLOCK_N=16, + ) + + return hash_out.view(-1) # [samples * hash_numbers] + + +@torch.compile() +def torch_hash_code(x, code, pack_weight): + # [N, hash_bits] + x = x @ code + m = x.shape[:-1] + # [N, hash_bits] -- > [N, hash_bits // 8, 8] + x = (x > 0).to(torch.uint8).view(*m, -1, 8) + # 8bit -> 1bit + # binary_codes * self.bit_masks [N, hash_numbers, 8] * [1, 1, 8] -> [N, hash_numbers, 8] + # then sum along the last dimension to get [N, hash_numbers] + x = torch.sum(x * pack_weight, dim=-1, dtype=torch.uint8) + x = x.view(-1) # [N * hash_numbers] + return x + class HashEncoder: """ @@ -105,8 +223,6 @@ def _init_bit_masks(self) -> None: self.bit_masks = torch.pow( 2, torch.arange(8, dtype=torch.uint8, device=self.device) ) - # shape (1, 1, 8) - self.bit_masks = self.bit_masks.unsqueeze(0).unsqueeze(0) def compute_hash(self, x: torch.Tensor) -> torch.Tensor: """ @@ -136,29 +252,24 @@ def compute_hash(self, x: torch.Tensor) -> torch.Tensor: if x_flat.dtype != self.dtype: x_flat = x_flat.to(self.dtype) - # [N, hash_bits] - xW = torch.matmul(x_flat, self.hash_weights) - - # [N * hash_bits] - xW_flat = xW.view(-1) - if self.device.type == "npu": + # [N, hash_bits] + xW = torch.matmul(x_flat, self.hash_weights) + # [N * hash_bits] + xW_flat = xW.view(-1) # [N*hash_numbers], where hash_numbers = hash_bits // 8 packed_codes_flat = torch_npu.npu_sign_bits_pack(xW_flat, size=1) - elif self.device.type == "cuda" or self.device.type == "cpu": - # (TODO) improve performance later on CUDA ops and CPU SIMD instructions - # [N, hash_bits] - projected = (xW > 0).to(torch.uint8) - # [N, hash_numbers, 8] - binary_codes = projected.view(-1, self.hash_numbers, 8) - - # binary_codes * self.bit_masks [N, hash_numbers, 8] * [1, 1, 8] -> [N, hash_numbers, 8] - # then sum along the last dimension to get [N, hash_numbers] - packed_codes_flat = torch.sum( - binary_codes * self.bit_masks, dim=-1, dtype=torch.uint8 - ) # [N, hash_numbers] - packed_codes_flat = packed_codes_flat.view(-1) # [N * hash_numbers] + elif self.device.type == "cuda": + packed_codes_flat = triton_hash_code( + x_flat, self.hash_weights, self.bit_masks + ) # [N * hash_numbers] + + elif self.device.type == "cpu": + packed_codes_flat = torch_hash_code( + x_flat, self.hash_weights, self.bit_masks + ) # [N * hash_numbers] + else: raise ValueError(f"Unsupported device type: {self.device.type}") @@ -213,7 +324,7 @@ def _unpack_hash(self, packed_codes: torch.Tensor) -> torch.Tensor: ) # expand last dim to 8 # (expanded & self.bit_masks) > 0 -> [N, hash_numbers, 8] - unpacked_bits = (expanded & self.bit_masks) > 0 + unpacked_bits = (expanded & self.bit_masks.unsqueeze(0).unsqueeze(0)) > 0 # 0 -> -1, 1 -> 1 unpacked_bits = unpacked_bits * 2 - 1 @@ -232,20 +343,22 @@ def _unpack_hash(self, packed_codes: torch.Tensor) -> torch.Tensor: if __name__ == "__main__": + torch.manual_seed(42) + + print("test HashEncoder...") + dtype = torch.float16 if hasattr(torch, "npu") and torch.npu.is_available(): device = torch.device("npu:0") elif hasattr(torch, "cuda") and torch.cuda.is_available(): device = torch.device("cuda:0") + dtype = torch.float32 else: device = torch.device("cpu") print("Using device:", device) + encoder = HashEncoder(input_dim=8, hash_bits=8, dtype=dtype, device=device) - torch.manual_seed(42) - - encoder = HashEncoder(input_dim=8, hash_bits=8, dtype=torch.float16, device=device) - - x = torch.randn(2, 8, device=device, dtype=torch.float16) + x = torch.randn(2, 8, device=device, dtype=dtype) print("x:", x) hash_codes = encoder.compute_hash(x) @@ -262,3 +375,31 @@ def _unpack_hash(self, packed_codes: torch.Tensor) -> torch.Tensor: print( f"hash_codes[1].item()={hash_codes[1].item()}, 8-bit binary form:{hash_codes[1].item():08b}" ) + + if hasattr(torch, "cuda") and torch.cuda.is_available(): + print("test cuda triton and torch hash code functions...") + x = torch.randn((1024, 512), device="cuda:0", dtype=torch.bfloat16) + code = torch.randn((512, 512), device="cuda:0", dtype=torch.bfloat16) + pack_weight = torch.tensor( + [128, 64, 32, 16, 8, 4, 2, 1], device="cuda:0", dtype=torch.uint8 + ) + + torch_output = torch_hash_code(x, code, pack_weight) + triton_output = triton_hash_code(x, code, pack_weight) + assert torch_output.shape == triton_output.shape + print(f"x_shape: {x.shape} code_shape: {code.shape}") + print("torch_output", torch_output) + print("triton_output", triton_output) + print( + f"The maximum difference between Torch and Triton is" + f" {torch.max(torch.abs(torch_output.to(torch.int32) - triton_output.to(torch.int32)))}" + ) + # benchmark + print( + "torch:", + triton.testing.do_bench(lambda: torch_hash_code(x, code, pack_weight)), + ) + print( + "triton:", + triton.testing.do_bench(lambda: triton_hash_code(x, code, pack_weight)), + ) diff --git a/ucm/sparse/kvcomp/hash_retrieval/cpy/hash_retrieval_backend.cpp b/ucm/sparse/kvcomp/hash_retrieval/cpy/hash_retrieval_backend.cpp index d1eedf050..4458b956a 100644 --- a/ucm/sparse/kvcomp/hash_retrieval/cpy/hash_retrieval_backend.cpp +++ b/ucm/sparse/kvcomp/hash_retrieval/cpy/hash_retrieval_backend.cpp @@ -18,16 +18,128 @@ #include #ifdef NUMA_ENABLED #include +#include #endif -#ifdef __ARM_NEON -#include // ARM NEON SIMD 指令集头文件 -#elif defined(__x86_64__) || defined(_M_X64) -#include // x86_64 SSE SIMD 指令集头文件 -#endif +#include +#include +#include +#include + +#if defined(__ARM_NEON) || defined(__ARM_NEON__) + #include +#elif defined(__x86_64__) || defined(_M_X64) || defined(__i386) || defined(_M_IX86) + #include // SSE/AVX + #include // POPCNT (SSE4.2) +#endif #define VEC_SIZE 16 +#if defined(__ARM_NEON) || defined(__ARM_NEON__) + +using vec16u = uint8x16_t; + +static inline vec16u vec_loadu16(const uint8_t* p) { + return vld1q_u8(p); +} + +static inline vec16u vec_xor(vec16u a, vec16u b) { + return veorq_u8(a, b); +} + +static inline uint16_t vec_sum_u8(vec16u v) { +#if defined(__aarch64__) || defined(_M_ARM64) + return vaddvq_u8(v); +#else + uint16x8_t s16 = vpaddlq_u8(v); + uint32x4_t s32 = vpaddlq_u16(s16); + uint64x2_t s64 = vpaddlq_u32(s32); + return (uint16_t)(vgetq_lane_u64(s64, 0) + vgetq_lane_u64(s64, 1)); +#endif +} + +static inline uint16_t vec_popcnt_xor_sum16(const uint8_t* a, const uint8_t* b) { + vec16u va = vec_loadu16(a); + vec16u vb = vec_loadu16(b); + vec16u vx = vec_xor(va, vb); + vec16u pc = vcntq_u8(vx); + return vec_sum_u8(pc); +} + +static inline uint16_t vec_popcnt_xor_sum16_vec(vec16u qa, const uint8_t* b) { + vec16u vb = vec_loadu16(b); + vec16u vx = vec_xor(qa, vb); + vec16u pc = vcntq_u8(vx); + return vec_sum_u8(pc); +} + +void print_uint8x16(uint8x16_t vec) { + uint8_t array[16]; + vst1q_u8(array, vec); + for (int i = 0; i < 16; ++i) { + std::cout << static_cast(array[i]) << " "; + } + std::cout << std::endl; +} + +#elif defined(__x86_64__) || defined(_M_X64) || defined(__i386) || defined(_M_IX86) + +using vec16u = __m128i; + +static inline vec16u vec_loadu16(const uint8_t* p) { + return _mm_loadu_si128(reinterpret_cast(p)); +} + +static inline vec16u vec_xor(vec16u a, vec16u b) { + return _mm_xor_si128(a, b); +} + +static inline uint16_t vec_popcnt_xor_sum16(const uint8_t* a, const uint8_t* b) { + __m128i va = _mm_loadu_si128(reinterpret_cast(a)); + __m128i vb = _mm_loadu_si128(reinterpret_cast(b)); + __m128i vx = _mm_xor_si128(va, vb); + + uint64_t lo, hi; +#if defined(__SSE4_1__) + lo = static_cast(_mm_extract_epi64(vx, 0)); + hi = static_cast(_mm_extract_epi64(vx, 1)); +#else + alignas(16) uint64_t tmp[2]; + _mm_storeu_si128(reinterpret_cast<__m128i*>(tmp), vx); + lo = tmp[0]; + hi = tmp[1]; +#endif + return (uint16_t)(__builtin_popcountll(lo) + __builtin_popcountll(hi)); +} + +static inline uint16_t vec_popcnt_xor_sum16_vec(vec16u qa, const uint8_t* b) { + __m128i vb = _mm_loadu_si128(reinterpret_cast(b)); + __m128i vx = _mm_xor_si128(qa, vb); + + uint64_t lo, hi; +#if defined(__SSE4_1__) + lo = static_cast(_mm_extract_epi64(vx, 0)); + hi = static_cast(_mm_extract_epi64(vx, 1)); +#else + alignas(16) uint64_t tmp[2]; + _mm_storeu_si128(reinterpret_cast<__m128i*>(tmp), vx); + lo = tmp[0]; + hi = tmp[1]; +#endif + return (uint16_t)(__builtin_popcountll(lo) + __builtin_popcountll(hi)); +} + +#else + +static inline uint16_t vec_popcnt_xor_sum16(const uint8_t* a, const uint8_t* b) { + uint16_t s = 0; + for (int i = 0; i < 16; ++i) + s += __builtin_popcount((unsigned)(a[i] ^ b[i])); + return s; +} + +#endif + namespace py = pybind11; class HashRetrievalWorkerBackend { @@ -37,11 +149,13 @@ class HashRetrievalWorkerBackend { : data_array_(data), stop_workers_(false), next_req_id_(0) { py::buffer_info info = data_array_.request(); - num_blocks_ = info.shape[0]; - block_size_ = info.shape[1]; - dim_ = info.shape[2]; - vec_per_dim_ = dim_ / VEC_SIZE; // data_每个值类型uint8_t,组成8*16_t进行simd加速 - data_ = static_cast(info.ptr); + num_blocks_ = info.shape[0]; + block_size_ = info.shape[1]; + dim_ = info.shape[2]; + vec_per_dim_ = dim_ / VEC_SIZE; // data_每个值类型uint8_t,组成8*16_t进行simd加速 + tail_dim_ = dim_ % VEC_SIZE; + tail_start_ = vec_per_dim_ * VEC_SIZE; + data_ = static_cast(info.ptr); // Start worker threads for (auto cpu_idx : cpu_idx_tbl) { @@ -72,8 +186,6 @@ class HashRetrievalWorkerBackend { if (rc != 0) { std::cerr << "Error binding memory to NUMA node " << numaId << std::endl; } - #else - std::cerr << "NUMA support is disabled." << std::endl; #endif } @@ -183,49 +295,6 @@ class HashRetrievalWorkerBackend { bool done = false; }; -#ifdef __ARM_NEON - static inline uint16_t vaddvq_u8_compat(uint8x16_t v) { - #if defined(__aarch64__) || defined(_M_ARM64) - return vaddvq_u8(v); - #else - uint16x8_t s16 = vpaddlq_u8(v); - uint32x4_t s32 = vpaddlq_u16(s16); - uint64x2_t s64 = vpaddlq_u32(s32); - return (uint16_t)(vgetq_lane_u64(s64, 0) + vgetq_lane_u64(s64, 1)); - #endif - } - - void print_uint8x16(uint8x16_t vec) { - uint8_t array[16]; - vst1q_u8(array, vec); - for (int i = 0; i < 16; ++i) { - std::cout << static_cast(array[i]) << " "; - } - std::cout << std::endl; - } - -#elif defined(__x86_64__) || defined(_M_X64) - // 采用 Brian Kernighan's 算法计算 64 位数的 Hamming Weight - unsigned int popcnt64(uint64_t x) { - x -= (x >> 1) & 0x5555555555555555; // 将相邻的两位合并 - x = (x & 0x3333333333333333) + ((x >> 2) & 0x3333333333333333); // 合并四位 - x = (x + (x >> 4)) & 0x0F0F0F0F0F0F0F0F; // 合并八位 - x = x + (x >> 8); // 合并十六位 - x = x + (x >> 16); // 合并三十二位 - x = x + (x >> 32); // 合并六十四位 - return x & 0x7F; // 返回最后的1的个数,0x7F表示最多返回 7 位 - } - - // 计算 128 位向量中 1 的个数 - int popcount_128(__m128i xor_result) { - // 将 128 位数据拆成两个 64 位整数 - uint64_t* result = (uint64_t*)&xor_result; - - // 分别计算每个 64 位的 Hamming 权重并返回结果之和 - return popcnt64(result[0]) + popcnt64(result[1]); - } -#endif - void worker_loop() { while (true) { Request req; @@ -248,18 +317,14 @@ class HashRetrievalWorkerBackend { std::vector> heap; heap.reserve(allowed.size()); +#if defined(__ARM_NEON) || defined(__ARM_NEON__) || \ + defined(__x86_64__) || defined(_M_X64) || defined(__i386) || defined(_M_IX86) // 1.预加载 query 向量 - #ifdef __ARM_NEON - uint8x16_t q_vecs[vec_per_dim_]; // 存储 query 向量 + vec16u q_vecs[vec_per_dim_]; // 存储query向量 for (size_t v = 0; v < vec_per_dim_; ++v) { - q_vecs[v] = vld1q_u8(q_ptr + v * VEC_SIZE); + q_vecs[v] = vec_loadu16(q_ptr + v * VEC_SIZE); } - #elif defined(__x86_64__) || defined(_M_X64) - __m128i q_vecs[vec_per_dim_]; // 存储 query 向量 - for (size_t v = 0; v < vec_per_dim_; ++v) { - q_vecs[v] = _mm_loadu_si128(reinterpret_cast(q_ptr + v * VEC_SIZE)); - } - #endif +#endif // 2.遍历允许的索引 for (auto idx : allowed) { @@ -274,37 +339,27 @@ class HashRetrievalWorkerBackend { const uint8_t* k_base = base_idx_ptr + t_idx * dim_; // 计算每个向量的相似度 +#if defined(__ARM_NEON) || defined(__ARM_NEON__) || \ + defined(__x86_64__) || defined(_M_X64) || defined(__i386) || defined(_M_IX86) for (size_t v = 0; v < vec_per_dim_; ++v) { - #ifdef __ARM_NEON - uint8x16_t k = vld1q_u8(k_base + v * VEC_SIZE); - sum += vaddvq_u8_compat(vcntq_u8(veorq_u8(q_vecs[v], k))); - #elif defined(__x86_64__) || defined(_M_X64) - __m128i k = _mm_loadu_si128(reinterpret_cast(k_base + v * VEC_SIZE)); - __m128i xor_result = _mm_xor_si128(q_vecs[v], k); // 16 * 8 - int popcount_result = popcount_128(xor_result); // 计算128位 xor_result 中所有位为 1 的个数 - sum += popcount_result; // 获取每个字节的累计值 - #endif + sum += vec_popcnt_xor_sum16_vec( + q_vecs[v], + k_base + v * VEC_SIZE + ); } - - // 处理不足16字节的部分 - ssize_t tail_dim = dim_ % VEC_SIZE; - if (tail_dim != 0) { - uint8_t q_tmp[16] = { 0 }; // 初始化填充为0 - uint8_t k_tmp[16] = { 0 }; - memcpy(q_tmp, q_ptr, dim_); - memcpy(k_tmp, k_base, dim_); - - #ifdef __ARM_NEON - uint8x16_t q = vld1q_u8(q_tmp); - uint8x16_t k = vld1q_u8(k_tmp); - sum += vaddvq_u8_compat(vcntq_u8(veorq_u8(q, k))); - #elif defined(__x86_64__) || defined(_M_X64) - __m128i q = _mm_loadu_si128(reinterpret_cast(q_tmp)); - __m128i k = _mm_loadu_si128(reinterpret_cast(k_tmp)); - __m128i xor_result = _mm_xor_si128(q, k); - int popcount_result = popcount_128(xor_result); // 计算128位 xor_result 中所有位为 1 的个数 - sum += popcount_result; // 获取每个字节的累计值 - #endif +#else + for (size_t v = 0; v < vec_per_dim_; ++v) { + sum += vec_popcnt_xor_sum16( + q_ptr + v * VEC_SIZE, + k_base + v * VEC_SIZE + ); + } +#endif + if (tail_dim_ != 0) { + for (size_t t = 0; t < tail_dim_; ++t) { + uint8_t x = q_ptr [tail_start_+t] ^ k_base[tail_start_+t]; + sum += __builtin_popcount((unsigned)x); + } } // 如果得分为0,则跳出循环 @@ -350,7 +405,7 @@ class HashRetrievalWorkerBackend { py::array_t data_array_; const uint8_t* data_ = nullptr; ssize_t dim_; - size_t num_blocks_, block_size_, vec_per_dim_; + size_t num_blocks_, block_size_, vec_per_dim_, tail_dim_, tail_start_; std::queue requests_; std::unordered_map results_; std::vector worker_threads_; @@ -368,4 +423,4 @@ PYBIND11_MODULE(hash_retrieval_backend, m) { .def("poll", &HashRetrievalWorkerBackend::poll) .def("get_result", &HashRetrievalWorkerBackend::get_result) .def("wait", &HashRetrievalWorkerBackend::wait); -} \ No newline at end of file +} diff --git a/ucm/sparse/kvcomp/hash_retrieval/hash_retrieval_worker.py b/ucm/sparse/kvcomp/hash_retrieval/hash_retrieval_worker.py index 7a77b05a8..5faf83dcc 100644 --- a/ucm/sparse/kvcomp/hash_retrieval/hash_retrieval_worker.py +++ b/ucm/sparse/kvcomp/hash_retrieval/hash_retrieval_worker.py @@ -1,10 +1,12 @@ import time +from collections import defaultdict import numpy as np import torch from ucm.sparse.kvcomp.hash_encoder import HashEncoder from ucm.sparse.kvcomp.hash_retrieval import hash_retrieval_backend +from ucm.sparse.kvstar.utils import get_bind_cpus_for_rank class HashRetrievalWorker: @@ -37,15 +39,16 @@ def wait(self, req_id): if __name__ == "__main__": ################# data batch_size = 2 - dim = 1024 - kv_cache_blocks = 25600 - data = torch.rand(kv_cache_blocks, dim).to(torch.float32) + block_size = 2 + head_dim = 128 + head_num = 1 + dim = head_dim * head_num + kv_cache_blocks = 2560 + data = torch.rand(kv_cache_blocks, block_size, dim).to(torch.float32) print("data created", data.shape) - backend = hash_retrieval_backend.HashRetrievalWorkerBackend(data) - worker = HashRetrievalWorker(backend) - topk = 3000 - search_blocks_range = 8000 + topk = 10 + search_blocks_range = 100 tpot = 30 / 1000 indexes = np.arange(batch_size * search_blocks_range).reshape( @@ -54,8 +57,35 @@ def wait(self, req_id): query = torch.rand(batch_size, dim).to(torch.float32) + hash_encoder = HashEncoder( + input_dim=dim, + hash_bits=dim, + dtype=torch.float32, + device=torch.device("cpu"), + ) + + hash_query = hash_encoder.compute_hash(query) + hash_key_cache = hash_encoder.compute_hash(data) + + ratio = 0.75 + total_tp_size = 4 + local_tp_rank = 0 + bind_info_list, alloc_numa_ids = get_bind_cpus_for_rank( + total_tp_size, local_tp_rank, ratio=ratio + ) + + bind_info_dict = defaultdict(list) + for item in bind_info_list: + bind_info_dict[item[1]].append(item[0]) + bind_info_dict = dict(bind_info_dict) + + backend = hash_retrieval_backend.HashRetrievalWorkerBackend( + hash_key_cache, bind_info_dict + ) + worker = HashRetrievalWorker(backend) + #################### cpp async version - req_id = worker.submit(query, topk=topk, indexes=indexes) + req_id = worker.submit(hash_query, topk=topk, indexes=indexes) #################### LLM decode begin time.sleep(tpot * 3) @@ -66,28 +96,24 @@ def wait(self, req_id): worker.wait(req_id) result = worker.get_result(req_id) print("cpp spent:", time.time() - begin) + cpp_indices = np.sort(result["indices"], 1) + print(f"cpp indices={cpp_indices}") ################### numpy version + unpacked_hash_query = hash_encoder._unpack_hash(hash_query) + unpacked_hash_key_cache = hash_encoder._unpack_hash(hash_key_cache) begin = time.time() - data_indexed = ( - data[indexes.flatten()].reshape(indexes.shape[0], indexes.shape[1], dim).numpy() + data_indexed = unpacked_hash_key_cache[indexes.flatten()].reshape( + indexes.shape[0], indexes.shape[1], block_size, dim ) - query = HashRetrievalWorker.handle_input(query) - scores = np.matmul(query[:, None, :], data_indexed.transpose((0, 2, 1))) - scores = scores[:, 0, :] - topk_elements = np.partition(scores, -topk, -1)[:, -topk:] - topk_indices = np.argpartition(scores, -topk, -1)[:, -topk:] - topk_indices = indexes[np.arange(indexes.shape[0])[:, None], topk_indices] - print("numpy spent: ", time.time() - begin) + scores = torch.einsum("td, tnjd->tnj", unpacked_hash_query, data_indexed) - ## compare - cpp_elements = np.sort(result["scores"], 1) - cpp_indices = np.sort(result["indices"], 1) - - np_elements = np.sort(topk_elements, 1) - np_indices = np.sort(topk_indices, 1) + block_scores_ret = torch.max(scores, dim=-1) + blocks_scores = block_scores_ret.values - diff_elements = np.abs(np_elements - cpp_elements) - diff_indices = np.abs(np_indices - cpp_indices) - - print(f"diff topk: {diff_indices.max()}") + topk_ret = torch.topk(blocks_scores, topk, dim=-1) + topk_index = topk_ret.indices + topk_index = topk_index.sort(dim=-1).values + topk_index = indexes[np.arange(indexes.shape[0])[:, None], topk_index] + print("numpy spent: ", time.time() - begin) + print(f"numpy indices={topk_index}") diff --git a/ucm/sparse/kvcomp/kvcomp.py b/ucm/sparse/kvcomp/kvcomp.py index c1713300e..8a1f61238 100644 --- a/ucm/sparse/kvcomp/kvcomp.py +++ b/ucm/sparse/kvcomp/kvcomp.py @@ -186,10 +186,10 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): if hasattr(torch, "npu") and torch.npu.is_available(): device = torch.device(f"npu:{self.rank}") - elif torch.cuda.is_available(): + elif hasattr(torch, "cuda") and torch.cuda.is_available(): device = torch.device(f"cuda:{self.rank}") else: - device = torch.device("npu") + device = torch.device("cpu") self.hash_encoder = HashEncoder( input_dim=self.kvcomp_config.head_dim, diff --git a/ucm/sparse/kvstar/retrieve/core/domain/retrieve_task/retrieve_task_queue.cpp b/ucm/sparse/kvstar/retrieve/core/domain/retrieve_task/retrieve_task_queue.cpp index a8fc080ad..0600e4cf0 100644 --- a/ucm/sparse/kvstar/retrieve/core/domain/retrieve_task/retrieve_task_queue.cpp +++ b/ucm/sparse/kvstar/retrieve/core/domain/retrieve_task/retrieve_task_queue.cpp @@ -37,8 +37,6 @@ void RetrieveTaskQueue::Worker(const int numaId, const int bindCoreId, std::prom started.set_value(Status::OsApiError()); return; } -#else - KVSTAR_DEBUG("NUMA support is disabled."); #endif KVSTAR_DEBUG("Bind current thread {} to numa {} core {} and set memory affinity success.", thread, numaId, bindCoreId); diff --git a/ucm/store/dramstore/CMakeLists.txt b/ucm/store/dramstore/CMakeLists.txt index 152955449..e69de29bb 100644 --- a/ucm/store/dramstore/CMakeLists.txt +++ b/ucm/store/dramstore/CMakeLists.txt @@ -1,12 +0,0 @@ -file(GLOB_RECURSE UCMSTORE_DRAM_CC_SOURCE_FILES "./cc/*.cc") -add_library(dramstore STATIC ${UCMSTORE_DRAM_CC_SOURCE_FILES}) -target_include_directories(dramstore PUBLIC - ${CMAKE_CURRENT_SOURCE_DIR}/cc/api - ${CMAKE_CURRENT_SOURCE_DIR}/cc/domain -) -target_link_libraries(dramstore PUBLIC storeinfra storedevice storetask) - -file(GLOB_RECURSE UCMSTORE_DRAM_CPY_SOURCE_FILES "./cpy/*.cc") -pybind11_add_module(ucmdramstore ${UCMSTORE_DRAM_CPY_SOURCE_FILES}) -target_link_libraries(ucmdramstore PRIVATE dramstore) -set_target_properties(ucmdramstore PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/ucm/store/dramstore/cc/api/dramstore.cc b/ucm/store/dramstore/cc/api/dramstore.cc deleted file mode 100644 index c59b7f2b1..000000000 --- a/ucm/store/dramstore/cc/api/dramstore.cc +++ /dev/null @@ -1,104 +0,0 @@ -/** - * MIT License - * - * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * */ -#include "dramstore.h" -#include "logger/logger.h" -#include "status/status.h" -#include "trans/dram_trans_manager.h" -#include "memory/memory_pool.h" - -namespace UC { - -class DRAMStoreImpl : public DRAMStore { -public: - int32_t Setup(const Config& config) { - auto status = this->memPool_.Setup(config.deviceId, config.capacity, config.blockSize); - if (status.Failure()) { - UC_ERROR("Failed({}) to setup MemoryPool.", status); - return status.Underlying(); - } - status = this->transMgr_.Setup(config.deviceId, config.streamNumber, &this->memPool_, config.timeoutMs); - if (status.Failure()) { - UC_ERROR("Failed({}) to setup TsfTaskManager.", status); - return status.Underlying(); - } - return Status::OK().Underlying(); - } - int32_t Alloc(const std::string& block) override { return this->memPool_.NewBlock(block).Underlying(); } - bool Lookup(const std::string& block) override { return this->memPool_.LookupBlock(block); } - void Commit(const std::string& block, const bool success) override { this->memPool_.CommitBlock(block, success).Underlying(); } - std::list Alloc(const std::list& blocks) override - { - std::list results; - for (const auto &block : blocks) { - results.emplace_back(this->Alloc(block)); - } - return results; - } - std::list Lookup(const std::list& blocks) override - { - std::list founds; - for (const auto &block : blocks) { - founds.emplace_back(this->Lookup(block)); - } - return founds; - } - void Commit(const std::list& blocks, const bool success) override { - for (const auto &block : blocks) { - this->Commit(block, success); - } - } - size_t Submit(Task&& task) override { - auto taskId = Task::invalid; - auto status = this->transMgr_.Submit(std::move(task), taskId); - if (status.Failure()) { taskId = Task::invalid; } - return taskId; } - - int32_t Wait(const size_t task) override { - return this->transMgr_.Wait(task).Underlying(); - } - - int32_t Check(const size_t task, bool& finish) override { - return this->transMgr_.Check(task, finish).Underlying(); - } - - -private: - - DramTransManager transMgr_; - MemoryPool memPool_; - -}; - -int32_t DRAMStore::Setup(const Config& config) -{ - auto impl = new (std::nothrow) DRAMStoreImpl(); - if (!impl) { - UC_ERROR("Out of memory."); - return Status::OutOfMemory().Underlying(); - } - this->impl_ = impl; - return impl->Setup(config); -} - -} // namespace UC diff --git a/ucm/store/dramstore/cc/api/dramstore.h b/ucm/store/dramstore/cc/api/dramstore.h deleted file mode 100644 index 25d72612a..000000000 --- a/ucm/store/dramstore/cc/api/dramstore.h +++ /dev/null @@ -1,83 +0,0 @@ -/** - * MIT License - * - * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * */ -#ifndef UNIFIEDCACHE_DRAMSTORE_H -#define UNIFIEDCACHE_DRAMSTORE_H - -#include "ucmstore.h" - -namespace UC { - -class DRAMStore : public CCStore<> { -public: - struct Config { - size_t capacity; - size_t blockSize; - int32_t deviceId; - size_t streamNumber; - size_t timeoutMs; - Config(const size_t capacity, const size_t blockSize, const int32_t deviceId, const size_t streamNumber, const size_t timeoutMs) - : capacity{capacity}, blockSize{blockSize}, deviceId{deviceId}, streamNumber{streamNumber}, timeoutMs{timeoutMs} - { - } - }; - -public: - DRAMStore() : impl_{nullptr} {} - ~DRAMStore() override - { - if (this->impl_) { delete this->impl_; } - } - int32_t Setup(const Config& config); - int32_t Alloc(const std::string& block) override { return this->impl_->Alloc(block); } - bool Lookup(const std::string& block) override { return this->impl_->Lookup(block); } - void Commit(const std::string& block, const bool success) override - { - this->impl_->Commit(block, success); - } - std::list Alloc(const std::list& blocks) override - { - return this->impl_->Alloc(blocks); - } - std::list Lookup(const std::list& blocks) override - { - return this->impl_->Lookup(blocks); - } - void Commit(const std::list& blocks, const bool success) override - { - this->impl_->Commit(blocks, success); - } - size_t Submit(Task&& task) override { return this->impl_->Submit(std::move(task)); } - int32_t Wait(const size_t task) override { return this->impl_->Wait(task); } - int32_t Check(const size_t task, bool& finish) override - { - return this->impl_->Check(task, finish); - } - -private: - DRAMStore* impl_; -}; - -} // namespace UC - -#endif diff --git a/ucm/store/dramstore/cc/domain/trans/dram_trans_manager.cc b/ucm/store/dramstore/cc/domain/trans/dram_trans_manager.cc deleted file mode 100644 index f98356129..000000000 --- a/ucm/store/dramstore/cc/domain/trans/dram_trans_manager.cc +++ /dev/null @@ -1,42 +0,0 @@ -/** - * MIT License - * - * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * */ - -#include "dram_trans_manager.h" - -namespace UC { - -Status DramTransManager::Setup(const int32_t deviceId, const size_t streamNumber, const MemoryPool* memPool, size_t timeoutMs) { - this->timeoutMs_ = timeoutMs; - auto status = Status::OK(); - for (size_t i = 0; i < streamNumber; i++) { - auto q = std::make_shared(); - status = - q->Setup(deviceId, &this->failureSet_, memPool, timeoutMs); - if (status.Failure()) { break; } - this->queues_.emplace_back(std::move(q)); - } - return status; -} - -} \ No newline at end of file diff --git a/ucm/store/dramstore/cc/domain/trans/dram_trans_manager.h b/ucm/store/dramstore/cc/domain/trans/dram_trans_manager.h deleted file mode 100644 index 7f9ef51b2..000000000 --- a/ucm/store/dramstore/cc/domain/trans/dram_trans_manager.h +++ /dev/null @@ -1,39 +0,0 @@ -/** - * MIT License - * - * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * */ -#ifndef UNIFIEDCACHE_DRAM_TRANS_MANAGER_H -#define UNIFIEDCACHE_DRAM_TRANS_MANAGER_H - -#include "task_manager.h" -#include "dram_trans_queue.h" - -namespace UC { - -class DramTransManager : public TaskManager { -public: - Status Setup(const int32_t deviceId, const size_t streamNumber, const MemoryPool* memPool, size_t timeoutMs); -}; - -} // namespace UC - -#endif diff --git a/ucm/store/dramstore/cc/domain/trans/dram_trans_queue.cc b/ucm/store/dramstore/cc/domain/trans/dram_trans_queue.cc deleted file mode 100644 index cf7a35770..000000000 --- a/ucm/store/dramstore/cc/domain/trans/dram_trans_queue.cc +++ /dev/null @@ -1,126 +0,0 @@ -/** - * MIT License - * - * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * */ - -#include "dram_trans_queue.h" - -namespace UC { - -Status DramTransQueue::Setup(const int32_t deviceId, TaskSet* failureSet, - const MemoryPool* memPool, const size_t timeoutMs) { - this->deviceId_ = deviceId; - this->failureSet_ = failureSet; - this->memPool_ = memPool; - auto success = - this->backend_.SetWorkerInitFn([this](auto& device) { return this->Init(device); }) - .SetWorkerFn([this](auto& shards, const auto& device) { this->Work(shards, device); }) - .SetWorkerExitFn([this](auto& device) { this->Exit(device); }) - .Run(); - return success ? Status::OK() : Status::Error(); -} - -void DramTransQueue::Push(std::list& shards) noexcept { - this->backend_.Push(std::move(shards)); -} - -bool DramTransQueue::Init(Device& device) { - if (this->deviceId_ < 0) { return true; } - device = DeviceFactory::Make(this->deviceId_, 262144, 512); - if (!device) { - return false; - } - return device->Setup().Success(); -} - -void DramTransQueue::Exit(Device& device) { - device.reset(); -} - -void DramTransQueue::Work(std::list& shards, const Device& device) { - auto it = shards.begin(); - if (this->failureSet_->Contains(it->owner)) { - this->Done(shards, device, true); - } - auto status = Status::OK(); - if (it->type == Task::Type::DUMP) { - status = this->D2H(shards, device); - } else { - status = this->H2D(shards, device); - } - this->Done(shards, device, status.Success()); -} - -Status DramTransQueue::H2D(std::list& shards, const Device& device) { - size_t pool_offset = 0; - std::vector host_addrs(shards.size()); - std::vector device_addrs(shards.size()); - int shard_index = 0; - for (auto& shard : shards) { - bool found = this->memPool_->GetOffset(shard.block, &pool_offset); - if (!found) { - return Status::Error(); - } - auto host_addr = this->memPool_->GetStartAddr().get() + pool_offset + shard.offset; - auto device_addr = shard.address; - host_addrs[shard_index] = host_addr; - device_addrs[shard_index] = reinterpret_cast(device_addr); - shard_index++; - } - auto it = shards.begin(); - return device->H2DBatchSync(device_addrs.data(), const_cast(host_addrs.data()), shards.size(), it->length); -} - -Status DramTransQueue::D2H(std::list& shards, const Device& device) { - size_t pool_offset = 0; - std::vector host_addrs(shards.size()); - std::vector device_addrs(shards.size()); - int shard_index = 0; - for (auto& shard : shards) { - bool found = this->memPool_->GetOffset(shard.block, &pool_offset); - if (!found) { - return Status::Error(); - } - auto host_addr = this->memPool_->GetStartAddr().get() + pool_offset + shard.offset; - auto device_addr = shard.address; - host_addrs[shard_index] = host_addr; - device_addrs[shard_index] = reinterpret_cast(device_addr); - shard_index++; - } - auto it = shards.begin(); - return device->D2HBatchSync(host_addrs.data(), const_cast(device_addrs.data()), shards.size(), it->length); -} - -void DramTransQueue::Done(std::list& shards, const Device& device, const bool success) { - auto it = shards.begin(); - if (!success) { this->failureSet_->Insert(it->owner); } - for (auto& shard : shards) { - if (shard.done) { - if (device) { - if (device->Synchronized().Failure()) { this->failureSet_->Insert(shard.owner); } - } - shard.done(); - } - } -} - -} // namespace UC \ No newline at end of file diff --git a/ucm/store/dramstore/cc/domain/trans/dram_trans_queue.h b/ucm/store/dramstore/cc/domain/trans/dram_trans_queue.h deleted file mode 100644 index 723507098..000000000 --- a/ucm/store/dramstore/cc/domain/trans/dram_trans_queue.h +++ /dev/null @@ -1,61 +0,0 @@ -/** - * MIT License - * - * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * */ -#ifndef UNIFIEDCACHE_DRAM_TRANS_QUEUE_H -#define UNIFIEDCACHE_DRAM_TRANS_QUEUE_H - -#include "device/idevice.h" -#include "status/status.h" -#include "task_queue.h" -#include "task_set.h" -#include "thread/thread_pool.h" -#include "memory/memory_pool.h" - -namespace UC { - -class DramTransQueue : public TaskQueue { - using Device = std::unique_ptr; - int32_t deviceId_{-1}; - TaskSet* failureSet_{nullptr}; - const MemoryPool* memPool_{nullptr}; - ThreadPool, Device> backend_{}; - -public: - Status Setup(const int32_t deviceId, - TaskSet* failureSet, - const MemoryPool* memPool, - const size_t timeoutMs); - void Push(std::list& shards) noexcept override; - -private: - bool Init(Device& device); - void Exit(Device& device); - void Work(std::list& shards, const Device& device); - void Done(std::list& shards, const Device& device, const bool success); - Status H2D(std::list& shards, const Device& device); - Status D2H(std::list& shards, const Device& device); -}; - -} // namespace UC - -#endif diff --git a/ucm/store/dramstore/cpy/dramstore.py.cc b/ucm/store/dramstore/cpy/dramstore.py.cc deleted file mode 100644 index cb76d5d1b..000000000 --- a/ucm/store/dramstore/cpy/dramstore.py.cc +++ /dev/null @@ -1,123 +0,0 @@ -/** - * MIT License - * - * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * */ -#include "dramstore.h" -#include - -namespace py = pybind11; - -namespace UC { - -class DRAMStorePy : public DRAMStore { -public: - void* CCStoreImpl() { return this; } - py::list AllocBatch(const py::list& blocks) - { - py::list results; - for (auto& block : blocks) { results.append(this->Alloc(block.cast())); } - return results; - } - py::list LookupBatch(const py::list& blocks) - { - py::list founds; - for (auto& block : blocks) { founds.append(this->Lookup(block.cast())); } - return founds; - } - void CommitBatch(const py::list& blocks, const bool success) - { - for (auto& block : blocks) { this->Commit(block.cast(), success); } - } - py::tuple CheckPy(const size_t task) - { - auto finish = false; - auto ret = this->Check(task, finish); - return py::make_tuple(ret, finish); - } - size_t Load(const py::list& blockIds, const py::list& offsets, const py::list& addresses, - const py::list& lengths) - { - return this->SubmitPy(blockIds, offsets, addresses, lengths, Task::Type::LOAD, - Task::Location::DEVICE, "DRAM::H2D"); - } - size_t Dump(const py::list& blockIds, const py::list& offsets, const py::list& addresses, - const py::list& lengths) - { - return this->SubmitPy(blockIds, offsets, addresses, lengths, Task::Type::DUMP, - Task::Location::DEVICE, "DRAM::D2H"); - } - -private: - size_t SubmitPy(const py::list& blockIds, const py::list& offsets, const py::list& addresses, - const py::list& lengths, Task::Type&& type, Task::Location&& location, - std::string&& brief) - { - Task task{std::move(type), std::move(location), std::move(brief)}; - auto blockId = blockIds.begin(); - auto offset = offsets.begin(); - auto address = addresses.begin(); - auto length = lengths.begin(); - while ((blockId != blockIds.end()) && (offset != offsets.end()) && - (address != addresses.end()) && (length != lengths.end())) { - task.Append(blockId->cast(), offset->cast(), - address->cast(), length->cast()); - blockId++; - offset++; - address++; - length++; - } - return this->Submit(std::move(task)); - } -}; - -} // namespace UC - -PYBIND11_MODULE(ucmdramstore, module) -{ - module.attr("project") = UCM_PROJECT_NAME; - module.attr("version") = UCM_PROJECT_VERSION; - module.attr("commit_id") = UCM_COMMIT_ID; - module.attr("build_type") = UCM_BUILD_TYPE; - auto store = py::class_(module, "DRAMStore"); - auto config = py::class_(store, "Config"); - config.def(py::init(), - py::arg("capacity"), py::arg("blockSize"), py::arg("deviceId"), py::arg("streamNumber"), py::arg("timeoutMs")); - config.def_readwrite("capacity", &UC::DRAMStorePy::Config::capacity); - config.def_readwrite("blockSize", &UC::DRAMStorePy::Config::blockSize); - config.def_readwrite("deviceId", &UC::DRAMStorePy::Config::deviceId); - config.def_readwrite("streamNumber", &UC::DRAMStorePy::Config::streamNumber); - config.def_readwrite("timeoutMs", &UC::DRAMStorePy::Config::timeoutMs); - store.def(py::init<>()); - store.def("CCStoreImpl", &UC::DRAMStorePy::CCStoreImpl); - store.def("Setup", &UC::DRAMStorePy::Setup); - store.def("Alloc", py::overload_cast(&UC::DRAMStorePy::Alloc)); - store.def("AllocBatch", &UC::DRAMStorePy::AllocBatch); - store.def("Lookup", py::overload_cast(&UC::DRAMStorePy::Lookup)); - store.def("LookupBatch", &UC::DRAMStorePy::LookupBatch); - store.def("Load", &UC::DRAMStorePy::Load); - store.def("Dump", &UC::DRAMStorePy::Dump); - store.def("Wait", &UC::DRAMStorePy::Wait); - store.def("Check", &UC::DRAMStorePy::Check); - store.def("Commit", - py::overload_cast(&UC::DRAMStorePy::Commit)); - store.def("CommitBatch", &UC::DRAMStorePy::CommitBatch); -} diff --git a/ucm/store/infra/CMakeLists.txt b/ucm/store/infra/CMakeLists.txt index f3e0ce727..6bc8dc4a4 100644 --- a/ucm/store/infra/CMakeLists.txt +++ b/ucm/store/infra/CMakeLists.txt @@ -1,21 +1,11 @@ add_library(storeinfra STATIC) target_include_directories(storeinfra PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) file(GLOB_RECURSE UCMSTORE_COMMON_FILE_SOURCE_FILES "file/*.cc") -if(LOGGER_BACKEND STREQUAL "spdlog") - file(GLOB_RECURSE UCMSTORE_COMMON_LOGGER_SOURCE_FILES "logger/spdlog/*.cc") -endif() -if(LOGGER_BACKEND STREQUAL "flux") - file(GLOB_RECURSE UCMSTORE_COMMON_LOGGER_SOURCE_FILES "logger/flux/*.cc") -endif() -file(GLOB_RECURSE UCMSTORE_COMMON_STATUS_SOURCE_FILES "status/*.cc") -file(GLOB_RECURSE UCMSTORE_COMMON_TEMPLATE_SOURCE_FILES "template/*.cc") -file(GLOB_RECURSE UCMSTORE_COMMON_THREAD_SOURCE_FILES "thread/*.cc") target_sources(storeinfra PRIVATE ${UCMSTORE_COMMON_FILE_SOURCE_FILES}) -target_sources(storeinfra PRIVATE ${UCMSTORE_COMMON_LOGGER_SOURCE_FILES}) -target_sources(storeinfra PRIVATE ${UCMSTORE_COMMON_STATUS_SOURCE_FILES}) -target_sources(storeinfra PRIVATE ${UCMSTORE_COMMON_TEMPLATE_SOURCE_FILES}) -target_sources(storeinfra PRIVATE ${UCMSTORE_COMMON_THREAD_SOURCE_FILES}) -target_link_libraries(storeinfra PUBLIC fmt) -if(LOGGER_BACKEND STREQUAL "spdlog") - target_link_libraries(storeinfra PUBLIC spdlog) -endif() +target_link_libraries(storeinfra PUBLIC + infra_status + infra_logger + infra_template + infra_thread + infra_time +) diff --git a/ucm/store/infra/memory/memory_pool.h b/ucm/store/infra/memory/memory_pool.h deleted file mode 100644 index 200d12868..000000000 --- a/ucm/store/infra/memory/memory_pool.h +++ /dev/null @@ -1,174 +0,0 @@ -/** - * MIT License - * - * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * */ -#ifndef UNIFIEDCACHE_MEMORY_POOL_H -#define UNIFIEDCACHE_MEMORY_POOL_H - -#include -#include -#include -#include -#include -#include -#include "status/status.h" -#include "device/idevice.h" -#include -#include -#include -#include "logger/logger.h" - -namespace UC { - -class MemoryPool { - - std::string DUMMY_SLOT_PREFIX{"__slot_"}; - using Device = std::unique_ptr; -public: - - Status Setup(int32_t deviceId, size_t capacity, size_t blockSize) { - capacity_ = capacity; - blockSize_ = blockSize; - device_ = DeviceFactory::Make(deviceId, blockSize, static_cast(capacity / blockSize)); - if (!device_) { - UC_ERROR("MemoryPool: failed to create device"); - return Status::Error(); - } - Status status = device_->Setup(); - if (!status.Success()) { - UC_ERROR("MemoryPool: failed to set up device"); - return Status::Error(); - } - pool_ = device_->GetBuffer(capacity_); - if (!pool_) { - UC_ERROR("MemoryPool: failed to get pool memory space"); - return Status::Error(); - } - - size_t slotNum = capacity_ / blockSize_; - for (size_t i = 0; i < slotNum; ++i) { - std::string dummy = DUMMY_SLOT_PREFIX + std::to_string(i); - size_t offset = i * blockSize_; - lruList_.push_front(dummy); - lruIndex_[dummy] = lruList_.begin(); - offsetMap_[dummy] = offset; - } - return Status::OK(); - - } - - Status NewBlock(const std::string& blockId) { - if (offsetMap_.count(blockId)) { - return Status::DuplicateKey(); - } - if (lruList_.empty()) { - // 所有空间里的块都正在写,那么就不能够分配 - return Status::Error(); - } - size_t offset = LRUEvictOne(); - offsetMap_[blockId] = offset; - return Status::OK(); - } - - bool LookupBlock(const std::string& blockId) const { - return availableBlocks_.count(blockId); - } - - bool GetOffset(const std::string& blockId, size_t* offset) const { - auto it = offsetMap_.find(blockId); - if (it == offsetMap_.end()) { - return false; - } - *offset = it->second; - return true; - } - - Status CommitBlock(const std::string& blockId, bool success) { - if (success) { - availableBlocks_.insert(blockId); - touchUnsafe(blockId); - } else { - resetSpaceOfBlock(blockId); - } - return Status::OK(); - } - - std::shared_ptr GetStartAddr() const { - return pool_; - } - -private: - std::shared_ptr pool_ = nullptr; - Device device_ = nullptr; - size_t capacity_; - size_t blockSize_; - - std::unordered_map offsetMap_; - std::set availableBlocks_; - - using ListType = std::list; - ListType lruList_; - std::unordered_map lruIndex_; - - void touchUnsafe(const std::string& blockId) { - auto it = lruIndex_.find(blockId); - if (it != lruIndex_.end()) { - lruList_.splice(lruList_.begin(), lruList_, it->second); - } - else { - lruList_.push_front(blockId); // 访问一次,该块就是最近使用了的,所以放到LRU队列的头部。这就是一般LRU的逻辑 - lruIndex_[blockId] = lruList_.begin(); - } - } - - size_t LRUEvictOne() { - const std::string& victim = lruList_.back(); - // 真实数据块,才从availableBlocks_中删掉 - if (victim.rfind(DUMMY_SLOT_PREFIX, 0) != 0) { - availableBlocks_.erase(victim); - } - size_t offset = offsetMap_[victim]; - offsetMap_.erase(victim); - lruIndex_.erase(victim); - lruList_.pop_back(); - return offset; - } - - void resetSpaceOfBlock(const std::string& blockId) { - auto it = offsetMap_.find(blockId); - size_t offset = it->second; - std::string dummy = DUMMY_SLOT_PREFIX + std::to_string(offset / blockSize_); - offsetMap_.erase(blockId); - - auto lit = lruIndex_.find(blockId); - if (lit != lruIndex_.end()) { - lruList_.erase(lit->second); - lruIndex_.erase(lit); - } - lruList_.push_back(dummy); // 将一个块commit false后,回收之前分配的内存,并且要将其放到LRU队列的尾部(下次可以写的时候,要马上就写。因为该块的优先级高于已经写了的块) - lruIndex_[dummy] = std::prev(lruList_.end()); - offsetMap_[dummy] = offset; - } -}; - -} // namespace UC -#endif \ No newline at end of file diff --git a/ucm/store/infra/status/status.h b/ucm/store/infra/status/status.h deleted file mode 100644 index 809d24597..000000000 --- a/ucm/store/infra/status/status.h +++ /dev/null @@ -1,134 +0,0 @@ -/** - * MIT License - * - * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * */ -#ifndef UNIFIEDCACHE_STATUS_H -#define UNIFIEDCACHE_STATUS_H - -#include - -namespace UC { - -class Status { - enum class Code { -#define UC_MAKE_STATUS_CODE(i) (-50000 - (i)) - OK = 0, - ERROR = -1, - EPARAM = UC_MAKE_STATUS_CODE(0), - EOOM = UC_MAKE_STATUS_CODE(1), - EOSERROR = UC_MAKE_STATUS_CODE(2), - EDUPLICATE = UC_MAKE_STATUS_CODE(3), - ERETRY = UC_MAKE_STATUS_CODE(4), - ENOOBJ = UC_MAKE_STATUS_CODE(5), - ESERIALIZE = UC_MAKE_STATUS_CODE(6), - EDESERIALIZE = UC_MAKE_STATUS_CODE(7), - EUNSUPPORTED = UC_MAKE_STATUS_CODE(8), - ENOSPACE = UC_MAKE_STATUS_CODE(9), -#undef UC_MAKE_STATUS_CODE - }; - -public: - static Status& OK() - { - static Status s{Code::OK}; - return s; - } - static Status& Error() - { - static Status s{Code::ERROR}; - return s; - } - static Status& InvalidParam() - { - static Status s{Code::EPARAM}; - return s; - } - static Status& OutOfMemory() - { - static Status s{Code::EOOM}; - return s; - } - static Status& OsApiError() - { - static Status s{Code::EOSERROR}; - return s; - } - static Status& DuplicateKey() - { - static Status s{Code::EDUPLICATE}; - return s; - } - static Status& Retry() - { - static Status s{Code::ERETRY}; - return s; - } - static Status& NotFound() - { - static Status s{Code::ENOOBJ}; - return s; - } - static Status& SerializeFailed() - { - static Status s{Code::ESERIALIZE}; - return s; - } - static Status& DeserializeFailed() - { - static Status s{Code::EDESERIALIZE}; - return s; - } - static Status& Unsupported() - { - static Status s{Code::EUNSUPPORTED}; - return s; - } - static Status& NoSpace() - { - static Status s{Code::ENOSPACE}; - return s; - } -public: - Status(const Status& status) { this->code_ = status.code_; } - Status& operator=(const Status& status) - { - if (this != &status) { this->code_ = status.code_; } - return *this; - } - bool operator==(const Status& status) const { return this->code_ == status.code_; } - bool operator!=(const Status& status) const { return this->code_ != status.code_; } - int32_t Underlying() const { return static_cast(this->code_); } - bool Success() const { return this->code_ == Code::OK; } - bool Failure() const { return this->code_ != Code::OK; } - -private: - Status(const Code code) : code_{code} {} - -private: - Code code_; -}; - -inline int32_t format_as(const Status& status) { return status.Underlying(); } - -} // namespace UC - -#endif diff --git a/ucm/store/nfsstore/cc/domain/hotness/hotness_timer.h b/ucm/store/nfsstore/cc/domain/hotness/hotness_timer.h index add777bad..c49bd1dc9 100644 --- a/ucm/store/nfsstore/cc/domain/hotness/hotness_timer.h +++ b/ucm/store/nfsstore/cc/domain/hotness/hotness_timer.h @@ -26,6 +26,7 @@ #define UNIFIEDCACHE_HOTNESS_TIMER_H #include #include +#include "logger/logger.h" #include "template/timer.h" namespace UC { @@ -41,7 +42,7 @@ class HotnessTimer { UC_ERROR("Failed({}) to start hotness timer.", e.what()); return Status::OutOfMemory(); } - return this->timer_->Start(); + return this->timer_->Start() ? Status::OK() : Status::Error(); } private: std::chrono::seconds interval_; @@ -50,4 +51,4 @@ class HotnessTimer { } // namespace UC -#endif \ No newline at end of file +#endif diff --git a/ucm/store/nfsstore/nfsstore_connector.py b/ucm/store/nfsstore/nfsstore_connector.py index 4a348a053..bd30f6288 100644 --- a/ucm/store/nfsstore/nfsstore_connector.py +++ b/ucm/store/nfsstore/nfsstore_connector.py @@ -51,7 +51,7 @@ def __init__(self, config: Dict): if transfer_enable: param.transferDeviceId = config["device"] param.transferIoSize = config["io_size"] - param.transferIoDirect = config.get("transferIoDirect", False) + param.transferIoDirect = config.get("use_direct", False) # NOTE: compatible with legacy nfsstore lib if hasattr(param, "storageCapacity"): diff --git a/ucm/store/pcstore/cc/domain/trans/trans_queue.cc b/ucm/store/pcstore/cc/domain/trans/trans_queue.cc index 8e2958fa0..83b8ce458 100644 --- a/ucm/store/pcstore/cc/domain/trans/trans_queue.cc +++ b/ucm/store/pcstore/cc/domain/trans/trans_queue.cc @@ -39,7 +39,7 @@ void TransQueue::DeviceWorker(BlockTask&& task) auto done = task.done; auto devPtrs = (void**)task.shards.data(); auto hostPtr = task.buffer.get(); - auto s = Trans::Status::OK(); + auto s = Status::OK(); if (task.type == TransTask::Type::LOAD) { s = stream_->HostToDevice(hostPtr, devPtrs, size, number); } else { diff --git a/ucm/store/test/CMakeLists.txt b/ucm/store/test/CMakeLists.txt index 859c185c4..0c4974efd 100644 --- a/ucm/store/test/CMakeLists.txt +++ b/ucm/store/test/CMakeLists.txt @@ -4,7 +4,7 @@ if(BUILD_UNIT_TESTS) add_executable(ucmstore.test ${UCMSTORE_TEST_SOURCE_FILES}) target_include_directories(ucmstore.test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/case) target_link_libraries(ucmstore.test PRIVATE - dramstore nfsstore localstore storeinfra storedevice + nfsstore localstore storeinfra storedevice gtest_main gtest mockcpp ) gtest_discover_tests(ucmstore.test) diff --git a/ucm/store/test/case/infra/mem_pool_test.cc b/ucm/store/test/case/infra/mem_pool_test.cc deleted file mode 100644 index f9ea04387..000000000 --- a/ucm/store/test/case/infra/mem_pool_test.cc +++ /dev/null @@ -1,169 +0,0 @@ -/** - * MIT License - * - * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * */ - -#include "infra/memory/memory_pool.h" -#include - -class UCMemoryPoolTest : public ::testing::Test {}; - -TEST_F(UCMemoryPoolTest, NewBlockAllocateAndCommit) -{ - UC::MemoryPool memPool; // 初始化内存池 - ASSERT_EQ(memPool.Setup(-1, 10, 2), UC::Status::OK()); - const std::string block1 = "block1"; - size_t offset = 10; - ASSERT_FALSE(memPool.LookupBlock(block1)); - // ASSERT_EQ(memPool.GetOffset(block1), nullptr); - ASSERT_EQ(memPool.GetOffset(block1, &offset), false); - ASSERT_EQ(memPool.NewBlock(block1), UC::Status::OK()); - ASSERT_FALSE(memPool.LookupBlock(block1)); - // ASSERT_NE(memPool.GetOffset(block1), nullptr); - ASSERT_EQ(memPool.GetOffset(block1, &offset), true); - ASSERT_EQ(memPool.NewBlock(block1), UC::Status::DuplicateKey()); - ASSERT_EQ(memPool.CommitBlock(block1, true), UC::Status::OK()); - ASSERT_TRUE(memPool.LookupBlock(block1)); -} - -TEST_F(UCMemoryPoolTest, EvictOldBlock) -{ - UC::MemoryPool memPool; // 初始化内存池 - ASSERT_EQ(memPool.Setup(-1, 10, 5), UC::Status::OK()); - const std::string block1 = "block1"; - const std::string block2 = "block2"; - const std::string block3 = "block3"; - size_t offset = 10; - ASSERT_EQ(memPool.NewBlock(block1), UC::Status::OK()); - // ASSERT_NE(memPool.GetOffset(block1), nullptr); - ASSERT_EQ(memPool.GetOffset(block1, &offset), true); - ASSERT_EQ(memPool.NewBlock(block2), UC::Status::OK()); - // ASSERT_NE(memPool.GetOffset(block2), nullptr); - ASSERT_EQ(memPool.GetOffset(block2, &offset), true); - memPool.CommitBlock(block1, true); - memPool.CommitBlock(block2, true); - ASSERT_EQ(memPool.NewBlock(block3), UC::Status::OK()); - // ASSERT_NE(memPool.GetOffset(block3), nullptr); - ASSERT_EQ(memPool.GetOffset(block3, &offset), true); - // ASSERT_EQ(memPool.GetOffset(block1), nullptr); - ASSERT_EQ(memPool.GetOffset(block1, &offset), false); - // ASSERT_NE(memPool.GetOffset(block2), nullptr); - ASSERT_EQ(memPool.GetOffset(block2, &offset), true); - ASSERT_FALSE(memPool.LookupBlock(block1)); - ASSERT_TRUE(memPool.LookupBlock(block2)); - ASSERT_FALSE(memPool.LookupBlock(block3)); -} - -TEST_F(UCMemoryPoolTest, OldBlockCommitFalse) -{ - UC::MemoryPool memPool; // 初始化内存池 - ASSERT_EQ(memPool.Setup(-1, 32, 8), UC::Status::OK()); - const std::string block1 = "block1"; - const std::string block2 = "block2"; - const std::string block3 = "block3"; - const std::string block4 = "block4"; - const std::string block5 = "block5"; - size_t offset = 32; - ASSERT_EQ(memPool.NewBlock(block1), UC::Status::OK()); - // ASSERT_NE(memPool.GetOffset(block1), nullptr); - ASSERT_EQ(memPool.GetOffset(block1, &offset), true); - ASSERT_EQ(memPool.NewBlock(block2), UC::Status::OK()); - // ASSERT_NE(memPool.GetOffset(block2), nullptr); - ASSERT_EQ(memPool.GetOffset(block2, &offset), true); - ASSERT_EQ(memPool.NewBlock(block3), UC::Status::OK()); - // ASSERT_NE(memPool.GetOffset(block3), nullptr); - ASSERT_EQ(memPool.GetOffset(block3, &offset), true); - memPool.CommitBlock(block1, true); - memPool.CommitBlock(block2, false); - ASSERT_TRUE(memPool.LookupBlock(block1)); - ASSERT_FALSE(memPool.LookupBlock(block2)); - ASSERT_FALSE(memPool.LookupBlock(block3)); - ASSERT_EQ(memPool.NewBlock(block4), UC::Status::OK()); - // ASSERT_EQ(memPool.GetOffset(block4), 8); - ASSERT_EQ(memPool.GetOffset(block4, &offset), true); - ASSERT_EQ(offset, 8); - ASSERT_EQ(memPool.NewBlock(block5), UC::Status::OK()); - // ASSERT_EQ(memPool.GetOffset(block5), 24); - ASSERT_EQ(memPool.GetOffset(block5, &offset), true); - ASSERT_EQ(offset, 24); - memPool.CommitBlock(block3, true); - memPool.CommitBlock(block4, true); - memPool.CommitBlock(block5, true); - ASSERT_TRUE(memPool.LookupBlock(block1)); - ASSERT_FALSE(memPool.LookupBlock(block2)); - ASSERT_TRUE(memPool.LookupBlock(block3)); - ASSERT_TRUE(memPool.LookupBlock(block4)); - ASSERT_TRUE(memPool.LookupBlock(block5)); - - ASSERT_EQ(memPool.NewBlock(block1), UC::Status::DuplicateKey()); - ASSERT_EQ(memPool.NewBlock(block2), UC::Status::OK()); - // ASSERT_EQ(memPool.GetOffset(block2), 0); - ASSERT_EQ(memPool.GetOffset(block2, &offset), true); - ASSERT_EQ(offset, 0); - ASSERT_FALSE(memPool.LookupBlock(block1)); - ASSERT_FALSE(memPool.LookupBlock(block2)); - memPool.CommitBlock(block2, true); - ASSERT_TRUE(memPool.LookupBlock(block2)); -} - -TEST_F(UCMemoryPoolTest, NoCommittedBlock) -{ - UC::MemoryPool memPool; // 初始化内存池 - ASSERT_EQ(memPool.Setup(-1, 32, 8), UC::Status::OK()); - const std::string block1 = "block1"; - const std::string block2 = "block2"; - const std::string block3 = "block3"; - const std::string block4 = "block4"; - const std::string block5 = "block5"; - const std::string block6 = "block6"; - size_t offset = 32; - ASSERT_EQ(memPool.NewBlock(block1), UC::Status::OK()); - ASSERT_EQ(memPool.NewBlock(block2), UC::Status::OK()); - ASSERT_EQ(memPool.NewBlock(block3), UC::Status::OK()); - ASSERT_EQ(memPool.NewBlock(block4), UC::Status::OK()); - ASSERT_EQ(memPool.NewBlock(block5), UC::Status::Error()); - memPool.CommitBlock(block1, true); - ASSERT_TRUE(memPool.LookupBlock(block1)); - ASSERT_EQ(memPool.NewBlock(block5), UC::Status::OK()); - // ASSERT_EQ(memPool.GetOffset(block5), 0); - ASSERT_EQ(memPool.GetOffset(block5, &offset), true); - ASSERT_EQ(offset, 0); - ASSERT_FALSE(memPool.LookupBlock(block1)); - ASSERT_EQ(memPool.NewBlock(block6), UC::Status::Error()); - // ASSERT_EQ(memPool.GetOffset(block2), 8); - ASSERT_EQ(memPool.GetOffset(block2, &offset), true); - ASSERT_EQ(offset, 8); - memPool.CommitBlock(block2, false); - // ASSERT_EQ(memPool.GetOffset((block2)), nullptr); - ASSERT_EQ(memPool.GetOffset(block2, &offset), false); - ASSERT_FALSE(memPool.LookupBlock(block1)); - ASSERT_EQ(memPool.NewBlock(block6), UC::Status::OK()); - // ASSERT_EQ(memPool.GetOffset(block6), 8); - ASSERT_EQ(memPool.GetOffset(block6, &offset), true); - ASSERT_EQ(offset, 8); - ASSERT_FALSE(memPool.LookupBlock(block6)); - memPool.CommitBlock(block6, true); - ASSERT_TRUE(memPool.LookupBlock(block6)); - // ASSERT_EQ(memPool.GetOffset(block6), 8); - ASSERT_EQ(memPool.GetOffset(block6, &offset), true); - ASSERT_EQ(offset, 8); -} \ No newline at end of file diff --git a/ucm/utils.py b/ucm/utils.py new file mode 100644 index 000000000..bf07f6b84 --- /dev/null +++ b/ucm/utils.py @@ -0,0 +1,90 @@ +# +# MIT License +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# + +from typing import Any, Dict + +import yaml + +from ucm.logger import init_logger + +logger = init_logger(__name__) + + +class Config: + def __init__(self, kv_transfer_config: Any): + self.kv_transfer_config = kv_transfer_config + self.config: Dict[str, Any] = {} + self._load_config() + + def load_ucm_config_from_yaml(self, file_path: str) -> Dict[str, Any]: + if not file_path: + logger.warning("No UCM config file path provided.") + return {} + + try: + with open(file_path, "r", encoding="utf-8") as f: + config = yaml.safe_load(f) or {} + if not isinstance(config, dict): + logger.warning( + f"Config file {file_path} does not contain a dictionary. " + "Returning empty config." + ) + return {} + logger.info(f"Loaded UCM config from {file_path}") + return config + except FileNotFoundError: + logger.error(f"UCM config file not found: {file_path}") + return {} + except yaml.YAMLError as e: + logger.error(f"Failed to parse YAML config file {file_path}: {e}") + return {} + + def _load_config(self) -> None: + has_extra_config = ( + self.kv_transfer_config is not None + and hasattr(self.kv_transfer_config, "kv_connector_extra_config") + and self.kv_transfer_config.kv_connector_extra_config is not None + ) + if not has_extra_config: + self.config = self._get_default_config() + else: + extra_config = self.kv_transfer_config.kv_connector_extra_config + if "UCM_CONFIG_FILE" in extra_config: + config_file = extra_config["UCM_CONFIG_FILE"] + self.config = self.load_ucm_config_from_yaml(config_file) + else: + if extra_config == {}: + self.config = self._get_default_config() + else: + self.config = dict(extra_config) + logger.info("Using kv_connector_extra_config from terminal input") + + def _get_default_config(self) -> Dict[str, Any]: + config = {"ucm_connector_name": "UcmDramStore"} + logger.warning(f"No UCM config provided, using default configuration {config}") + return config + + def get_config(self) -> Dict[str, Any]: + logger.info(f"Using UCM with config: {self.config}") + return self.config