|
1 | 1 | import concurrent.futures
|
2 | 2 | import datetime
|
3 | 3 | import enum
|
| 4 | +import itertools |
4 | 5 | import uuid
|
5 | 6 | from dataclasses import dataclass
|
6 | 7 | from enum import Enum
|
|
12 | 13 | from .._grpc.grpcwrapper.ydb_topic import StreamWriteMessage
|
13 | 14 | from .._grpc.grpcwrapper.common_utils import IToProto
|
14 | 15 | from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec
|
| 16 | +from .. import connection |
15 | 17 |
|
16 | 18 | Message = typing.Union["PublicMessage", "PublicMessage.SimpleMessageSourceType"]
|
17 | 19 |
|
@@ -200,14 +202,94 @@ def default_serializer_message_content(data: Any) -> bytes:
|
200 | 202 | def messages_to_proto_requests(
|
201 | 203 | messages: List[InternalMessage],
|
202 | 204 | ) -> List[StreamWriteMessage.FromClient]:
|
203 |
| - # todo split by proto message size and codec |
204 |
| - res = [] |
205 |
| - for msg in messages: |
| 205 | + |
| 206 | + gropus = _slit_messages_for_send(messages) |
| 207 | + |
| 208 | + res = [] # type: List[StreamWriteMessage.FromClient] |
| 209 | + for group in gropus: |
206 | 210 | req = StreamWriteMessage.FromClient(
|
207 | 211 | StreamWriteMessage.WriteRequest(
|
208 |
| - messages=[msg.to_message_data()], |
209 |
| - codec=msg.codec, |
| 212 | + messages=list(map(InternalMessage.to_message_data, group)), |
| 213 | + codec=group[0].codec, |
210 | 214 | )
|
211 | 215 | )
|
212 | 216 | res.append(req)
|
213 | 217 | return res
|
| 218 | + |
| 219 | + |
| 220 | +_max_int = 2**63 - 1 |
| 221 | + |
| 222 | +_message_data_overhead = ( |
| 223 | + StreamWriteMessage.FromClient( |
| 224 | + StreamWriteMessage.WriteRequest( |
| 225 | + messages=[ |
| 226 | + StreamWriteMessage.WriteRequest.MessageData( |
| 227 | + seq_no=_max_int, |
| 228 | + created_at=datetime.datetime(3000, 1, 1, 1, 1, 1, 1), |
| 229 | + data=bytes(1), |
| 230 | + uncompressed_size=_max_int, |
| 231 | + partitioning=StreamWriteMessage.PartitioningMessageGroupID( |
| 232 | + message_group_id="a" * 100, |
| 233 | + ), |
| 234 | + ), |
| 235 | + ], |
| 236 | + codec=20000, |
| 237 | + ) |
| 238 | + ) |
| 239 | + .to_proto() |
| 240 | + .ByteSize() |
| 241 | +) |
| 242 | + |
| 243 | + |
| 244 | +def _slit_messages_for_send( |
| 245 | + messages: List[InternalMessage], |
| 246 | +) -> List[List[InternalMessage]]: |
| 247 | + codec_groups = [] # type: List[List[InternalMessage]] |
| 248 | + for _, messages in itertools.groupby(messages, lambda x: x.codec): |
| 249 | + codec_groups.append(list(messages)) |
| 250 | + |
| 251 | + res = [] # type: List[List[InternalMessage]] |
| 252 | + for codec_group in codec_groups: |
| 253 | + group_by_size = _split_messages_by_size_with_default_overhead(codec_group) |
| 254 | + res.extend(group_by_size) |
| 255 | + return res |
| 256 | + |
| 257 | + |
| 258 | +def _split_messages_by_size_with_default_overhead( |
| 259 | + messages: List[InternalMessage], |
| 260 | +) -> List[List[InternalMessage]]: |
| 261 | + def get_message_size(msg: InternalMessage): |
| 262 | + return len(msg.data) + _message_data_overhead |
| 263 | + |
| 264 | + return _split_messages_by_size( |
| 265 | + messages, connection._DEFAULT_MAX_GRPC_MESSAGE_SIZE, get_message_size |
| 266 | + ) |
| 267 | + |
| 268 | + |
| 269 | +def _split_messages_by_size( |
| 270 | + messages: List[InternalMessage], |
| 271 | + split_size: int, |
| 272 | + get_msg_size: typing.Callable[[InternalMessage], int], |
| 273 | +) -> List[List[InternalMessage]]: |
| 274 | + res = [] |
| 275 | + group = [] |
| 276 | + group_size = 0 |
| 277 | + |
| 278 | + for msg in messages: |
| 279 | + msg_size = get_msg_size(msg) |
| 280 | + |
| 281 | + if len(group) == 0: |
| 282 | + group.append(msg) |
| 283 | + group_size += msg_size |
| 284 | + elif group_size + msg_size <= split_size: |
| 285 | + group.append(msg) |
| 286 | + group_size += msg_size |
| 287 | + else: |
| 288 | + res.append(group) |
| 289 | + group = [msg] |
| 290 | + group_size = msg_size |
| 291 | + |
| 292 | + if len(group) > 0: |
| 293 | + res.append(group) |
| 294 | + |
| 295 | + return res |
0 commit comments