@@ -14,10 +14,6 @@ use crate::node::{Node, NodeId, NodeState, ParentAndIndex};
1414#[ repr( transparent) ]
1515pub ( crate ) struct TreeIndex ( pub ( crate ) u32 ) ;
1616
17- fn nid ( id : LocalNodeId ) -> NodeId {
18- NodeId :: new ( id, TreeIndex ( 0 ) )
19- }
20-
2117#[ derive( Debug , Default ) ]
2218struct TreeIndexMap {
2319 id_to_index : HashMap < TreeId , TreeIndex > ,
@@ -40,12 +36,22 @@ impl TreeIndexMap {
4036 }
4137}
4238
39+ /// State for a subtree, including its root node and current focus.
40+ #[ derive( Clone , Debug ) ]
41+ pub ( crate ) struct SubtreeState {
42+ pub ( crate ) root : NodeId ,
43+ pub ( crate ) focus : NodeId ,
44+ }
45+
4346#[ derive( Clone , Debug ) ]
4447pub struct State {
4548 pub ( crate ) nodes : HashMap < NodeId , NodeState > ,
4649 pub ( crate ) data : TreeData ,
50+ pub ( crate ) root : NodeId ,
4751 pub ( crate ) focus : NodeId ,
4852 is_host_focused : bool ,
53+ /// Maps TreeId to the state of each subtree (root and focus).
54+ pub ( crate ) subtrees : HashMap < TreeId , SubtreeState > ,
4955}
5056
5157#[ derive( Default ) ]
@@ -57,7 +63,7 @@ struct InternalChanges {
5763
5864impl State {
5965 fn validate_global ( & self ) {
60- if !self . nodes . contains_key ( & nid ( self . data . root ) ) {
66+ if !self . nodes . contains_key ( & self . root ) {
6167 panic ! ( "Root ID {:?} is not in the node list" , self . data. root) ;
6268 }
6369 if !self . nodes . contains_key ( & self . focus ) {
@@ -70,18 +76,39 @@ impl State {
7076 update : TreeUpdate ,
7177 is_host_focused : bool ,
7278 mut changes : Option < & mut InternalChanges > ,
79+ tree_index : TreeIndex ,
7380 ) {
74- let mut unreachable = HashSet :: new ( ) ;
75- let mut seen_child_ids = HashSet :: new ( ) ;
81+ let map_id = |id : LocalNodeId | NodeId :: new ( id, tree_index) ;
82+
83+ let mut unreachable: HashSet < NodeId > = HashSet :: new ( ) ;
84+ let mut seen_child_ids: HashSet < NodeId > = HashSet :: new ( ) ;
7685
77- if let Some ( tree) = update. tree {
78- if tree. root != self . data . root {
79- unreachable. insert ( nid ( self . data . root ) ) ;
86+ let tree_id = update. tree_id ;
87+
88+ let new_tree_root = if let Some ( tree) = update. tree {
89+ let new_root = map_id ( tree. root ) ;
90+ if tree_id == TreeId :: ROOT {
91+ // Only update main tree root/data for ROOT tree
92+ if tree. root != self . data . root {
93+ unreachable. insert ( self . root ) ;
94+ }
95+ self . root = new_root;
96+ self . data = tree;
8097 }
81- self . data = tree;
82- }
98+ Some ( new_root)
99+ } else {
100+ None
101+ } ;
83102
84- let root = self . data . root ;
103+ // Use the tree's root from the update, or fallback to existing subtree/main tree root
104+ let root = new_tree_root
105+ . map ( |r| r. to_components ( ) . 0 )
106+ . unwrap_or_else ( || {
107+ self . subtrees
108+ . get ( & tree_id)
109+ . map ( |s| s. root . to_components ( ) . 0 )
110+ . unwrap_or ( self . data . root )
111+ } ) ;
85112 let mut pending_nodes: HashMap < NodeId , _ > = HashMap :: new ( ) ;
86113 let mut pending_children = HashMap :: new ( ) ;
87114
@@ -103,33 +130,34 @@ impl State {
103130 }
104131
105132 for ( local_node_id, node_data) in update. nodes {
106- let node_id = nid ( local_node_id) ;
133+ let node_id = map_id ( local_node_id) ;
107134 unreachable. remove ( & node_id) ;
108135
109136 for ( child_index, child_id) in node_data. children ( ) . iter ( ) . enumerate ( ) {
110- if seen_child_ids. contains ( child_id) {
137+ let mapped_child_id = map_id ( * child_id) ;
138+ if seen_child_ids. contains ( & mapped_child_id) {
111139 panic ! ( "TreeUpdate includes duplicate child {:?}" , child_id) ;
112140 }
113- seen_child_ids. insert ( * child_id ) ;
114- unreachable. remove ( & nid ( * child_id ) ) ;
141+ seen_child_ids. insert ( mapped_child_id ) ;
142+ unreachable. remove ( & mapped_child_id ) ;
115143 let parent_and_index = ParentAndIndex ( node_id, child_index) ;
116- if let Some ( child_state) = self . nodes . get_mut ( & nid ( * child_id ) ) {
144+ if let Some ( child_state) = self . nodes . get_mut ( & mapped_child_id ) {
117145 if child_state. parent_and_index != Some ( parent_and_index) {
118146 child_state. parent_and_index = Some ( parent_and_index) ;
119147 if let Some ( changes) = & mut changes {
120- changes. updated_node_ids . insert ( nid ( * child_id ) ) ;
148+ changes. updated_node_ids . insert ( mapped_child_id ) ;
121149 }
122150 }
123- } else if let Some ( child_data) = pending_nodes. remove ( & nid ( * child_id ) ) {
151+ } else if let Some ( child_data) = pending_nodes. remove ( & mapped_child_id ) {
124152 add_node (
125153 & mut self . nodes ,
126154 & mut changes,
127155 Some ( parent_and_index) ,
128- nid ( * child_id ) ,
156+ mapped_child_id ,
129157 child_data,
130158 ) ;
131159 } else {
132- pending_children. insert ( nid ( * child_id ) , parent_and_index) ;
160+ pending_children. insert ( mapped_child_id , parent_and_index) ;
133161 }
134162 }
135163
@@ -138,8 +166,9 @@ impl State {
138166 node_state. parent_and_index = None ;
139167 }
140168 for child_id in node_state. data . children ( ) . iter ( ) {
141- if !seen_child_ids. contains ( child_id) {
142- unreachable. insert ( nid ( * child_id) ) ;
169+ let mapped_existing_child_id = map_id ( * child_id) ;
170+ if !seen_child_ids. contains ( & mapped_existing_child_id) {
171+ unreachable. insert ( mapped_existing_child_id) ;
143172 }
144173 }
145174 if node_state. data != node_data {
@@ -178,29 +207,63 @@ impl State {
178207 ) ;
179208 }
180209
181- self . focus = nid ( update. focus ) ;
210+ // Store subtree state (root and focus) per tree
211+ let tree_focus = map_id ( update. focus ) ;
212+ if let Some ( new_root) = new_tree_root {
213+ // New tree: insert both root and focus
214+ self . subtrees . insert (
215+ tree_id,
216+ SubtreeState {
217+ root : new_root,
218+ focus : tree_focus,
219+ } ,
220+ ) ;
221+ } else if let Some ( subtree) = self . subtrees . get_mut ( & tree_id) {
222+ // Existing tree: just update focus
223+ subtree. focus = tree_focus;
224+ } else if tree_id == TreeId :: ROOT {
225+ // ROOT tree focus update without tree change (e.g., during Tree::new after take())
226+ // Use the main tree's root for the subtree state
227+ self . subtrees . insert (
228+ tree_id,
229+ SubtreeState {
230+ root : self . root ,
231+ focus : tree_focus,
232+ } ,
233+ ) ;
234+ }
235+
236+ self . focus = tree_focus;
182237 self . is_host_focused = is_host_focused;
183238
184239 if !unreachable. is_empty ( ) {
185240 fn traverse_unreachable (
186241 nodes : & mut HashMap < NodeId , NodeState > ,
187242 changes : & mut Option < & mut InternalChanges > ,
188- seen_child_ids : & HashSet < LocalNodeId > ,
243+ seen_child_ids : & HashSet < NodeId > ,
189244 id : NodeId ,
245+ map_id : impl Fn ( LocalNodeId ) -> NodeId + Copy ,
190246 ) {
191247 if let Some ( changes) = changes {
192248 changes. removed_node_ids . insert ( id) ;
193249 }
194250 let node = nodes. remove ( & id) . unwrap ( ) ;
195251 for child_id in node. data . children ( ) . iter ( ) {
196- if !seen_child_ids. contains ( child_id) {
197- traverse_unreachable ( nodes, changes, seen_child_ids, nid ( * child_id) ) ;
252+ let mapped_child_id = map_id ( * child_id) ;
253+ if !seen_child_ids. contains ( & mapped_child_id) {
254+ traverse_unreachable (
255+ nodes,
256+ changes,
257+ seen_child_ids,
258+ mapped_child_id,
259+ map_id,
260+ ) ;
198261 }
199262 }
200263 }
201264
202265 for id in unreachable {
203- traverse_unreachable ( & mut self . nodes , & mut changes, & seen_child_ids, id) ;
266+ traverse_unreachable ( & mut self . nodes , & mut changes, & seen_child_ids, id, map_id ) ;
204267 }
205268 }
206269
@@ -219,7 +282,7 @@ impl State {
219282 tree_id : TreeId :: ROOT ,
220283 focus,
221284 } ;
222- self . update ( update, is_host_focused, changes) ;
285+ self . update ( update, is_host_focused, changes, TreeIndex ( 0 ) ) ;
223286 }
224287
225288 pub fn has_node ( & self , id : NodeId ) -> bool {
@@ -235,7 +298,7 @@ impl State {
235298 }
236299
237300 pub fn root_id ( & self ) -> NodeId {
238- nid ( self . data . root )
301+ self . root
239302 }
240303
241304 pub fn root ( & self ) -> Node < ' _ > {
@@ -307,14 +370,16 @@ impl Tree {
307370 panic ! ( "Cannot initialize with a subtree. TreeUpdate::tree_id must be TreeId::ROOT." ) ;
308371 }
309372 let mut tree_index_map = TreeIndexMap :: default ( ) ;
310- tree_index_map. get_index ( initial_state. tree_id ) ;
373+ let tree_index = tree_index_map. get_index ( initial_state. tree_id ) ;
311374 let mut state = State {
312375 nodes : HashMap :: new ( ) ,
376+ root : NodeId :: new ( tree. root , tree_index) ,
313377 data : tree,
314- focus : nid ( initial_state. focus ) ,
378+ focus : NodeId :: new ( initial_state. focus , tree_index ) ,
315379 is_host_focused,
380+ subtrees : HashMap :: new ( ) ,
316381 } ;
317- state. update ( initial_state, is_host_focused, None ) ;
382+ state. update ( initial_state, is_host_focused, None , tree_index ) ;
318383 Self {
319384 next_state : state. clone ( ) ,
320385 state,
@@ -327,9 +392,14 @@ impl Tree {
327392 update : TreeUpdate ,
328393 handler : & mut impl ChangeHandler ,
329394 ) {
395+ let tree_index = self . tree_index_map . get_index ( update. tree_id ) ;
330396 let mut changes = InternalChanges :: default ( ) ;
331- self . next_state
332- . update ( update, self . state . is_host_focused , Some ( & mut changes) ) ;
397+ self . next_state . update (
398+ update,
399+ self . state . is_host_focused ,
400+ Some ( & mut changes) ,
401+ tree_index,
402+ ) ;
333403 self . process_changes ( changes, handler) ;
334404 }
335405
@@ -400,8 +470,10 @@ impl Tree {
400470 if self . state . data != self . next_state . data {
401471 self . state . data . clone_from ( & self . next_state . data ) ;
402472 }
473+ self . state . root = self . next_state . root ;
403474 self . state . focus = self . next_state . focus ;
404475 self . state . is_host_focused = self . next_state . is_host_focused ;
476+ self . state . subtrees . clone_from ( & self . next_state . subtrees ) ;
405477 }
406478
407479 pub fn state ( & self ) -> & State {
0 commit comments