Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ repos:
args: [--line-length=120]
additional_dependencies: ['click==8.0.4']
- repo: https://github.com/pycqa/flake8
rev: 3.9.0
rev: 6.1.0
hooks:
- id: flake8
additional_dependencies: [flake8-typing-imports==1.9.0]
args: ['--config=.flake8', '--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606, E231']
args: ['--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606, E231']
58 changes: 58 additions & 0 deletions lightllm/server/router/dynamic_prompt/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,64 @@ def evict(self, need_remove_tokens, evict_callback):

return

def _try_merge(self, child_node: TreeNode) -> Optional[TreeNode]:
"""
合并条件:
1. 父节点不是根节点。
2. 父节点的引用计数为 0。
3. 子节点的引用计数为 0。
4. 父节点只有一个子节点 (即 child_node)。
"""
parent_node = child_node.parent
# 条件检查
if (
parent_node is None
or parent_node == self.root_node
or parent_node.ref_counter != 0
or len(parent_node.children) != 1
or child_node.ref_counter != 0
):
return None

if child_node.is_leaf():
self.evict_tree_set.discard(child_node)

child_node.token_id_key = torch.cat([parent_node.token_id_key, child_node.token_id_key])
child_node.token_mem_index_value = torch.cat(
[parent_node.token_mem_index_value, child_node.token_mem_index_value]
)
child_node.node_value_len = len(child_node.token_mem_index_value)
child_node.time_id = max(parent_node.time_id, child_node.time_id)

grandparent_node = parent_node.parent
key_in_grandparent = parent_node.token_id_key[0].item()
grandparent_node.children[key_in_grandparent] = child_node
child_node.parent = grandparent_node

parent_node.parent = None

if child_node.is_leaf():
self.evict_tree_set.add(child_node)

return child_node

def merge_unreferenced_nodes(self):
worklist = collections.deque(
[
node
for node in self.evict_tree_set
if node.ref_counter == 0 and node.parent is not None and node.parent != self.root_node
]
)

while worklist:
node = worklist.popleft()
if node.parent is None:
continue
merged_node = self._try_merge(node)
if merged_node:
worklist.append(merged_node)

def assert_leafs_is_right(self):
for node in self.evict_tree_set:
if node.is_leaf() and node.ref_counter == 0:
Expand Down
30 changes: 29 additions & 1 deletion lightllm/server/router/model_infer/mode_backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
from lightllm.utils.dist_utils import get_dp_world_size, get_global_dp_rank, get_current_rank_in_dp
from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_node, get_node_world_size
from lightllm.utils.dist_utils import get_dp_rank_in_node, create_new_group_for_current_node
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.utils.envs_utils import (
get_env_start_args,
enable_radix_tree_timer_merge,
get_radix_tree_merge_update_delta,
)
from lightllm.distributed import dist_group_manager
from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer
from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack
Expand Down Expand Up @@ -61,6 +65,11 @@ def __init__(self) -> None:

# nixl pd mode callback func
self.nixl_prefill_chuncked_handle_func: Optional[Callable[[InferReq, int, float, int], None]] = None

# counter
self._radix_tree_merge_counter: int = 0
self._enable_radix_tree_timer_merge: bool = enable_radix_tree_timer_merge()
self._radix_tree_merge_update_delta: int = get_radix_tree_merge_update_delta()
pass

def init_model(self, kvargs):
Expand Down Expand Up @@ -439,6 +448,22 @@ def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]:
"""
return [g_infer_context.requests_mapping[request_id] for request_id in req_ids]

def _timer_merge_radix_tree(self):
self._radix_tree_merge_counter += 1
if (
self._enable_radix_tree_timer_merge
and (self._radix_tree_merge_counter % self._radix_tree_merge_update_delta == 0)
and self.radix_cache is not None
):
g_infer_state_lock.acquire()
start = time.time()
self.radix_cache.merge_unreferenced_nodes()
self.logger.info(
f"radix tree merge_unreferenced_nodes cost time {time.time() - start} s in rank {self.global_rank}"
)
g_infer_state_lock.release()
return

# 一些可以复用的通用功能函数
def _get_classed_reqs(
self,
Expand All @@ -465,6 +490,9 @@ def _get_classed_reqs(
4. prefill_reqs 需要进行prefill操作的请求
5. decode_reqs 需要进行decode操作的请求
"""
# 定期对 radix cache 进行 merge,防止查询插入的操作效率下降
self._timer_merge_radix_tree()

if self.args.enable_cpu_cache and len(g_infer_context.infer_req_ids) > 0:
self.multi_level_cache_module.update_cpu_cache_task_states()

