Skip to content

Commit 703464c

Browse files
Fix case where get_partial_current_state_deltas could return >100 rows (#18960)
1 parent c928347 commit 703464c

File tree

4 files changed

+380
-34
lines changed

4 files changed

+380
-34
lines changed

changelog.d/18960.bugfix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix a bug in the database function for fetching state deltas that could result in unnecessarily long query times.

synapse/storage/controllers/state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,7 @@ async def get_current_state_deltas(
683683
# https://github.com/matrix-org/synapse/issues/13008
684684

685685
return await self.stores.main.get_partial_current_state_deltas(
686-
prev_stream_id, max_stream_id
686+
prev_stream_id, max_stream_id, limit=100
687687
)
688688

689689
@trace

synapse/storage/databases/main/state_deltas.py

Lines changed: 64 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -78,27 +78,41 @@ def __init__(
7878
)
7979

8080
async def get_partial_current_state_deltas(
81-
self, prev_stream_id: int, max_stream_id: int
81+
self, prev_stream_id: int, max_stream_id: int, limit: int = 100
8282
) -> tuple[int, list[StateDelta]]:
83-
"""Fetch a list of room state changes since the given stream id
83+
"""Fetch a list of room state changes since the given stream id.
8484
8585
This may be the partial state if we're lazy joining the room.
8686
87+
This method takes care to handle state deltas that share the same
88+
`stream_id`. That can happen when persisting state in a batch,
89+
potentially as the result of state resolution (both adding new state and
90+
undo'ing previous state).
91+
92+
State deltas are grouped by `stream_id`. When hitting the given `limit`
93+
would return only part of a "group" of state deltas, that entire group
94+
is omitted. Thus, this function may return *up to* `limit` state deltas,
95+
or slightly more when a single group itself exceeds `limit`.
96+
8797
Args:
8898
prev_stream_id: point to get changes since (exclusive)
8999
max_stream_id: the point that we know has been correctly persisted
90100
- ie, an upper limit to return changes from.
101+
limit: the maximum number of rows to return.
91102
92103
Returns:
93104
A tuple consisting of:
94105
- the stream id which these results go up to
95106
- list of current_state_delta_stream rows. If it is empty, we are
96107
up to date.
97-
98-
A maximum of 100 rows will be returned.
99108
"""
100109
prev_stream_id = int(prev_stream_id)
101110

111+
if limit <= 0:
112+
raise ValueError(
113+
"Invalid `limit` passed to `get_partial_current_state_deltas"
114+
)
115+
102116
# check we're not going backwards
103117
assert prev_stream_id <= max_stream_id, (
104118
f"New stream id {max_stream_id} is smaller than prev stream id {prev_stream_id}"
@@ -115,45 +129,62 @@ async def get_partial_current_state_deltas(
115129
def get_current_state_deltas_txn(
116130
txn: LoggingTransaction,
117131
) -> tuple[int, list[StateDelta]]:
118-
# First we calculate the max stream id that will give us less than
119-
# N results.
120-
# We arbitrarily limit to 100 stream_id entries to ensure we don't
121-
# select toooo many.
122-
sql = """
123-
SELECT stream_id, count(*)
132+
# First we group state deltas by `stream_id` and calculate which
133+
# groups can be returned without exceeding the provided `limit`.
134+
sql_grouped = """
135+
SELECT stream_id, COUNT(*) AS c
124136
FROM current_state_delta_stream
125137
WHERE stream_id > ? AND stream_id <= ?
126138
GROUP BY stream_id
127-
ORDER BY stream_id ASC
128-
LIMIT 100
139+
ORDER BY stream_id
140+
LIMIT ?
129141
"""
130-
txn.execute(sql, (prev_stream_id, max_stream_id))
131-
132-
total = 0
133-
134-
for stream_id, count in txn:
135-
total += count
136-
if total > 100:
137-
# We arbitrarily limit to 100 entries to ensure we don't
138-
# select toooo many.
139-
logger.debug(
140-
"Clipping current_state_delta_stream rows to stream_id %i",
141-
stream_id,
142-
)
143-
clipped_stream_id = stream_id
142+
group_limit = limit + 1
143+
txn.execute(sql_grouped, (prev_stream_id, max_stream_id, group_limit))
144+
grouped_rows = txn.fetchall()
145+
146+
if not grouped_rows:
147+
# Nothing to return in the range; we are up to date through max_stream_id.
148+
return max_stream_id, []
149+
150+
# Always retrieve the first group, at the bare minimum. This ensures the
151+
# caller always makes progress, even if a single group exceeds `limit`.
152+
fetch_upto_stream_id, included_rows = grouped_rows[0]
153+
154+
# Determine which other groups we can retrieve at the same time,
155+
# without blowing the budget.
156+
included_all_groups = True
157+
for stream_id, count in grouped_rows[1:]:
158+
if included_rows + count > limit:
159+
included_all_groups = False
144160
break
145-
else:
146-
# if there's no problem, we may as well go right up to the max_stream_id
147-
clipped_stream_id = max_stream_id
161+
included_rows += count
162+
fetch_upto_stream_id = stream_id
163+
164+
# If we retrieved fewer groups than the limit *and* we didn't hit the
165+
# `LIMIT ?` cap on the grouping query, we know we've caught up with
166+
# the stream.
167+
caught_up_with_stream = (
168+
included_all_groups and len(grouped_rows) < group_limit
169+
)
170+
171+
# At this point we should have advanced, or bailed out early above.
172+
assert fetch_upto_stream_id != prev_stream_id
148173

149-
# Now actually get the deltas
150-
sql = """
174+
# 2) Fetch the actual rows for only the included stream_id groups.
175+
sql_rows = """
151176
SELECT stream_id, room_id, type, state_key, event_id, prev_event_id
152177
FROM current_state_delta_stream
153178
WHERE ? < stream_id AND stream_id <= ?
154179
ORDER BY stream_id ASC
155180
"""
156-
txn.execute(sql, (prev_stream_id, clipped_stream_id))
181+
txn.execute(sql_rows, (prev_stream_id, fetch_upto_stream_id))
182+
rows = txn.fetchall()
183+
184+
clipped_stream_id = (
185+
max_stream_id if caught_up_with_stream else fetch_upto_stream_id
186+
)
187+
157188
return clipped_stream_id, [
158189
StateDelta(
159190
stream_id=row[0],
@@ -163,7 +194,7 @@ def get_current_state_deltas_txn(
163194
event_id=row[4],
164195
prev_event_id=row[5],
165196
)
166-
for row in txn.fetchall()
197+
for row in rows
167198
]
168199

169200
return await self.db_pool.runInteraction(

0 commit comments

Comments
 (0)