Skip to content

Commit f0ff154

Browse files
Longxmasshihaobaiwangzaijun
authored
implement radix_cache node merge function (#1090)
使用非递归函数实现radix cache node合并功能,并补充了相应的单元测试 --------- Co-authored-by: baishihao <[email protected]> Co-authored-by: wangzaijun <[email protected]>
1 parent ecbfe9c commit f0ff154

File tree

5 files changed

+241
-4
lines changed

5 files changed

+241
-4
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@ repos:
77
args: [--line-length=120]
88
additional_dependencies: ['click==8.0.4']
99
- repo: https://github.com/pycqa/flake8
10-
rev: 3.9.0
10+
rev: 6.1.0
1111
hooks:
1212
- id: flake8
13-
additional_dependencies: [flake8-typing-imports==1.9.0]
14-
args: ['--config=.flake8', '--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606, E231']
13+
args: ['--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606, E231']

lightllm/server/router/dynamic_prompt/radix_cache.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,64 @@ def evict(self, need_remove_tokens, evict_callback):
342342

343343
return
344344

345+
def _try_merge(self, child_node: TreeNode) -> Optional[TreeNode]:
346+
"""
347+
合并条件:
348+
1. 父节点不是根节点。
349+
2. 父节点的引用计数为 0。
350+
3. 子节点的引用计数为 0。
351+
4. 父节点只有一个子节点 (即 child_node)。
352+
"""
353+
parent_node = child_node.parent
354+
# 条件检查
355+
if (
356+
parent_node is None
357+
or parent_node == self.root_node
358+
or parent_node.ref_counter != 0
359+
or len(parent_node.children) != 1
360+
or child_node.ref_counter != 0
361+
):
362+
return None
363+
364+
if child_node.is_leaf():
365+
self.evict_tree_set.discard(child_node)
366+
367+
child_node.token_id_key = torch.cat([parent_node.token_id_key, child_node.token_id_key])
368+
child_node.token_mem_index_value = torch.cat(
369+
[parent_node.token_mem_index_value, child_node.token_mem_index_value]
370+
)
371+
child_node.node_value_len = len(child_node.token_mem_index_value)
372+
child_node.time_id = max(parent_node.time_id, child_node.time_id)
373+
374+
grandparent_node = parent_node.parent
375+
key_in_grandparent = parent_node.token_id_key[0].item()
376+
grandparent_node.children[key_in_grandparent] = child_node
377+
child_node.parent = grandparent_node
378+
379+
parent_node.parent = None
380+
381+
if child_node.is_leaf():
382+
self.evict_tree_set.add(child_node)
383+
384+
return child_node
385+
386+
def merge_unreferenced_nodes(self):
387+
worklist = collections.deque(
388+
[
389+
node
390+
for node in self.evict_tree_set
391+
if node.ref_counter == 0 and node.parent is not None and node.parent != self.root_node
392+
]
393+
)
394+
395+
while worklist:
396+
node = worklist.popleft()
397+
if node.parent is None:
398+
continue
399+
merged_node = self._try_merge(node)
400+
if merged_node:
401+
worklist.append(merged_node)
402+
345403
def assert_leafs_is_right(self):
346404
for node in self.evict_tree_set:
347405
if node.is_leaf() and node.ref_counter == 0:

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@
2626
from lightllm.utils.dist_utils import get_dp_world_size, get_global_dp_rank, get_current_rank_in_dp
2727
from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_node, get_node_world_size
2828
from lightllm.utils.dist_utils import get_dp_rank_in_node, create_new_group_for_current_node
29-
from lightllm.utils.envs_utils import get_env_start_args
29+
from lightllm.utils.envs_utils import (
30+
get_env_start_args,
31+
enable_radix_tree_timer_merge,
32+
get_radix_tree_merge_update_delta,
33+
)
3034
from lightllm.distributed import dist_group_manager
3135
from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer
3236
from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack
@@ -61,6 +65,11 @@ def __init__(self) -> None:
6165

6266
# nixl pd mode callback func
6367
self.nixl_prefill_chuncked_handle_func: Optional[Callable[[InferReq, int, float, int], None]] = None
68+
69+
# counter
70+
self._radix_tree_merge_counter: int = 0
71+
self._enable_radix_tree_timer_merge: bool = enable_radix_tree_timer_merge()
72+
self._radix_tree_merge_update_delta: int = get_radix_tree_merge_update_delta()
6473
pass
6574

6675
def init_model(self, kvargs):
@@ -439,6 +448,22 @@ def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]:
439448
"""
440449
return [g_infer_context.requests_mapping[request_id] for request_id in req_ids]
441450

451+
def _timer_merge_radix_tree(self):
452+
self._radix_tree_merge_counter += 1
453+
if (
454+
self._enable_radix_tree_timer_merge
455+
and (self._radix_tree_merge_counter % self._radix_tree_merge_update_delta == 0)
456+
and self.radix_cache is not None
457+
):
458+
g_infer_state_lock.acquire()
459+
start = time.time()
460+
self.radix_cache.merge_unreferenced_nodes()
461+
self.logger.info(
462+
f"radix tree merge_unreferenced_nodes cost time {time.time() - start} s in rank {self.global_rank}"
463+
)
464+
g_infer_state_lock.release()
465+
return
466+
442467
# 一些可以复用的通用功能函数
443468
def _get_classed_reqs(
444469
self,
@@ -465,6 +490,9 @@ def _get_classed_reqs(
465490
4. prefill_reqs 需要进行prefill操作的请求
466491
5. decode_reqs 需要进行decode操作的请求
467492
"""
493+
# 定期对 radix cache 进行 merge,防止查询插入的操作效率下降
494+
self._timer_merge_radix_tree()
495+
468496
if self.args.enable_cpu_cache and len(g_infer_context.infer_req_ids) > 0:
469497
self.multi_level_cache_module.update_cpu_cache_task_states()
470498

