@@ -78,27 +78,41 @@ def __init__(
7878 )
7979
8080 async def get_partial_current_state_deltas (
81- self , prev_stream_id : int , max_stream_id : int
81+ self , prev_stream_id : int , max_stream_id : int , limit : int = 100
8282 ) -> tuple [int , list [StateDelta ]]:
83- """Fetch a list of room state changes since the given stream id
83+ """Fetch a list of room state changes since the given stream id.
8484
8585 This may be the partial state if we're lazy joining the room.
8686
87+ This method takes care to handle state deltas that share the same
88+ `stream_id`. That can happen when persisting state in a batch,
89+ potentially as the result of state resolution (both adding new state and
90+ undo'ing previous state).
91+
92+ State deltas are grouped by `stream_id`. When hitting the given `limit`
93+ would return only part of a "group" of state deltas, that entire group
94+ is omitted. Thus, this function may return *up to* `limit` state deltas,
95+ or slightly more when a single group itself exceeds `limit`.
96+
8797 Args:
8898 prev_stream_id: point to get changes since (exclusive)
8999 max_stream_id: the point that we know has been correctly persisted
90100 - ie, an upper limit to return changes from.
101+ limit: the maximum number of rows to return.
91102
92103 Returns:
93104 A tuple consisting of:
94105 - the stream id which these results go up to
95106 - list of current_state_delta_stream rows. If it is empty, we are
96107 up to date.
97-
98- A maximum of 100 rows will be returned.
99108 """
100109 prev_stream_id = int (prev_stream_id )
101110
111+ if limit <= 0 :
112+ raise ValueError (
113+ "Invalid `limit` passed to `get_partial_current_state_deltas"
114+ )
115+
102116 # check we're not going backwards
103117 assert prev_stream_id <= max_stream_id , (
104118 f"New stream id { max_stream_id } is smaller than prev stream id { prev_stream_id } "
@@ -115,45 +129,62 @@ async def get_partial_current_state_deltas(
115129 def get_current_state_deltas_txn (
116130 txn : LoggingTransaction ,
117131 ) -> tuple [int , list [StateDelta ]]:
118- # First we calculate the max stream id that will give us less than
119- # N results.
120- # We arbitrarily limit to 100 stream_id entries to ensure we don't
121- # select toooo many.
122- sql = """
123- SELECT stream_id, count(*)
132+ # First we group state deltas by `stream_id` and calculate which
133+ # groups can be returned without exceeding the provided `limit`.
134+ sql_grouped = """
135+ SELECT stream_id, COUNT(*) AS c
124136 FROM current_state_delta_stream
125137 WHERE stream_id > ? AND stream_id <= ?
126138 GROUP BY stream_id
127- ORDER BY stream_id ASC
128- LIMIT 100
139+ ORDER BY stream_id
140+ LIMIT ?
129141 """
130- txn .execute (sql , (prev_stream_id , max_stream_id ))
131-
132- total = 0
133-
134- for stream_id , count in txn :
135- total += count
136- if total > 100 :
137- # We arbitrarily limit to 100 entries to ensure we don't
138- # select toooo many.
139- logger .debug (
140- "Clipping current_state_delta_stream rows to stream_id %i" ,
141- stream_id ,
142- )
143- clipped_stream_id = stream_id
142+ group_limit = limit + 1
143+ txn .execute (sql_grouped , (prev_stream_id , max_stream_id , group_limit ))
144+ grouped_rows = txn .fetchall ()
145+
146+ if not grouped_rows :
147+ # Nothing to return in the range; we are up to date through max_stream_id.
148+ return max_stream_id , []
149+
150+ # Always retrieve the first group, at the bare minimum. This ensures the
151+ # caller always makes progress, even if a single group exceeds `limit`.
152+ fetch_upto_stream_id , included_rows = grouped_rows [0 ]
153+
154+ # Determine which other groups we can retrieve at the same time,
155+ # without blowing the budget.
156+ included_all_groups = True
157+ for stream_id , count in grouped_rows [1 :]:
158+ if included_rows + count > limit :
159+ included_all_groups = False
144160 break
145- else :
146- # if there's no problem, we may as well go right up to the max_stream_id
147- clipped_stream_id = max_stream_id
161+ included_rows += count
162+ fetch_upto_stream_id = stream_id
163+
164+ # If we retrieved fewer groups than the limit *and* we didn't hit the
165+ # `LIMIT ?` cap on the grouping query, we know we've caught up with
166+ # the stream.
167+ caught_up_with_stream = (
168+ included_all_groups and len (grouped_rows ) < group_limit
169+ )
170+
171+ # At this point we should have advanced, or bailed out early above.
172+ assert fetch_upto_stream_id != prev_stream_id
148173
149- # Now actually get the deltas
150- sql = """
174+ # 2) Fetch the actual rows for only the included stream_id groups.
175+ sql_rows = """
151176 SELECT stream_id, room_id, type, state_key, event_id, prev_event_id
152177 FROM current_state_delta_stream
153178 WHERE ? < stream_id AND stream_id <= ?
154179 ORDER BY stream_id ASC
155180 """
156- txn .execute (sql , (prev_stream_id , clipped_stream_id ))
181+ txn .execute (sql_rows , (prev_stream_id , fetch_upto_stream_id ))
182+ rows = txn .fetchall ()
183+
184+ clipped_stream_id = (
185+ max_stream_id if caught_up_with_stream else fetch_upto_stream_id
186+ )
187+
157188 return clipped_stream_id , [
158189 StateDelta (
159190 stream_id = row [0 ],
@@ -163,7 +194,7 @@ def get_current_state_deltas_txn(
163194 event_id = row [4 ],
164195 prev_event_id = row [5 ],
165196 )
166- for row in txn . fetchall ()
197+ for row in rows
167198 ]
168199
169200 return await self .db_pool .runInteraction (
0 commit comments