diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c481f1ff6..573ff399c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,8 +7,7 @@ repos: args: [--line-length=120] additional_dependencies: ['click==8.0.4'] - repo: https://github.com/pycqa/flake8 - rev: 3.9.0 + rev: 6.1.0 hooks: - id: flake8 - additional_dependencies: [flake8-typing-imports==1.9.0] - args: ['--config=.flake8', '--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606, E231'] \ No newline at end of file + args: ['--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606, E231'] diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index 28c4ceb1e..2bf0a4d5a 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -342,6 +342,64 @@ def evict(self, need_remove_tokens, evict_callback): return + def _try_merge(self, child_node: TreeNode) -> Optional[TreeNode]: + """ + 合并条件: + 1. 父节点不是根节点。 + 2. 父节点的引用计数为 0。 + 3. 子节点的引用计数为 0。 + 4. 父节点只有一个子节点 (即 child_node)。 + """ + parent_node = child_node.parent + # 条件检查 + if ( + parent_node is None + or parent_node == self.root_node + or parent_node.ref_counter != 0 + or len(parent_node.children) != 1 + or child_node.ref_counter != 0 + ): + return None + + if child_node.is_leaf(): + self.evict_tree_set.discard(child_node) + + child_node.token_id_key = torch.cat([parent_node.token_id_key, child_node.token_id_key]) + child_node.token_mem_index_value = torch.cat( + [parent_node.token_mem_index_value, child_node.token_mem_index_value] + ) + child_node.node_value_len = len(child_node.token_mem_index_value) + child_node.time_id = max(parent_node.time_id, child_node.time_id) + + grandparent_node = parent_node.parent + key_in_grandparent = parent_node.token_id_key[0].item() + grandparent_node.children[key_in_grandparent] = child_node + child_node.parent = grandparent_node + + parent_node.parent = None + + if child_node.is_leaf(): + self.evict_tree_set.add(child_node) + + return child_node + + def merge_unreferenced_nodes(self): + worklist = collections.deque( + [ + node + for node in self.evict_tree_set + if node.ref_counter == 0 and node.parent is not None and node.parent != self.root_node + ] + ) + + while worklist: + node = worklist.popleft() + if node.parent is None: + continue + merged_node = self._try_merge(node) + if merged_node: + worklist.append(merged_node) + def assert_leafs_is_right(self): for node in self.evict_tree_set: if node.is_leaf() and node.ref_counter == 0: diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index b6cb4d21f..95f0c9951 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -26,7 +26,11 @@ from lightllm.utils.dist_utils import get_dp_world_size, get_global_dp_rank, get_current_rank_in_dp from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_node, get_node_world_size from lightllm.utils.dist_utils import get_dp_rank_in_node, create_new_group_for_current_node -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import ( + get_env_start_args, + enable_radix_tree_timer_merge, + get_radix_tree_merge_update_delta, +) from lightllm.distributed import dist_group_manager from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack @@ -61,6 +65,11 @@ def __init__(self) -> None: # nixl pd mode callback func self.nixl_prefill_chuncked_handle_func: Optional[Callable[[InferReq, int, float, int], None]] = None + + # counter + self._radix_tree_merge_counter: int = 0 + self._enable_radix_tree_timer_merge: bool = enable_radix_tree_timer_merge() + self._radix_tree_merge_update_delta: int = get_radix_tree_merge_update_delta() pass def init_model(self, kvargs): @@ -439,6 +448,22 @@ def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]: """ return [g_infer_context.requests_mapping[request_id] for request_id in req_ids] + def _timer_merge_radix_tree(self): + self._radix_tree_merge_counter += 1 + if ( + self._enable_radix_tree_timer_merge + and (self._radix_tree_merge_counter % self._radix_tree_merge_update_delta == 0) + and self.radix_cache is not None + ): + g_infer_state_lock.acquire() + start = time.time() + self.radix_cache.merge_unreferenced_nodes() + self.logger.info( + f"radix tree merge_unreferenced_nodes cost time {time.time() - start} s in rank {self.global_rank}" + ) + g_infer_state_lock.release() + return + # 一些可以复用的通用功能函数 def _get_classed_reqs( self, @@ -465,6 +490,9 @@ def _get_classed_reqs( 4. prefill_reqs 需要进行prefill操作的请求 5. decode_reqs 需要进行decode操作的请求 """ + # 定期对 radix cache 进行 merge,防止查询插入的操作效率下降 + self._timer_merge_radix_tree() + if self.args.enable_cpu_cache and len(g_infer_context.infer_req_ids) > 0: self.multi_level_cache_module.update_cpu_cache_task_states() diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 2f795aa23..7c221c574 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -181,3 +181,16 @@ def disable_cpu_kvcache_sync() -> bool: 实验用环境遍历,未来可能会移除 """ return enable_env_vars("LIGHTLLM_DISABLE_CPU_CACHE_SYNC") + + +@lru_cache(maxsize=None) +def enable_radix_tree_timer_merge() -> bool: + """ + 使能定期合并 radix tree的叶节点, 防止插入查询性能下降。 + """ + return enable_env_vars("LIGHTLLM_RADIX_TREE_MERGE_ENABLE") + + +@lru_cache(maxsize=None) +def get_radix_tree_merge_update_delta() -> int: + return int(os.getenv("LIGHTLMM_RADIX_TREE_MERGE_DELTA", 6000)) diff --git a/unit_tests/server/router/dynamic_prompt/test_radix_cache.py b/unit_tests/server/router/dynamic_prompt/test_radix_cache.py index 505cbbc1c..605433e9d 100644 --- a/unit_tests/server/router/dynamic_prompt/test_radix_cache.py +++ b/unit_tests/server/router/dynamic_prompt/test_radix_cache.py @@ -91,5 +91,144 @@ def test_case4(): return +def test_case5(): + """ + 测试场景:一个简单的父子节点链 (A -> B),在 ref_counter 都为 0 时,应该成功合并。 + """ + print("\nTest Case 5: Merging simple parent-child nodes when ref_counter is 0\n") + tree = RadixCache("unique_name", 100, 0) + + _, node_a = tree.insert(torch.tensor([1, 2, 3], dtype=torch.int64)) + _, node_b = tree.insert(torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64)) + tree.print_self() + + # 验证初始状态:A -> B 结构,且 ref_counter 均为 0 + assert node_b.parent == node_a + assert torch.equal(node_a.token_id_key, torch.tensor([1, 2, 3], dtype=torch.int64)) + assert len(node_a.children) == 1 + assert node_a.ref_counter == 0 + assert node_b.ref_counter == 0 + assert tree.get_tree_total_tokens_num() == 5 + + # 执行合并 + tree.merge_unreferenced_nodes() + tree.print_self() + + assert torch.equal(node_b.token_id_key, torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64)) + assert node_b.is_leaf() + assert tree.get_tree_total_tokens_num() == 5 + assert tree.root_node.children[1] is node_b + + +def test_case6(): + """ + 测试场景:一个长的节点链 (A -> B -> C),在 ref_counter 都为 0 时,应该级联合并成一个节点。 + """ + print("\nTest Case 6: Merging long nodes when ref_counter is 0\n") + tree = RadixCache("unique_name", 100, 0) + _, node_a = tree.insert(torch.tensor([1], dtype=torch.int64)) + _, node_b = tree.insert(torch.tensor([1, 2], dtype=torch.int64)) + _, node_c = tree.insert(torch.tensor([1, 2, 3, 4], dtype=torch.int64)) + tree.print_self() + + assert node_c.parent == node_b + assert node_b.parent == node_a + assert tree.get_tree_total_tokens_num() == 4 + tree.merge_unreferenced_nodes() + tree.print_self() + + assert len(tree.root_node.children) == 1 + # 节点 C 的 key 应该是完整的 [1, 2, 3, 4] + assert torch.equal(node_c.token_id_key, torch.tensor([1, 2, 3, 4], dtype=torch.int64)) + assert node_c.is_leaf() + assert tree.get_tree_total_tokens_num() == 4 + + +def test_case7(): + """ + 测试场景:由于父节点或子节点的 ref_counter > 0,合并不应该发生。 + """ + print("\nTest Case 7: Merging when parent or child ref_counter > 0\n") + tree = RadixCache("unique_name", 100, 0) + + _, node_a = tree.insert(torch.tensor([1, 2, 3], dtype=torch.int64)) + _, node_b = tree.insert(torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64)) + tree.print_self() + + matched_node, _, _ = tree.match_prefix(torch.tensor([1, 2, 3], dtype=torch.int64), update_refs=True) + assert matched_node is node_a + assert node_a.ref_counter == 1 + assert node_b.ref_counter == 0 + + tree.merge_unreferenced_nodes() + tree.print_self() + + assert torch.equal(node_a.token_id_key, torch.tensor([1, 2, 3], dtype=torch.int64)) + assert not node_a.is_leaf() + assert node_b.parent is node_a + + +def test_case8(): + """ + 测试场景:由于父节点有多个子节点,合并不应该发生。 + """ + print("\nTest Case 8: Merging when parent has multiple children\n") + tree = RadixCache("unique_name", 100, 0) + + _, node_a = tree.insert(torch.tensor([1, 2], dtype=torch.int64)) + _, node_b = tree.insert(torch.tensor([1, 2, 3], dtype=torch.int64)) + _, node_c = tree.insert(torch.tensor([1, 2, 4], dtype=torch.int64)) + tree.print_self() + + assert len(node_a.children) == 2 + assert node_a.ref_counter == 0 + assert node_b.ref_counter == 0 + assert node_c.ref_counter == 0 + + tree.merge_unreferenced_nodes() + tree.print_self() + + assert len(node_a.children) == 2 + assert torch.equal(node_a.token_id_key, torch.tensor([1, 2], dtype=torch.int64)) + assert tree.root_node.children[1].children[3] is node_b + assert tree.root_node.children[1].children[4] is node_c + + +def test_case9(): + """ + 测试场景:在一个复杂的树中,只有满足条件的分支被合并。 + """ + print("\nTest Case 9: Merging in a complex tree with mixed conditions\n") + tree = RadixCache("unique_name", 100, 0) + + # 分支1: 可合并的链 A -> B + _, node_a = tree.insert(torch.tensor([1, 2], dtype=torch.int64)) + _, node_b = tree.insert(torch.tensor([1, 2, 3], dtype=torch.int64)) + + # 分支2: 不可合并的链 C -> D (因为 C 被引用) + _, node_c = tree.insert(torch.tensor([4, 5], dtype=torch.int64)) + _, node_d = tree.insert(torch.tensor([4, 5, 6], dtype=torch.int64)) + + # 增加 C 的引用计数 + tree.match_prefix(torch.tensor([4, 5], dtype=torch.int64), update_refs=True) + assert node_c.ref_counter == 1 + tree.print_self() + + tree.merge_unreferenced_nodes() + tree.print_self() + + merged_node_b = tree.root_node.children[1] + assert torch.equal(merged_node_b.token_id_key, torch.tensor([1, 2, 3], dtype=torch.int64)) + assert merged_node_b.is_leaf() + + unmerged_node_c = tree.root_node.children[4] + assert torch.equal(unmerged_node_c.token_id_key, torch.tensor([4, 5], dtype=torch.int64)) + assert not unmerged_node_c.is_leaf() + assert len(unmerged_node_c.children) == 1 + + unmerged_node_d = unmerged_node_c.children[6] + assert torch.equal(unmerged_node_d.token_id_key, torch.tensor([6], dtype=torch.int64)) + + if __name__ == "__main__": pytest.main()