Expand Down
13 changes: 13 additions & 0 deletions lightllm/utils/envs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,16 @@ def disable_cpu_kvcache_sync() -> bool:
实验用环境遍历,未来可能会移除
"""
return enable_env_vars("LIGHTLLM_DISABLE_CPU_CACHE_SYNC")


@lru_cache(maxsize=None)
def enable_radix_tree_timer_merge() -> bool:
"""
使能定期合并 radix tree的叶节点, 防止插入查询性能下降。
"""
return enable_env_vars("LIGHTLLM_RADIX_TREE_MERGE_ENABLE")


@lru_cache(maxsize=None)
def get_radix_tree_merge_update_delta() -> int:
return int(os.getenv("LIGHTLMM_RADIX_TREE_MERGE_DELTA", 6000))
139 changes: 139 additions & 0 deletions unit_tests/server/router/dynamic_prompt/test_radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,144 @@ def test_case4():
return


def test_case5():
"""
测试场景:一个简单的父子节点链 (A -> B),在 ref_counter 都为 0 时,应该成功合并。
"""
print("\nTest Case 5: Merging simple parent-child nodes when ref_counter is 0\n")
tree = RadixCache("unique_name", 100, 0)

_, node_a = tree.insert(torch.tensor([1, 2, 3], dtype=torch.int64))
_, node_b = tree.insert(torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64))
tree.print_self()

# 验证初始状态:A -> B 结构,且 ref_counter 均为 0
assert node_b.parent == node_a
assert torch.equal(node_a.token_id_key, torch.tensor([1, 2, 3], dtype=torch.int64))
assert len(node_a.children) == 1
assert node_a.ref_counter == 0
assert node_b.ref_counter == 0
assert tree.get_tree_total_tokens_num() == 5

# 执行合并
tree.merge_unreferenced_nodes()
tree.print_self()

assert torch.equal(node_b.token_id_key, torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64))
assert node_b.is_leaf()
assert tree.get_tree_total_tokens_num() == 5
assert tree.root_node.children[1] is node_b


def test_case6():
"""
测试场景:一个长的节点链 (A -> B -> C),在 ref_counter 都为 0 时,应该级联合并成一个节点。
"""
print("\nTest Case 6: Merging long nodes when ref_counter is 0\n")
tree = RadixCache("unique_name", 100, 0)
_, node_a = tree.insert(torch.tensor([1], dtype=torch.int64))
_, node_b = tree.insert(torch.tensor([1, 2], dtype=torch.int64))
_, node_c = tree.insert(torch.tensor([1, 2, 3, 4], dtype=torch.int64))
tree.print_self()

assert node_c.parent == node_b
assert node_b.parent == node_a
assert tree.get_tree_total_tokens_num() == 4
tree.merge_unreferenced_nodes()
tree.print_self()

assert len(tree.root_node.children) == 1
# 节点 C 的 key 应该是完整的 [1, 2, 3, 4]
assert torch.equal(node_c.token_id_key, torch.tensor([1, 2, 3, 4], dtype=torch.int64))
assert node_c.is_leaf()
assert tree.get_tree_total_tokens_num() == 4


def test_case7():
"""
测试场景:由于父节点或子节点的 ref_counter > 0,合并不应该发生。
"""
print("\nTest Case 7: Merging when parent or child ref_counter > 0\n")
tree = RadixCache("unique_name", 100, 0)

_, node_a = tree.insert(torch.tensor([1, 2, 3], dtype=torch.int64))
_, node_b = tree.insert(torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64))
tree.print_self()

matched_node, _, _ = tree.match_prefix(torch.tensor([1, 2, 3], dtype=torch.int64), update_refs=True)
assert matched_node is node_a
assert node_a.ref_counter == 1
assert node_b.ref_counter == 0

tree.merge_unreferenced_nodes()
tree.print_self()

assert torch.equal(node_a.token_id_key, torch.tensor([1, 2, 3], dtype=torch.int64))
assert not node_a.is_leaf()
assert node_b.parent is node_a


def test_case8():
"""
测试场景:由于父节点有多个子节点,合并不应该发生。
"""
print("\nTest Case 8: Merging when parent has multiple children\n")
tree = RadixCache("unique_name", 100, 0)

_, node_a = tree.insert(torch.tensor([1, 2], dtype=torch.int64))
_, node_b = tree.insert(torch.tensor([1, 2, 3], dtype=torch.int64))
_, node_c = tree.insert(torch.tensor([1, 2, 4], dtype=torch.int64))
tree.print_self()

assert len(node_a.children) == 2
assert node_a.ref_counter == 0
assert node_b.ref_counter == 0
assert node_c.ref_counter == 0

tree.merge_unreferenced_nodes()
tree.print_self()

assert len(node_a.children) == 2
assert torch.equal(node_a.token_id_key, torch.tensor([1, 2], dtype=torch.int64))
assert tree.root_node.children[1].children[3] is node_b
assert tree.root_node.children[1].children[4] is node_c


def test_case9():
"""
测试场景:在一个复杂的树中,只有满足条件的分支被合并。
"""
print("\nTest Case 9: Merging in a complex tree with mixed conditions\n")
tree = RadixCache("unique_name", 100, 0)

# 分支1: 可合并的链 A -> B
_, node_a = tree.insert(torch.tensor([1, 2], dtype=torch.int64))
_, node_b = tree.insert(torch.tensor([1, 2, 3], dtype=torch.int64))

# 分支2: 不可合并的链 C -> D (因为 C 被引用)
_, node_c = tree.insert(torch.tensor([4, 5], dtype=torch.int64))
_, node_d = tree.insert(torch.tensor([4, 5, 6], dtype=torch.int64))

# 增加 C 的引用计数
tree.match_prefix(torch.tensor([4, 5], dtype=torch.int64), update_refs=True)
assert node_c.ref_counter == 1
tree.print_self()

tree.merge_unreferenced_nodes()
tree.print_self()

merged_node_b = tree.root_node.children[1]
assert torch.equal(merged_node_b.token_id_key, torch.tensor([1, 2, 3], dtype=torch.int64))
assert merged_node_b.is_leaf()

unmerged_node_c = tree.root_node.children[4]
assert torch.equal(unmerged_node_c.token_id_key, torch.tensor([4, 5], dtype=torch.int64))
assert not unmerged_node_c.is_leaf()
assert len(unmerged_node_c.children) == 1

unmerged_node_d = unmerged_node_c.children[6]
assert torch.equal(unmerged_node_d.token_id_key, torch.tensor([6], dtype=torch.int64))


if __name__ == "__main__":
pytest.main()