Skip to content

Commit 5b20c16

Browse files
authored
Merge pull request #252 group messages to batches
group messages to batches
2 parents d580933 + e3b8118 commit 5b20c16

File tree

3 files changed

+143
-8
lines changed

3 files changed

+143
-8
lines changed

ydb/_topic_writer/topic_writer.py

Lines changed: 87 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import concurrent.futures
22
import datetime
33
import enum
4+
import itertools
45
import uuid
56
from dataclasses import dataclass
67
from enum import Enum
@@ -12,6 +13,7 @@
1213
from .._grpc.grpcwrapper.ydb_topic import StreamWriteMessage
1314
from .._grpc.grpcwrapper.common_utils import IToProto
1415
from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec
16+
from .. import connection
1517

1618
Message = typing.Union["PublicMessage", "PublicMessage.SimpleMessageSourceType"]
1719

@@ -200,14 +202,94 @@ def default_serializer_message_content(data: Any) -> bytes:
200202
def messages_to_proto_requests(
201203
messages: List[InternalMessage],
202204
) -> List[StreamWriteMessage.FromClient]:
203-
# todo split by proto message size and codec
204-
res = []
205-
for msg in messages:
205+
206+
gropus = _slit_messages_for_send(messages)
207+
208+
res = [] # type: List[StreamWriteMessage.FromClient]
209+
for group in gropus:
206210
req = StreamWriteMessage.FromClient(
207211
StreamWriteMessage.WriteRequest(
208-
messages=[msg.to_message_data()],
209-
codec=msg.codec,
212+
messages=list(map(InternalMessage.to_message_data, group)),
213+
codec=group[0].codec,
210214
)
211215
)
212216
res.append(req)
213217
return res
218+
219+
220+
_max_int = 2**63 - 1
221+
222+
_message_data_overhead = (
223+
StreamWriteMessage.FromClient(
224+
StreamWriteMessage.WriteRequest(
225+
messages=[
226+
StreamWriteMessage.WriteRequest.MessageData(
227+
seq_no=_max_int,
228+
created_at=datetime.datetime(3000, 1, 1, 1, 1, 1, 1),
229+
data=bytes(1),
230+
uncompressed_size=_max_int,
231+
partitioning=StreamWriteMessage.PartitioningMessageGroupID(
232+
message_group_id="a" * 100,
233+
),
234+
),
235+
],
236+
codec=20000,
237+
)
238+
)
239+
.to_proto()
240+
.ByteSize()
241+
)
242+
243+
244+
def _slit_messages_for_send(
245+
messages: List[InternalMessage],
246+
) -> List[List[InternalMessage]]:
247+
codec_groups = [] # type: List[List[InternalMessage]]
248+
for _, messages in itertools.groupby(messages, lambda x: x.codec):
249+
codec_groups.append(list(messages))
250+
251+
res = [] # type: List[List[InternalMessage]]
252+
for codec_group in codec_groups:
253+
group_by_size = _split_messages_by_size_with_default_overhead(codec_group)
254+
res.extend(group_by_size)
255+
return res
256+
257+
258+
def _split_messages_by_size_with_default_overhead(
259+
messages: List[InternalMessage],
260+
) -> List[List[InternalMessage]]:
261+
def get_message_size(msg: InternalMessage):
262+
return len(msg.data) + _message_data_overhead
263+
264+
return _split_messages_by_size(
265+
messages, connection._DEFAULT_MAX_GRPC_MESSAGE_SIZE, get_message_size
266+
)
267+
268+
269+
def _split_messages_by_size(
270+
messages: List[InternalMessage],
271+
split_size: int,
272+
get_msg_size: typing.Callable[[InternalMessage], int],
273+
) -> List[List[InternalMessage]]:
274+
res = []
275+
group = []
276+
group_size = 0
277+
278+
for msg in messages:
279+
msg_size = get_msg_size(msg)
280+
281+
if len(group) == 0:
282+
group.append(msg)
283+
group_size += msg_size
284+
elif group_size + msg_size <= split_size:
285+
group.append(msg)
286+
group_size += msg_size
287+
else:
288+
res.append(group)
289+
group = [msg]
290+
group_size = msg_size
291+
292+
if len(group) > 0:
293+
res.append(group)
294+
295+
return res
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from typing import List
2+
3+
import pytest
4+
5+
from .topic_writer import _split_messages_by_size
6+
7+
8+
@pytest.mark.parametrize(
9+
"messages,split_size,expected",
10+
[
11+
(
12+
[1, 2, 3],
13+
0,
14+
[[1], [2], [3]],
15+
),
16+
(
17+
[1, 2, 3],
18+
1,
19+
[[1], [2], [3]],
20+
),
21+
(
22+
[1, 2, 3],
23+
3,
24+
[[1, 2], [3]],
25+
),
26+
(
27+
[1, 2, 3],
28+
100,
29+
[[1, 2, 3]],
30+
),
31+
(
32+
[100, 2, 3],
33+
100,
34+
[[100], [2, 3]],
35+
),
36+
(
37+
[],
38+
100,
39+
[],
40+
),
41+
(
42+
[],
43+
100,
44+
[],
45+
),
46+
],
47+
)
48+
def test_split_messages_by_size(
49+
messages: List[int], split_size: int, expected: List[List[int]]
50+
):
51+
res = _split_messages_by_size(messages, split_size, lambda x: x) # noqa
52+
assert res == expected

ydb/connection.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
YDB_TRACE_ID_HEADER = "x-ydb-trace-id"
2525
YDB_REQUEST_TYPE_HEADER = "x-ydb-request-type"
2626

27+
_DEFAULT_MAX_GRPC_MESSAGE_SIZE = 64 * 10**6
28+
2729

2830
def _message_to_string(message):
2931
"""
@@ -179,10 +181,9 @@ def _construct_channel_options(driver_config, endpoint_options=None):
179181
:param endpoint_options: Endpoint options
180182
:return: A channel initialization options
181183
"""
182-
_max_message_size = 64 * 10**6
183184
_default_connect_options = [
184-
("grpc.max_receive_message_length", _max_message_size),
185-
("grpc.max_send_message_length", _max_message_size),
185+
("grpc.max_receive_message_length", _DEFAULT_MAX_GRPC_MESSAGE_SIZE),
186+
("grpc.max_send_message_length", _DEFAULT_MAX_GRPC_MESSAGE_SIZE),
186187
("grpc.primary_user_agent", driver_config.primary_user_agent),
187188
(
188189
"grpc.lb_policy_name",

0 commit comments

Comments
 (0)