lightllm/utils/envs_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,16 @@ def disable_cpu_kvcache_sync() -> bool:
181181
实验用环境遍历,未来可能会移除
182182
"""
183183
return enable_env_vars("LIGHTLLM_DISABLE_CPU_CACHE_SYNC")
184+
185+
186+
@lru_cache(maxsize=None)
187+
def enable_radix_tree_timer_merge() -> bool:
188+
"""
189+
使能定期合并 radix tree的叶节点, 防止插入查询性能下降。
190+
"""
191+
return enable_env_vars("LIGHTLLM_RADIX_TREE_MERGE_ENABLE")
192+
193+
194+
@lru_cache(maxsize=None)
195+
def get_radix_tree_merge_update_delta() -> int:
196+
return int(os.getenv("LIGHTLMM_RADIX_TREE_MERGE_DELTA", 6000))

unit_tests/server/router/dynamic_prompt/test_radix_cache.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,5 +91,144 @@ def test_case4():
9191
return
9292

9393

94+
def test_case5():
95+
"""
96+
测试场景:一个简单的父子节点链 (A -> B),在 ref_counter 都为 0 时,应该成功合并。
97+
"""
98+
print("\nTest Case 5: Merging simple parent-child nodes when ref_counter is 0\n")
99+
tree = RadixCache("unique_name", 100, 0)
100+
101+
_, node_a = tree.insert(torch.tensor([1, 2, 3], dtype=torch.int64))
102+
_, node_b = tree.insert(torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64))
103+
tree.print_self()
104+
105+
# 验证初始状态:A -> B 结构,且 ref_counter 均为 0
106+
assert node_b.parent == node_a
107+
assert torch.equal(node_a.token_id_key, torch.tensor([1, 2, 3], dtype=torch.int64))
108+
assert len(node_a.children) == 1
109+
assert node_a.ref_counter == 0
110+
assert node_b.ref_counter == 0
111+
assert tree.get_tree_total_tokens_num() == 5
112+
113+
# 执行合并
114+
tree.merge_unreferenced_nodes()
115+
tree.print_self()
116+
117+
assert torch.equal(node_b.token_id_key, torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64))
118+
assert node_b.is_leaf()
119+
assert tree.get_tree_total_tokens_num() == 5
120+
assert tree.root_node.children[1] is node_b
121+
122+
123+
def test_case6():
124+
"""
125+
测试场景:一个长的节点链 (A -> B -> C),在 ref_counter 都为 0 时,应该级联合并成一个节点。
126+
"""
127+
print("\nTest Case 6: Merging long nodes when ref_counter is 0\n")
128+
tree = RadixCache("unique_name", 100, 0)
129+
_, node_a = tree.insert(torch.tensor([1], dtype=torch.int64))
130+
_, node_b = tree.insert(torch.tensor([1, 2], dtype=torch.int64))
131+
_, node_c = tree.insert(torch.tensor([1, 2, 3, 4], dtype=torch.int64))
132+
tree.print_self()
133+
134+
assert node_c.parent == node_b
135+
assert node_b.parent == node_a
136+
assert tree.get_tree_total_tokens_num() == 4
137+
tree.merge_unreferenced_nodes()
138+
tree.print_self()
139+
140+
assert len(tree.root_node.children) == 1
141+
# 节点 C 的 key 应该是完整的 [1, 2, 3, 4]
142+
assert torch.equal(node_c.token_id_key, torch.tensor([1, 2, 3, 4], dtype=torch.int64))
143+
assert node_c.is_leaf()
144+
assert tree.get_tree_total_tokens_num() == 4
145+
146+
147+
def test_case7():
148+
"""
149+
测试场景:由于父节点或子节点的 ref_counter > 0,合并不应该发生。
150+
"""
151+
print("\nTest Case 7: Merging when parent or child ref_counter > 0\n")
152+
tree = RadixCache("unique_name", 100, 0)
153+
154+
_, node_a = tree.insert(torch.tensor([1, 2, 3], dtype=torch.int64))
155+
_, node_b = tree.insert(torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64))
156+
tree.print_self()
157+
158+
matched_node, _, _ = tree.match_prefix(torch.tensor([1, 2, 3], dtype=torch.int64), update_refs=True)
159+
assert matched_node is node_a
160+
assert node_a.ref_counter == 1
161+
assert node_b.ref_counter == 0
162+
163+
tree.merge_unreferenced_nodes()
164+
tree.print_self()
165+
166+
assert torch.equal(node_a.token_id_key, torch.tensor([1, 2, 3], dtype=torch.int64))
167+
assert not node_a.is_leaf()
168+
assert node_b.parent is node_a
169+
170+
171+
def test_case8():
172+
"""
173+
测试场景:由于父节点有多个子节点,合并不应该发生。
174+
"""
175+
print("\nTest Case 8: Merging when parent has multiple children\n")
176+
tree = RadixCache("unique_name", 100, 0)
177+
178+
_, node_a = tree.insert(torch.tensor([1, 2], dtype=torch.int64))
179+
_, node_b = tree.insert(torch.tensor([1, 2, 3], dtype=torch.int64))
180+
_, node_c = tree.insert(torch.tensor([1, 2, 4], dtype=torch.int64))
181+
tree.print_self()
182+
183+
assert len(node_a.children) == 2
184+
assert node_a.ref_counter == 0
185+
assert node_b.ref_counter == 0
186+
assert node_c.ref_counter == 0
187+
188+
tree.merge_unreferenced_nodes()
189+
tree.print_self()
190+
191+
assert len(node_a.children) == 2
192+
assert torch.equal(node_a.token_id_key, torch.tensor([1, 2], dtype=torch.int64))
193+
assert tree.root_node.children[1].children[3] is node_b
194+
assert tree.root_node.children[1].children[4] is node_c
195+
196+
197+
def test_case9():
198+
"""
199+
测试场景:在一个复杂的树中,只有满足条件的分支被合并。
200+
"""
201+
print("\nTest Case 9: Merging in a complex tree with mixed conditions\n")
202+
tree = RadixCache("unique_name", 100, 0)
203+
204+
# 分支1: 可合并的链 A -> B
205+
_, node_a = tree.insert(torch.tensor([1, 2], dtype=torch.int64))
206+
_, node_b = tree.insert(torch.tensor([1, 2, 3], dtype=torch.int64))
207+
208+
# 分支2: 不可合并的链 C -> D (因为 C 被引用)
209+
_, node_c = tree.insert(torch.tensor([4, 5], dtype=torch.int64))
210+
_, node_d = tree.insert(torch.tensor([4, 5, 6], dtype=torch.int64))
211+
212+
# 增加 C 的引用计数
213+
tree.match_prefix(torch.tensor([4, 5], dtype=torch.int64), update_refs=True)
214+
assert node_c.ref_counter == 1
215+
tree.print_self()
216+
217+
tree.merge_unreferenced_nodes()
218+
tree.print_self()
219+
220+
merged_node_b = tree.root_node.children[1]
221+
assert torch.equal(merged_node_b.token_id_key, torch.tensor([1, 2, 3], dtype=torch.int64))
222+
assert merged_node_b.is_leaf()
223+
224+
unmerged_node_c = tree.root_node.children[4]
225+
assert torch.equal(unmerged_node_c.token_id_key, torch.tensor([4, 5], dtype=torch.int64))
226+
assert not unmerged_node_c.is_leaf()
227+
assert len(unmerged_node_c.children) == 1
228+
229+
unmerged_node_d = unmerged_node_c.children[6]
230+
assert torch.equal(unmerged_node_d.token_id_key, torch.tensor([6], dtype=torch.int64))
231+
232+
94233
if __name__ == "__main__":
95234
pytest.main()

0 commit comments

Comments
 (0)