diff --git a/test/test_ucm_dram.py b/test/test_ucm_dram.py deleted file mode 100644 index 020405d1..00000000 --- a/test/test_ucm_dram.py +++ /dev/null @@ -1,250 +0,0 @@ -# -# MIT License -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# - -import random -import unittest -import unittest.mock as mock -from contextlib import contextmanager -from typing import List -from unittest.mock import MagicMock - -import torch -from vllm.multimodal.inputs import MultiModalKwargs -from vllm.sampling_params import SamplingParams -from vllm.utils import sha256 -from vllm.v1.core.kv_cache_utils import hash_request_tokens -from vllm.v1.request import Request - - -@contextmanager -def mock_stream_context(stream=None): - yield - - -class MockStream: - def __init__(self, device=None): - self.device = device or torch.device("cpu") - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - pass - - def synchronize(self): - pass - - def record_event(self, event=None): - return event or MockEvent() - - def wait_stream(self, stream): - pass - - -class MockEvent: - def __init__(self, enable_timing=False): - self.enable_timing = enable_timing - - def record(self, stream=None): - pass - - def wait(self, stream=None): - pass - - def synchronize(self): - pass - - -def patch_cuda_for_cpu(): - mock.patch("torch.cuda.Stream", MockStream).start() - mock.patch("torch.cuda.Event", MockEvent).start() - mock.patch("torch.cuda.current_stream", return_value=MockStream()).start() - mock.patch("torch.cuda.synchronize", side_effect=lambda *a, **k: None).start() - mock.patch("torch.cuda.is_available", return_value=True).start() - mock.patch("torch.cuda.stream", mock_stream_context).start() - - -patch_cuda_for_cpu() -from ucm.store.dramstore.dramstore_connector import ( # isort: skip - DramTask, - UcmDramStore, -) - - -def make_request( - request_id, prompt_token_ids, mm_positions=None, mm_hashes=None, cache_salt=None -): - if mm_positions is None: - multi_modal_inputs = None - else: - multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions) - - return Request( - request_id=request_id, - prompt_token_ids=prompt_token_ids, - multi_modal_inputs=multi_modal_inputs, - multi_modal_hashes=mm_hashes, - multi_modal_placeholders=mm_positions, - sampling_params=SamplingParams(max_tokens=17), - pooling_params=None, - eos_token_id=100, - arrival_time=0, - lora_request=None, - cache_salt=cache_salt, - ) - - -class TestUcmDram(unittest.TestCase): - - @classmethod - def setUpClass(cls): - print("===> Before all tests (setUpClass)") - - @classmethod - def tearDownClass(cls): - print("===> After all tests (setUpClass)") - - def setUp(self): - self.config = {"block_size": 4} - self.scheduler_config = { - "role": "scheduler", - "max_cache_size": 1073741824, - "kv_block_size": 262144, - } - self.worker_config = { - "role": "worker", - "max_cache_size": 1073741824, - "kv_block_size": 262144, - } - - self.block_number = 4 - self.block_size = int(self.config["block_size"]) - self.scheduler_dram = UcmDramStore(self.scheduler_config) - self.worker_dram = UcmDramStore(self.worker_config) - random.seed(20250728) - self.request = make_request( - request_id=1, - prompt_token_ids=random.sample( - range(0, 10000), self.block_number * self.block_size - ), - mm_positions=None, - mm_hashes=None, - ) - block_hash_types = hash_request_tokens(sha256, self.block_size, self.request) - self.block_hashes: List[str] = [str(x.hash_value) for x in block_hash_types] - - def test_look_up_all_hit(self): - """ - Test for all blocks hitten in cache - """ - expected = [True] * len(self.block_hashes) - self.scheduler_dram.cached_blocks.update(self.block_hashes) - actual = self.scheduler_dram.lookup(self.block_hashes) - - self.assertEqual(actual, expected) - - def test_lookup_partial_hit(self): - """ - Test for part of the blocks hitten in cache - """ - partial_index = random.randint(0, 4) - partial_hashes = self.block_hashes[:partial_index] - self.scheduler_dram.cached_blocks.update(partial_hashes) - actual = self.scheduler_dram.lookup(self.block_hashes) - expected = [True] * partial_index + [False] * (self.block_size - partial_index) - self.assertEqual(actual, expected) - - def test_lookup_none_hit(self): - """ - Test for none of the blocks hitten in cache - """ - actual = self.scheduler_dram.lookup(self.block_hashes) - expected = [False] * len(self.block_hashes) - self.assertEqual(actual, expected) - - def test_load_success(self): - """ - Test for load from cache successfully - """ - src_tensors = [ - torch.randint(0, 100, (self.block_size,), dtype=torch.int8) - for _ in range(len(self.block_hashes)) - ] - offsets = [i for i in range(len(self.block_hashes))] - dump_task = self.worker_dram.dump(self.block_hashes, offsets, src_tensors) - self.worker_dram.wait(dump_task) - dst_tensors = [ - torch.zeros(self.block_size, dtype=torch.int8) - for _ in range(len(self.block_hashes)) - ] - load_task = self.worker_dram.load(self.block_hashes, offsets, dst_tensors) - - self.assertIsInstance(load_task, DramTask) - self.assertIsNotNone(load_task.event) - for i, (src_tensor, dst_tensor) in enumerate(zip(src_tensors, dst_tensors)): - self.assertEqual(dst_tensor.shape[0], self.block_size) - self.assertTrue( - torch.equal(src_tensor, dst_tensor), - f"Block {i} loaded data is different", - ) - - def test_dump_success(self): - """ - Test data dump successfully - """ - src_tensors = [ - torch.randint(0, 100, (self.block_size,), dtype=torch.int8) - for _ in range(len(self.block_hashes)) - ] - offsets = [i for i in range(len(self.block_hashes))] - original_data = [tensor.clone() for tensor in src_tensors] - dump_task = self.worker_dram.dump(self.block_hashes, offsets, src_tensors) - self.assertIsInstance(dump_task, DramTask) - self.assertIsNotNone(dump_task.event) - self.worker_dram.wait(dump_task) - for i, block_id in enumerate(self.block_hashes): - key = block_id + "_" + str(offsets[i]) - cached_data = self.worker_dram.dram_cache[key] - self.assertEqual(cached_data.shape[0], self.block_size) - self.assertTrue(torch.equal(cached_data, original_data[i])) - - def test_wait_success(self): - """ - Test wait for task successfully - """ - task = DramTask() - task.event = MagicMock() - result = self.worker_dram.wait(task) - self.assertEqual(result, 0) - task.event.synchronize.assert_called_once() - - def test_wait_failure(self): - task = DramTask() - task.event = None - result = self.worker_dram.wait(task) - self.assertEqual(result, -1) - - -if __name__ == "__main__": - unittest.main() diff --git a/ucm/store/dramstore/CMakeLists.txt b/ucm/store/dramstore/CMakeLists.txt index 9c4e27ec..15295544 100644 --- a/ucm/store/dramstore/CMakeLists.txt +++ b/ucm/store/dramstore/CMakeLists.txt @@ -1,7 +1,10 @@ file(GLOB_RECURSE UCMSTORE_DRAM_CC_SOURCE_FILES "./cc/*.cc") add_library(dramstore STATIC ${UCMSTORE_DRAM_CC_SOURCE_FILES}) -target_include_directories(dramstore PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cc) -target_link_libraries(dramstore PUBLIC storeinfra storetask) +target_include_directories(dramstore PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/cc/api + ${CMAKE_CURRENT_SOURCE_DIR}/cc/domain +) +target_link_libraries(dramstore PUBLIC storeinfra storedevice storetask) file(GLOB_RECURSE UCMSTORE_DRAM_CPY_SOURCE_FILES "./cpy/*.cc") pybind11_add_module(ucmdramstore ${UCMSTORE_DRAM_CPY_SOURCE_FILES}) diff --git a/ucm/store/dramstore/cc/api/dramstore.cc b/ucm/store/dramstore/cc/api/dramstore.cc index 56b4350f..c59b7f2b 100644 --- a/ucm/store/dramstore/cc/api/dramstore.cc +++ b/ucm/store/dramstore/cc/api/dramstore.cc @@ -24,27 +24,70 @@ #include "dramstore.h" #include "logger/logger.h" #include "status/status.h" +#include "trans/dram_trans_manager.h" +#include "memory/memory_pool.h" namespace UC { class DRAMStoreImpl : public DRAMStore { public: - int32_t Setup(const size_t ioSize, const size_t capacity, const int32_t deviceId) { return -1; } - int32_t Alloc(const std::string& block) override { return -1; } - bool Lookup(const std::string& block) override { return false; } - void Commit(const std::string& block, const bool success) override {} + int32_t Setup(const Config& config) { + auto status = this->memPool_.Setup(config.deviceId, config.capacity, config.blockSize); + if (status.Failure()) { + UC_ERROR("Failed({}) to setup MemoryPool.", status); + return status.Underlying(); + } + status = this->transMgr_.Setup(config.deviceId, config.streamNumber, &this->memPool_, config.timeoutMs); + if (status.Failure()) { + UC_ERROR("Failed({}) to setup TsfTaskManager.", status); + return status.Underlying(); + } + return Status::OK().Underlying(); + } + int32_t Alloc(const std::string& block) override { return this->memPool_.NewBlock(block).Underlying(); } + bool Lookup(const std::string& block) override { return this->memPool_.LookupBlock(block); } + void Commit(const std::string& block, const bool success) override { this->memPool_.CommitBlock(block, success).Underlying(); } std::list Alloc(const std::list& blocks) override { - return std::list(); + std::list results; + for (const auto &block : blocks) { + results.emplace_back(this->Alloc(block)); + } + return results; } std::list Lookup(const std::list& blocks) override { - return std::list(); + std::list founds; + for (const auto &block : blocks) { + founds.emplace_back(this->Lookup(block)); + } + return founds; + } + void Commit(const std::list& blocks, const bool success) override { + for (const auto &block : blocks) { + this->Commit(block, success); + } + } + size_t Submit(Task&& task) override { + auto taskId = Task::invalid; + auto status = this->transMgr_.Submit(std::move(task), taskId); + if (status.Failure()) { taskId = Task::invalid; } + return taskId; } + + int32_t Wait(const size_t task) override { + return this->transMgr_.Wait(task).Underlying(); + } + + int32_t Check(const size_t task, bool& finish) override { + return this->transMgr_.Check(task, finish).Underlying(); } - void Commit(const std::list& blocks, const bool success) override {} - size_t Submit(Task&& task) override { return 0; } - int32_t Wait(const size_t task) override { return -1; } - int32_t Check(const size_t task, bool& finish) override { return -1; } + + +private: + + DramTransManager transMgr_; + MemoryPool memPool_; + }; int32_t DRAMStore::Setup(const Config& config) @@ -55,7 +98,7 @@ int32_t DRAMStore::Setup(const Config& config) return Status::OutOfMemory().Underlying(); } this->impl_ = impl; - return impl->Setup(config.ioSize, config.capacity, config.deviceId); + return impl->Setup(config); } } // namespace UC diff --git a/ucm/store/dramstore/cc/api/dramstore.h b/ucm/store/dramstore/cc/api/dramstore.h index 1dc97573..1ae8ab91 100644 --- a/ucm/store/dramstore/cc/api/dramstore.h +++ b/ucm/store/dramstore/cc/api/dramstore.h @@ -31,11 +31,13 @@ namespace UC { class DRAMStore : public CCStore { public: struct Config { - size_t ioSize; size_t capacity; + size_t blockSize; int32_t deviceId; - Config(const size_t ioSize, const size_t capacity) - : ioSize{ioSize}, capacity{capacity}, deviceId{-1} + size_t streamNumber; + size_t timeoutMs; + Config(const size_t capacity, const size_t blockSize, const int32_t deviceId, const size_t streamNumber, const size_t timeoutMs) + : capacity{capacity}, blockSize{blockSize}, deviceId{deviceId}, streamNumber{streamNumber}, timeoutMs{timeoutMs} { } }; diff --git a/ucm/store/dramstore/cc/domain/trans/dram_trans_manager.cc b/ucm/store/dramstore/cc/domain/trans/dram_trans_manager.cc new file mode 100644 index 00000000..f9835612 --- /dev/null +++ b/ucm/store/dramstore/cc/domain/trans/dram_trans_manager.cc @@ -0,0 +1,42 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ + +#include "dram_trans_manager.h" + +namespace UC { + +Status DramTransManager::Setup(const int32_t deviceId, const size_t streamNumber, const MemoryPool* memPool, size_t timeoutMs) { + this->timeoutMs_ = timeoutMs; + auto status = Status::OK(); + for (size_t i = 0; i < streamNumber; i++) { + auto q = std::make_shared(); + status = + q->Setup(deviceId, &this->failureSet_, memPool, timeoutMs); + if (status.Failure()) { break; } + this->queues_.emplace_back(std::move(q)); + } + return status; +} + +} \ No newline at end of file diff --git a/ucm/store/dramstore/cc/domain/trans/dram_trans_manager.h b/ucm/store/dramstore/cc/domain/trans/dram_trans_manager.h new file mode 100644 index 00000000..7f9ef51b --- /dev/null +++ b/ucm/store/dramstore/cc/domain/trans/dram_trans_manager.h @@ -0,0 +1,39 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_DRAM_TRANS_MANAGER_H +#define UNIFIEDCACHE_DRAM_TRANS_MANAGER_H + +#include "task_manager.h" +#include "dram_trans_queue.h" + +namespace UC { + +class DramTransManager : public TaskManager { +public: + Status Setup(const int32_t deviceId, const size_t streamNumber, const MemoryPool* memPool, size_t timeoutMs); +}; + +} // namespace UC + +#endif diff --git a/ucm/store/dramstore/cc/domain/trans/dram_trans_queue.cc b/ucm/store/dramstore/cc/domain/trans/dram_trans_queue.cc new file mode 100644 index 00000000..cf7a3577 --- /dev/null +++ b/ucm/store/dramstore/cc/domain/trans/dram_trans_queue.cc @@ -0,0 +1,126 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ + +#include "dram_trans_queue.h" + +namespace UC { + +Status DramTransQueue::Setup(const int32_t deviceId, TaskSet* failureSet, + const MemoryPool* memPool, const size_t timeoutMs) { + this->deviceId_ = deviceId; + this->failureSet_ = failureSet; + this->memPool_ = memPool; + auto success = + this->backend_.SetWorkerInitFn([this](auto& device) { return this->Init(device); }) + .SetWorkerFn([this](auto& shards, const auto& device) { this->Work(shards, device); }) + .SetWorkerExitFn([this](auto& device) { this->Exit(device); }) + .Run(); + return success ? Status::OK() : Status::Error(); +} + +void DramTransQueue::Push(std::list& shards) noexcept { + this->backend_.Push(std::move(shards)); +} + +bool DramTransQueue::Init(Device& device) { + if (this->deviceId_ < 0) { return true; } + device = DeviceFactory::Make(this->deviceId_, 262144, 512); + if (!device) { + return false; + } + return device->Setup().Success(); +} + +void DramTransQueue::Exit(Device& device) { + device.reset(); +} + +void DramTransQueue::Work(std::list& shards, const Device& device) { + auto it = shards.begin(); + if (this->failureSet_->Contains(it->owner)) { + this->Done(shards, device, true); + } + auto status = Status::OK(); + if (it->type == Task::Type::DUMP) { + status = this->D2H(shards, device); + } else { + status = this->H2D(shards, device); + } + this->Done(shards, device, status.Success()); +} + +Status DramTransQueue::H2D(std::list& shards, const Device& device) { + size_t pool_offset = 0; + std::vector host_addrs(shards.size()); + std::vector device_addrs(shards.size()); + int shard_index = 0; + for (auto& shard : shards) { + bool found = this->memPool_->GetOffset(shard.block, &pool_offset); + if (!found) { + return Status::Error(); + } + auto host_addr = this->memPool_->GetStartAddr().get() + pool_offset + shard.offset; + auto device_addr = shard.address; + host_addrs[shard_index] = host_addr; + device_addrs[shard_index] = reinterpret_cast(device_addr); + shard_index++; + } + auto it = shards.begin(); + return device->H2DBatchSync(device_addrs.data(), const_cast(host_addrs.data()), shards.size(), it->length); +} + +Status DramTransQueue::D2H(std::list& shards, const Device& device) { + size_t pool_offset = 0; + std::vector host_addrs(shards.size()); + std::vector device_addrs(shards.size()); + int shard_index = 0; + for (auto& shard : shards) { + bool found = this->memPool_->GetOffset(shard.block, &pool_offset); + if (!found) { + return Status::Error(); + } + auto host_addr = this->memPool_->GetStartAddr().get() + pool_offset + shard.offset; + auto device_addr = shard.address; + host_addrs[shard_index] = host_addr; + device_addrs[shard_index] = reinterpret_cast(device_addr); + shard_index++; + } + auto it = shards.begin(); + return device->D2HBatchSync(host_addrs.data(), const_cast(device_addrs.data()), shards.size(), it->length); +} + +void DramTransQueue::Done(std::list& shards, const Device& device, const bool success) { + auto it = shards.begin(); + if (!success) { this->failureSet_->Insert(it->owner); } + for (auto& shard : shards) { + if (shard.done) { + if (device) { + if (device->Synchronized().Failure()) { this->failureSet_->Insert(shard.owner); } + } + shard.done(); + } + } +} + +} // namespace UC \ No newline at end of file diff --git a/ucm/store/dramstore/cc/domain/trans/dram_trans_queue.h b/ucm/store/dramstore/cc/domain/trans/dram_trans_queue.h new file mode 100644 index 00000000..72350709 --- /dev/null +++ b/ucm/store/dramstore/cc/domain/trans/dram_trans_queue.h @@ -0,0 +1,61 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_DRAM_TRANS_QUEUE_H +#define UNIFIEDCACHE_DRAM_TRANS_QUEUE_H + +#include "device/idevice.h" +#include "status/status.h" +#include "task_queue.h" +#include "task_set.h" +#include "thread/thread_pool.h" +#include "memory/memory_pool.h" + +namespace UC { + +class DramTransQueue : public TaskQueue { + using Device = std::unique_ptr; + int32_t deviceId_{-1}; + TaskSet* failureSet_{nullptr}; + const MemoryPool* memPool_{nullptr}; + ThreadPool, Device> backend_{}; + +public: + Status Setup(const int32_t deviceId, + TaskSet* failureSet, + const MemoryPool* memPool, + const size_t timeoutMs); + void Push(std::list& shards) noexcept override; + +private: + bool Init(Device& device); + void Exit(Device& device); + void Work(std::list& shards, const Device& device); + void Done(std::list& shards, const Device& device, const bool success); + Status H2D(std::list& shards, const Device& device); + Status D2H(std::list& shards, const Device& device); +}; + +} // namespace UC + +#endif diff --git a/ucm/store/dramstore/cpy/dramstore.py.cc b/ucm/store/dramstore/cpy/dramstore.py.cc index 5b50748f..cb76d5d1 100644 --- a/ucm/store/dramstore/cpy/dramstore.py.cc +++ b/ucm/store/dramstore/cpy/dramstore.py.cc @@ -21,7 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * */ -#include "api/dramstore.h" +#include "dramstore.h" #include namespace py = pybind11; @@ -99,10 +99,13 @@ PYBIND11_MODULE(ucmdramstore, module) module.attr("build_type") = UCM_BUILD_TYPE; auto store = py::class_(module, "DRAMStore"); auto config = py::class_(store, "Config"); - config.def(py::init(), py::arg("ioSize"), py::arg("capacity")); - config.def_readwrite("ioSize", &UC::DRAMStorePy::Config::ioSize); + config.def(py::init(), + py::arg("capacity"), py::arg("blockSize"), py::arg("deviceId"), py::arg("streamNumber"), py::arg("timeoutMs")); config.def_readwrite("capacity", &UC::DRAMStorePy::Config::capacity); + config.def_readwrite("blockSize", &UC::DRAMStorePy::Config::blockSize); config.def_readwrite("deviceId", &UC::DRAMStorePy::Config::deviceId); + config.def_readwrite("streamNumber", &UC::DRAMStorePy::Config::streamNumber); + config.def_readwrite("timeoutMs", &UC::DRAMStorePy::Config::timeoutMs); store.def(py::init<>()); store.def("CCStoreImpl", &UC::DRAMStorePy::CCStoreImpl); store.def("Setup", &UC::DRAMStorePy::Setup); diff --git a/ucm/store/dramstore/dramstore_connector.py b/ucm/store/dramstore/dramstore_connector.py index e18306b5..1c525aff 100644 --- a/ucm/store/dramstore/dramstore_connector.py +++ b/ucm/store/dramstore/dramstore_connector.py @@ -28,6 +28,7 @@ import torch from ucm.logger import init_logger +from ucm.store.dramstore import ucmdramstore from ucm.store.ucmstore import Task, UcmKVStoreBase logger = init_logger(__name__) @@ -48,7 +49,7 @@ @dataclass class DramTask(Task): - task_id: str = "1" + task_id: int event: Optional[Any] = None @@ -59,111 +60,50 @@ class UcmDramStore(UcmKVStoreBase): def __init__(self, config: Dict): super().__init__(config) - self.dram_cache: Dict[str, any] = {} - self.max_cache_byte = int(config.get("max_cache_size", 5368709120)) - self.kv_block_size = int(config.get("kv_block_size", 262144)) - self.max_block_num = self.max_cache_byte // self.kv_block_size - if config["role"] == "scheduler": - self.cached_blocks = set() + self.store = ucmdramstore.DRAMStore() - def cc_store(self) -> int: - """ - get the underlying implementation of Store - - Returns: - cc pointer to Store - """ - return 0 + capacity = int(config.get("capacity", 1073741824)) # Default 1GB + block_size = int(config.get("kv_block_size", 262144)) # Default 256KB + device_id = int(config.get("device_id", -1)) + stream_number = int(config.get("stream_number", 32)) + timeout_ms = int(config.get("timeout_ms", 30000)) - def create(self, block_ids: List[str]) -> List[int]: - """ - create kv cache space in storage + param = ucmdramstore.DRAMStore.Config( + capacity, block_size, device_id, stream_number, timeout_ms + ) - Args: - block_ids (List[str]): vLLM block hash. - Returns: - success mask - """ - return [SUCCESS] * len(block_ids) + ret = self.store.Setup(param) + if ret != 0: + msg = f"Failed to initialize ucmdramstore, errcode: {ret}." + raise RuntimeError(msg) - def lookup(self, block_ids: List[str]) -> List[bool]: - """ - Get number of blocks that can be loaded from the - external KV cache. + def cc_store(self) -> int: + return self.store.CCStoreImpl() - Args: - block_ids (List[str]): vLLM block hash. + def create(self, block_ids: List[str]) -> List[int]: + return self.store.AllocBatch(block_ids) - Returns: - hit block mask, True -> hit - """ - hit_list = [block_id in self.cached_blocks for block_id in block_ids] - return hit_list + def lookup(self, block_ids: List[str]) -> List[bool]: + return self.store.LookupBatch(block_ids) def prefetch(self, block_ids: List[str]) -> None: - """ - prefetch kv cache to high speed cache according to block_ids. - - Args: - block_ids (List[str]): vLLM block hash. - """ pass def load( self, block_ids: List[str], offset: List[int], dst_tensor: List[torch.Tensor] ) -> Task: - """ - load kv cache to device. - - Args: - block_ids (List[str]): vLLM block hash. - offset(List[int]): tp > 1 scene - dst_tensor: List[torch.Tensor]: device tensor addr. - Returns: - task(Task). - """ - task = DramTask() - stream = device.Stream() - task.event = device.Event(enable_timing=True) - with device.stream(stream): - for i, block_id in enumerate(block_ids): - key = block_id + "_" + str(offset[i]) - dst_tensor[i].copy_(self.dram_cache[key], non_blocking=True) - task.event.record(stream=stream) - logger.debug(f"load block {block_ids} finished.") - return task + dst_tensor_ptr = [t.data_ptr() for t in dst_tensor] + dst_tensor_size = [t.numel() * t.element_size() for t in dst_tensor] + task_id = self.store.Load(block_ids, offset, dst_tensor_ptr, dst_tensor_size) + return DramTask(task_id=task_id) def dump( self, block_ids: List[str], offset: List[int], src_tensor: List[torch.Tensor] ) -> Task: - """ - dump kv cache to device. - - Args: - block_ids (List[str]): vLLM block hash. - offset(List[int]): tp > 1 scene - src_tensor: List[torch.Tensor]: device tensor addr. - Returns: - task(Task). - """ - task = DramTask() - if len(self.dram_cache) > self.max_block_num: - logger.warning( - "Dram cache usage exceeds limit! No more kv cache offload! Try to increase your initial max_cache_size." - ) - task.task_id = "-1" - return task - else: - device.current_stream().synchronize() - stream = device.Stream() - task.event = device.Event(enable_timing=True) - with device.stream(stream): - for i, block_id in enumerate(block_ids): - key = block_id + "_" + str(offset[i]) - self.dram_cache[key] = src_tensor[i].to("cpu", non_blocking=True) - task.event.record(stream=stream) - logger.debug(f"dump block {block_ids} finished.") - return task + src_tensor_ptr = [t.data_ptr() for t in src_tensor] + src_tensor_size = [t.numel() * t.element_size() for t in src_tensor] + task_id = self.store.Dump(block_ids, offset, src_tensor_ptr, src_tensor_size) + return DramTask(task_id=task_id) def fetch_data( self, @@ -172,9 +112,8 @@ def fetch_data( dst_addr: List[int], size: List[int], ) -> Task: - raise NotImplementedError( - "Method(fetch_data) not yet implemented in this version" - ) + task_id = self.store.Load(block_ids, offset, dst_addr, size) + return DramTask(task_id=task_id) def dump_data( self, @@ -183,50 +122,14 @@ def dump_data( src_addr: List[int], size: List[int], ) -> Task: - raise NotImplementedError( - "Method(dump_data) not yet implemented in this version" - ) + task_id = self.store.Dump(block_ids, offset, src_addr, size) + return DramTask(task_id=task_id) def wait(self, task: DramTask) -> int: - """ - wait kv cache kv transfer task finished. - - Args: - task (Task): transfer engine task. - Returns: - 0 - success - others - failed. - """ - if task.task_id == "-1": - logger.warning("Dump failure with full cache usage!") - return FAILURE - try: - event = task.event - event.synchronize() - return SUCCESS - except Exception as e: - logger.error(f"Error waiting cache for block IDs: {e}") - return FAILURE + return self.store.Wait(task.task_id) def commit(self, block_ids: List[str], is_success: bool = True) -> None: - """ - commit kv cache, now kv cache can be reused. - - Args: - block_ids (List[str]): vLLM block hash. - is_success(bool): if False, we need release block - """ - if is_success: - self.cached_blocks.update(block_ids) + self.store.CommitBatch(block_ids, is_success) def check(self, task: Task) -> int: - """ - check if kv transfer task finished. - - Args: - task (Task): transfer engine task. - Returns: - 0 - finished - others - in process. - """ - pass + return self.store.Check(task.task_id) diff --git a/ucm/store/infra/memory/memory_pool.h b/ucm/store/infra/memory/memory_pool.h new file mode 100644 index 00000000..200d1286 --- /dev/null +++ b/ucm/store/infra/memory/memory_pool.h @@ -0,0 +1,174 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#ifndef UNIFIEDCACHE_MEMORY_POOL_H +#define UNIFIEDCACHE_MEMORY_POOL_H + +#include +#include +#include +#include +#include +#include +#include "status/status.h" +#include "device/idevice.h" +#include +#include +#include +#include "logger/logger.h" + +namespace UC { + +class MemoryPool { + + std::string DUMMY_SLOT_PREFIX{"__slot_"}; + using Device = std::unique_ptr; +public: + + Status Setup(int32_t deviceId, size_t capacity, size_t blockSize) { + capacity_ = capacity; + blockSize_ = blockSize; + device_ = DeviceFactory::Make(deviceId, blockSize, static_cast(capacity / blockSize)); + if (!device_) { + UC_ERROR("MemoryPool: failed to create device"); + return Status::Error(); + } + Status status = device_->Setup(); + if (!status.Success()) { + UC_ERROR("MemoryPool: failed to set up device"); + return Status::Error(); + } + pool_ = device_->GetBuffer(capacity_); + if (!pool_) { + UC_ERROR("MemoryPool: failed to get pool memory space"); + return Status::Error(); + } + + size_t slotNum = capacity_ / blockSize_; + for (size_t i = 0; i < slotNum; ++i) { + std::string dummy = DUMMY_SLOT_PREFIX + std::to_string(i); + size_t offset = i * blockSize_; + lruList_.push_front(dummy); + lruIndex_[dummy] = lruList_.begin(); + offsetMap_[dummy] = offset; + } + return Status::OK(); + + } + + Status NewBlock(const std::string& blockId) { + if (offsetMap_.count(blockId)) { + return Status::DuplicateKey(); + } + if (lruList_.empty()) { + // 所有空间里的块都正在写,那么就不能够分配 + return Status::Error(); + } + size_t offset = LRUEvictOne(); + offsetMap_[blockId] = offset; + return Status::OK(); + } + + bool LookupBlock(const std::string& blockId) const { + return availableBlocks_.count(blockId); + } + + bool GetOffset(const std::string& blockId, size_t* offset) const { + auto it = offsetMap_.find(blockId); + if (it == offsetMap_.end()) { + return false; + } + *offset = it->second; + return true; + } + + Status CommitBlock(const std::string& blockId, bool success) { + if (success) { + availableBlocks_.insert(blockId); + touchUnsafe(blockId); + } else { + resetSpaceOfBlock(blockId); + } + return Status::OK(); + } + + std::shared_ptr GetStartAddr() const { + return pool_; + } + +private: + std::shared_ptr pool_ = nullptr; + Device device_ = nullptr; + size_t capacity_; + size_t blockSize_; + + std::unordered_map offsetMap_; + std::set availableBlocks_; + + using ListType = std::list; + ListType lruList_; + std::unordered_map lruIndex_; + + void touchUnsafe(const std::string& blockId) { + auto it = lruIndex_.find(blockId); + if (it != lruIndex_.end()) { + lruList_.splice(lruList_.begin(), lruList_, it->second); + } + else { + lruList_.push_front(blockId); // 访问一次,该块就是最近使用了的,所以放到LRU队列的头部。这就是一般LRU的逻辑 + lruIndex_[blockId] = lruList_.begin(); + } + } + + size_t LRUEvictOne() { + const std::string& victim = lruList_.back(); + // 真实数据块,才从availableBlocks_中删掉 + if (victim.rfind(DUMMY_SLOT_PREFIX, 0) != 0) { + availableBlocks_.erase(victim); + } + size_t offset = offsetMap_[victim]; + offsetMap_.erase(victim); + lruIndex_.erase(victim); + lruList_.pop_back(); + return offset; + } + + void resetSpaceOfBlock(const std::string& blockId) { + auto it = offsetMap_.find(blockId); + size_t offset = it->second; + std::string dummy = DUMMY_SLOT_PREFIX + std::to_string(offset / blockSize_); + offsetMap_.erase(blockId); + + auto lit = lruIndex_.find(blockId); + if (lit != lruIndex_.end()) { + lruList_.erase(lit->second); + lruIndex_.erase(lit); + } + lruList_.push_back(dummy); // 将一个块commit false后,回收之前分配的内存,并且要将其放到LRU队列的尾部(下次可以写的时候,要马上就写。因为该块的优先级高于已经写了的块) + lruIndex_[dummy] = std::prev(lruList_.end()); + offsetMap_[dummy] = offset; + } +}; + +} // namespace UC +#endif \ No newline at end of file diff --git a/ucm/store/infra/thread/thread_pool.h b/ucm/store/infra/thread/thread_pool.h index 0aa42672..c33a0c28 100644 --- a/ucm/store/infra/thread/thread_pool.h +++ b/ucm/store/infra/thread/thread_pool.h @@ -99,7 +99,7 @@ class ThreadPool { void Push(Task&& task) noexcept { std::unique_lock lk(this->mtx_); - this->taskQ_.push_back(task); + this->taskQ_.push_back(std::move(task)); this->cv_.notify_one(); } diff --git a/ucm/store/test/case/infra/mem_pool_test.cc b/ucm/store/test/case/infra/mem_pool_test.cc new file mode 100644 index 00000000..f9ea0438 --- /dev/null +++ b/ucm/store/test/case/infra/mem_pool_test.cc @@ -0,0 +1,169 @@ +/** + * MIT License + * + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ + +#include "infra/memory/memory_pool.h" +#include + +class UCMemoryPoolTest : public ::testing::Test {}; + +TEST_F(UCMemoryPoolTest, NewBlockAllocateAndCommit) +{ + UC::MemoryPool memPool; // 初始化内存池 + ASSERT_EQ(memPool.Setup(-1, 10, 2), UC::Status::OK()); + const std::string block1 = "block1"; + size_t offset = 10; + ASSERT_FALSE(memPool.LookupBlock(block1)); + // ASSERT_EQ(memPool.GetOffset(block1), nullptr); + ASSERT_EQ(memPool.GetOffset(block1, &offset), false); + ASSERT_EQ(memPool.NewBlock(block1), UC::Status::OK()); + ASSERT_FALSE(memPool.LookupBlock(block1)); + // ASSERT_NE(memPool.GetOffset(block1), nullptr); + ASSERT_EQ(memPool.GetOffset(block1, &offset), true); + ASSERT_EQ(memPool.NewBlock(block1), UC::Status::DuplicateKey()); + ASSERT_EQ(memPool.CommitBlock(block1, true), UC::Status::OK()); + ASSERT_TRUE(memPool.LookupBlock(block1)); +} + +TEST_F(UCMemoryPoolTest, EvictOldBlock) +{ + UC::MemoryPool memPool; // 初始化内存池 + ASSERT_EQ(memPool.Setup(-1, 10, 5), UC::Status::OK()); + const std::string block1 = "block1"; + const std::string block2 = "block2"; + const std::string block3 = "block3"; + size_t offset = 10; + ASSERT_EQ(memPool.NewBlock(block1), UC::Status::OK()); + // ASSERT_NE(memPool.GetOffset(block1), nullptr); + ASSERT_EQ(memPool.GetOffset(block1, &offset), true); + ASSERT_EQ(memPool.NewBlock(block2), UC::Status::OK()); + // ASSERT_NE(memPool.GetOffset(block2), nullptr); + ASSERT_EQ(memPool.GetOffset(block2, &offset), true); + memPool.CommitBlock(block1, true); + memPool.CommitBlock(block2, true); + ASSERT_EQ(memPool.NewBlock(block3), UC::Status::OK()); + // ASSERT_NE(memPool.GetOffset(block3), nullptr); + ASSERT_EQ(memPool.GetOffset(block3, &offset), true); + // ASSERT_EQ(memPool.GetOffset(block1), nullptr); + ASSERT_EQ(memPool.GetOffset(block1, &offset), false); + // ASSERT_NE(memPool.GetOffset(block2), nullptr); + ASSERT_EQ(memPool.GetOffset(block2, &offset), true); + ASSERT_FALSE(memPool.LookupBlock(block1)); + ASSERT_TRUE(memPool.LookupBlock(block2)); + ASSERT_FALSE(memPool.LookupBlock(block3)); +} + +TEST_F(UCMemoryPoolTest, OldBlockCommitFalse) +{ + UC::MemoryPool memPool; // 初始化内存池 + ASSERT_EQ(memPool.Setup(-1, 32, 8), UC::Status::OK()); + const std::string block1 = "block1"; + const std::string block2 = "block2"; + const std::string block3 = "block3"; + const std::string block4 = "block4"; + const std::string block5 = "block5"; + size_t offset = 32; + ASSERT_EQ(memPool.NewBlock(block1), UC::Status::OK()); + // ASSERT_NE(memPool.GetOffset(block1), nullptr); + ASSERT_EQ(memPool.GetOffset(block1, &offset), true); + ASSERT_EQ(memPool.NewBlock(block2), UC::Status::OK()); + // ASSERT_NE(memPool.GetOffset(block2), nullptr); + ASSERT_EQ(memPool.GetOffset(block2, &offset), true); + ASSERT_EQ(memPool.NewBlock(block3), UC::Status::OK()); + // ASSERT_NE(memPool.GetOffset(block3), nullptr); + ASSERT_EQ(memPool.GetOffset(block3, &offset), true); + memPool.CommitBlock(block1, true); + memPool.CommitBlock(block2, false); + ASSERT_TRUE(memPool.LookupBlock(block1)); + ASSERT_FALSE(memPool.LookupBlock(block2)); + ASSERT_FALSE(memPool.LookupBlock(block3)); + ASSERT_EQ(memPool.NewBlock(block4), UC::Status::OK()); + // ASSERT_EQ(memPool.GetOffset(block4), 8); + ASSERT_EQ(memPool.GetOffset(block4, &offset), true); + ASSERT_EQ(offset, 8); + ASSERT_EQ(memPool.NewBlock(block5), UC::Status::OK()); + // ASSERT_EQ(memPool.GetOffset(block5), 24); + ASSERT_EQ(memPool.GetOffset(block5, &offset), true); + ASSERT_EQ(offset, 24); + memPool.CommitBlock(block3, true); + memPool.CommitBlock(block4, true); + memPool.CommitBlock(block5, true); + ASSERT_TRUE(memPool.LookupBlock(block1)); + ASSERT_FALSE(memPool.LookupBlock(block2)); + ASSERT_TRUE(memPool.LookupBlock(block3)); + ASSERT_TRUE(memPool.LookupBlock(block4)); + ASSERT_TRUE(memPool.LookupBlock(block5)); + + ASSERT_EQ(memPool.NewBlock(block1), UC::Status::DuplicateKey()); + ASSERT_EQ(memPool.NewBlock(block2), UC::Status::OK()); + // ASSERT_EQ(memPool.GetOffset(block2), 0); + ASSERT_EQ(memPool.GetOffset(block2, &offset), true); + ASSERT_EQ(offset, 0); + ASSERT_FALSE(memPool.LookupBlock(block1)); + ASSERT_FALSE(memPool.LookupBlock(block2)); + memPool.CommitBlock(block2, true); + ASSERT_TRUE(memPool.LookupBlock(block2)); +} + +TEST_F(UCMemoryPoolTest, NoCommittedBlock) +{ + UC::MemoryPool memPool; // 初始化内存池 + ASSERT_EQ(memPool.Setup(-1, 32, 8), UC::Status::OK()); + const std::string block1 = "block1"; + const std::string block2 = "block2"; + const std::string block3 = "block3"; + const std::string block4 = "block4"; + const std::string block5 = "block5"; + const std::string block6 = "block6"; + size_t offset = 32; + ASSERT_EQ(memPool.NewBlock(block1), UC::Status::OK()); + ASSERT_EQ(memPool.NewBlock(block2), UC::Status::OK()); + ASSERT_EQ(memPool.NewBlock(block3), UC::Status::OK()); + ASSERT_EQ(memPool.NewBlock(block4), UC::Status::OK()); + ASSERT_EQ(memPool.NewBlock(block5), UC::Status::Error()); + memPool.CommitBlock(block1, true); + ASSERT_TRUE(memPool.LookupBlock(block1)); + ASSERT_EQ(memPool.NewBlock(block5), UC::Status::OK()); + // ASSERT_EQ(memPool.GetOffset(block5), 0); + ASSERT_EQ(memPool.GetOffset(block5, &offset), true); + ASSERT_EQ(offset, 0); + ASSERT_FALSE(memPool.LookupBlock(block1)); + ASSERT_EQ(memPool.NewBlock(block6), UC::Status::Error()); + // ASSERT_EQ(memPool.GetOffset(block2), 8); + ASSERT_EQ(memPool.GetOffset(block2, &offset), true); + ASSERT_EQ(offset, 8); + memPool.CommitBlock(block2, false); + // ASSERT_EQ(memPool.GetOffset((block2)), nullptr); + ASSERT_EQ(memPool.GetOffset(block2, &offset), false); + ASSERT_FALSE(memPool.LookupBlock(block1)); + ASSERT_EQ(memPool.NewBlock(block6), UC::Status::OK()); + // ASSERT_EQ(memPool.GetOffset(block6), 8); + ASSERT_EQ(memPool.GetOffset(block6, &offset), true); + ASSERT_EQ(offset, 8); + ASSERT_FALSE(memPool.LookupBlock(block6)); + memPool.CommitBlock(block6, true); + ASSERT_TRUE(memPool.LookupBlock(block6)); + // ASSERT_EQ(memPool.GetOffset(block6), 8); + ASSERT_EQ(memPool.GetOffset(block6, &offset), true); + ASSERT_EQ(offset, 8); +} \ No newline at end of file diff --git a/ucm/store/test/e2e/dramstore_embed_and_fetch.py b/ucm/store/test/e2e/dramstore_embed_and_fetch.py new file mode 100644 index 00000000..4f9acda1 --- /dev/null +++ b/ucm/store/test/e2e/dramstore_embed_and_fetch.py @@ -0,0 +1,142 @@ +# -*- coding: utf-8 -*- +# +# MIT License +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +import os +import secrets +from typing import List + +import torch + +from ucm.store.dramstore.dramstore_connector import UcmDramStore +from ucm.store.ucmstore import UcmKVStoreBase + + +def setup_store( + capacity, block_size, stream_number, device_id, timeout_ms +) -> UcmKVStoreBase: + config = {} + config["capacity"] = capacity + config["kv_block_size"] = block_size + config["stream_number"] = stream_number + config["device_id"] = device_id + config["timeout_ms"] = timeout_ms + return UcmDramStore(config) + + +def make_buffers( + block_number, device_id, batch_size, block_dim, block_len, block_layer +): + hashes = [secrets.token_hex(16) for _ in range(block_number)] + tensors = [ + [ + torch.rand( + [block_dim, block_len], + dtype=torch.bfloat16, + device="cuda:{}".format(device_id), + ) + for _ in range(block_layer) + ] + for _ in range(batch_size) + ] + return hashes, tensors + + +def embed(store: UcmKVStoreBase, hashes: List[str], tensors: List[List[torch.Tensor]]): + results = store.create(hashes) + assert sum(results) == 0 + block_ids = [] + offsets = [] + layers = [] + for hash_id, block in zip(hashes, tensors): + offset = 0 + for layer in block: + block_ids.append(hash_id) + offsets.append(offset) + layers.append(layer) + offset += layer.untyped_storage().size() + task = store.dump(block_ids, offsets, layers) + assert task.task_id > 0 + ret = store.wait(task) + assert ret == 0 + store.commit(hashes, True) + + +def fetch(store: UcmKVStoreBase, hashes: List[str], tensors: List[List[torch.Tensor]]): + founds = store.lookup(hashes) + for found in founds: + assert found + block_ids = [] + offsets = [] + layers = [] + for hash_id, block in zip(hashes, tensors): + offset = 0 + for layer in block: + block_ids.append(hash_id) + offsets.append(offset) + layers.append(layer) + offset += layer.untyped_storage().size() + task = store.load(block_ids, offsets, layers) + assert task.task_id > 0 + ret = store.wait(task) + assert ret == 0 + + +def main(): + block_number = 4096 + device_id = 1 + block_dim = 576 + block_len = 128 + block_elem_size = 2 + block_layer = 61 + io_size = block_dim * block_len * block_elem_size + block_size = io_size * block_layer + batch_size = 256 + stream_number = 10 + timeout_ms = 1000000 + capacity = block_number * block_size * 2 + batch_number = 64 + + store = setup_store(capacity, block_size, stream_number, device_id, timeout_ms) + hashes, tensors = make_buffers( + block_number, device_id, batch_size, block_dim, block_len, block_layer + ) + total_batches = (block_number + batch_size - 1) // batch_size + + for batch in range(total_batches): + start = batch_size * batch + end = min(start + batch_size, block_number) + embed(store, hashes[start:end], tensors) + + _, new_tensors = make_buffers( + block_number, device_id, batch_size, block_dim, block_len, block_layer + ) + for batch in range(total_batches): + start = batch_size * batch + end = start + batch_size + fetch(store, hashes[start:end], new_tensors) + + +if __name__ == "__main__": + os.environ["UC_LOGGER_LEVEL"] = "debug" + main()