@@ -91,5 +91,144 @@ def test_case4():
9191 return
9292
9393
94+ def test_case5 ():
95+ """
96+ 测试场景:一个简单的父子节点链 (A -> B),在 ref_counter 都为 0 时,应该成功合并。
97+ """
98+ print ("\n Test Case 5: Merging simple parent-child nodes when ref_counter is 0\n " )
99+ tree = RadixCache ("unique_name" , 100 , 0 )
100+
101+ _ , node_a = tree .insert (torch .tensor ([1 , 2 , 3 ], dtype = torch .int64 ))
102+ _ , node_b = tree .insert (torch .tensor ([1 , 2 , 3 , 4 , 5 ], dtype = torch .int64 ))
103+ tree .print_self ()
104+
105+ # 验证初始状态:A -> B 结构,且 ref_counter 均为 0
106+ assert node_b .parent == node_a
107+ assert torch .equal (node_a .token_id_key , torch .tensor ([1 , 2 , 3 ], dtype = torch .int64 ))
108+ assert len (node_a .children ) == 1
109+ assert node_a .ref_counter == 0
110+ assert node_b .ref_counter == 0
111+ assert tree .get_tree_total_tokens_num () == 5
112+
113+ # 执行合并
114+ tree .merge_unreferenced_nodes ()
115+ tree .print_self ()
116+
117+ assert torch .equal (node_b .token_id_key , torch .tensor ([1 , 2 , 3 , 4 , 5 ], dtype = torch .int64 ))
118+ assert node_b .is_leaf ()
119+ assert tree .get_tree_total_tokens_num () == 5
120+ assert tree .root_node .children [1 ] is node_b
121+
122+
123+ def test_case6 ():
124+ """
125+ 测试场景:一个长的节点链 (A -> B -> C),在 ref_counter 都为 0 时,应该级联合并成一个节点。
126+ """
127+ print ("\n Test Case 6: Merging long nodes when ref_counter is 0\n " )
128+ tree = RadixCache ("unique_name" , 100 , 0 )
129+ _ , node_a = tree .insert (torch .tensor ([1 ], dtype = torch .int64 ))
130+ _ , node_b = tree .insert (torch .tensor ([1 , 2 ], dtype = torch .int64 ))
131+ _ , node_c = tree .insert (torch .tensor ([1 , 2 , 3 , 4 ], dtype = torch .int64 ))
132+ tree .print_self ()
133+
134+ assert node_c .parent == node_b
135+ assert node_b .parent == node_a
136+ assert tree .get_tree_total_tokens_num () == 4
137+ tree .merge_unreferenced_nodes ()
138+ tree .print_self ()
139+
140+ assert len (tree .root_node .children ) == 1
141+ # 节点 C 的 key 应该是完整的 [1, 2, 3, 4]
142+ assert torch .equal (node_c .token_id_key , torch .tensor ([1 , 2 , 3 , 4 ], dtype = torch .int64 ))
143+ assert node_c .is_leaf ()
144+ assert tree .get_tree_total_tokens_num () == 4
145+
146+
147+ def test_case7 ():
148+ """
149+ 测试场景:由于父节点或子节点的 ref_counter > 0,合并不应该发生。
150+ """
151+ print ("\n Test Case 7: Merging when parent or child ref_counter > 0\n " )
152+ tree = RadixCache ("unique_name" , 100 , 0 )
153+
154+ _ , node_a = tree .insert (torch .tensor ([1 , 2 , 3 ], dtype = torch .int64 ))
155+ _ , node_b = tree .insert (torch .tensor ([1 , 2 , 3 , 4 , 5 ], dtype = torch .int64 ))
156+ tree .print_self ()
157+
158+ matched_node , _ , _ = tree .match_prefix (torch .tensor ([1 , 2 , 3 ], dtype = torch .int64 ), update_refs = True )
159+ assert matched_node is node_a
160+ assert node_a .ref_counter == 1
161+ assert node_b .ref_counter == 0
162+
163+ tree .merge_unreferenced_nodes ()
164+ tree .print_self ()
165+
166+ assert torch .equal (node_a .token_id_key , torch .tensor ([1 , 2 , 3 ], dtype = torch .int64 ))
167+ assert not node_a .is_leaf ()
168+ assert node_b .parent is node_a
169+
170+
171+ def test_case8 ():
172+ """
173+ 测试场景:由于父节点有多个子节点,合并不应该发生。
174+ """
175+ print ("\n Test Case 8: Merging when parent has multiple children\n " )
176+ tree = RadixCache ("unique_name" , 100 , 0 )
177+
178+ _ , node_a = tree .insert (torch .tensor ([1 , 2 ], dtype = torch .int64 ))
179+ _ , node_b = tree .insert (torch .tensor ([1 , 2 , 3 ], dtype = torch .int64 ))
180+ _ , node_c = tree .insert (torch .tensor ([1 , 2 , 4 ], dtype = torch .int64 ))
181+ tree .print_self ()
182+
183+ assert len (node_a .children ) == 2
184+ assert node_a .ref_counter == 0
185+ assert node_b .ref_counter == 0
186+ assert node_c .ref_counter == 0
187+
188+ tree .merge_unreferenced_nodes ()
189+ tree .print_self ()
190+
191+ assert len (node_a .children ) == 2
192+ assert torch .equal (node_a .token_id_key , torch .tensor ([1 , 2 ], dtype = torch .int64 ))
193+ assert tree .root_node .children [1 ].children [3 ] is node_b
194+ assert tree .root_node .children [1 ].children [4 ] is node_c
195+
196+
197+ def test_case9 ():
198+ """
199+ 测试场景:在一个复杂的树中,只有满足条件的分支被合并。
200+ """
201+ print ("\n Test Case 9: Merging in a complex tree with mixed conditions\n " )
202+ tree = RadixCache ("unique_name" , 100 , 0 )
203+
204+ # 分支1: 可合并的链 A -> B
205+ _ , node_a = tree .insert (torch .tensor ([1 , 2 ], dtype = torch .int64 ))
206+ _ , node_b = tree .insert (torch .tensor ([1 , 2 , 3 ], dtype = torch .int64 ))
207+
208+ # 分支2: 不可合并的链 C -> D (因为 C 被引用)
209+ _ , node_c = tree .insert (torch .tensor ([4 , 5 ], dtype = torch .int64 ))
210+ _ , node_d = tree .insert (torch .tensor ([4 , 5 , 6 ], dtype = torch .int64 ))
211+
212+ # 增加 C 的引用计数
213+ tree .match_prefix (torch .tensor ([4 , 5 ], dtype = torch .int64 ), update_refs = True )
214+ assert node_c .ref_counter == 1
215+ tree .print_self ()
216+
217+ tree .merge_unreferenced_nodes ()
218+ tree .print_self ()
219+
220+ merged_node_b = tree .root_node .children [1 ]
221+ assert torch .equal (merged_node_b .token_id_key , torch .tensor ([1 , 2 , 3 ], dtype = torch .int64 ))
222+ assert merged_node_b .is_leaf ()
223+
224+ unmerged_node_c = tree .root_node .children [4 ]
225+ assert torch .equal (unmerged_node_c .token_id_key , torch .tensor ([4 , 5 ], dtype = torch .int64 ))
226+ assert not unmerged_node_c .is_leaf ()
227+ assert len (unmerged_node_c .children ) == 1
228+
229+ unmerged_node_d = unmerged_node_c .children [6 ]
230+ assert torch .equal (unmerged_node_d .token_id_key , torch .tensor ([6 ], dtype = torch .int64 ))
231+
232+
94233if __name__ == "__main__" :
95234 pytest .main ()
0 commit comments