diff --git a/tests/slo/playground/configs/chaos.sh b/tests/slo/playground/configs/chaos.sh index 550a6740..0466ecf8 100755 --- a/tests/slo/playground/configs/chaos.sh +++ b/tests/slo/playground/configs/chaos.sh @@ -31,7 +31,7 @@ do sh -c "docker stop ${nodeForChaos} -t 10" sh -c "docker start ${nodeForChaos}" - sleep 60 + sleep 30 done # for i in $(seq 1 3) diff --git a/tests/slo/requirements.txt b/tests/slo/requirements.txt index 877870cc..c81fe92f 100644 --- a/tests/slo/requirements.txt +++ b/tests/slo/requirements.txt @@ -1,4 +1,5 @@ requests==2.28.2 ratelimiter==1.2.0.post0 +aiolimiter==1.1.0 prometheus-client==0.17.0 quantile-estimator==0.1.2 diff --git a/tests/slo/slo_runner.sh b/tests/slo/slo_runner.sh index d44729e7..54012058 100755 --- a/tests/slo/slo_runner.sh +++ b/tests/slo/slo_runner.sh @@ -3,6 +3,7 @@ docker compose -f playground/configs/compose.yaml up -d --wait ../../.venv/bin/python ./src topic-create grpc://localhost:2135 /Root/testdb --path /Root/testdb/slo_topic -../../.venv/bin/python ./src topic-run grpc://localhost:2135 /Root/testdb --path /Root/testdb/slo_topic --prom-pgw "" --read-threads 0 --time 10 +../../.venv/bin/python ./src topic-run grpc://localhost:2135 /Root/testdb --path /Root/testdb/slo_topic --prom-pgw "" --read-threads 0 --write-rps 1 --time 120 -../../.venv/bin/python ./src topic-run grpc://localhost:2135 /Root/testdb --path /Root/testdb/slo_topic --prom-pgw "" --write-threads 0 --read-rps 1 --debug --time 600 \ No newline at end of file +# ../../.venv/bin/python ./src topic-run grpc://localhost:2135 /Root/testdb --path /Root/testdb/slo_topic --prom-pgw "" --read-threads 0 --time 5 +# ../../.venv/bin/python ./src topic-run grpc://localhost:2135 /Root/testdb --path /Root/testdb/slo_topic --prom-pgw "" --write-threads 0 --read-rps 1 --time 200 \ No newline at end of file diff --git a/tests/slo/src/jobs/async_topic_jobs.py b/tests/slo/src/jobs/async_topic_jobs.py new file mode 100644 index 00000000..d0838bb0 --- /dev/null +++ b/tests/slo/src/jobs/async_topic_jobs.py @@ -0,0 +1,130 @@ +import asyncio +import ydb.aio +import time +import logging +from aiolimiter import AsyncLimiter + +from .base import BaseJobManager +from core.metrics import OP_TYPE_READ, OP_TYPE_WRITE + +logger = logging.getLogger(__name__) + + +class AsyncTopicJobManager(BaseJobManager): + def __init__(self, driver, args, metrics): + super().__init__(driver, args, metrics) + self.driver: ydb.aio.Driver = driver + + async def run_tests(self): + tasks = [ + *await self._run_topic_write_jobs(), + *await self._run_topic_read_jobs(), + *self._run_metric_job(), + ] + + await asyncio.gather(*tasks) + + async def _run_topic_write_jobs(self): + logger.info("Start async topic write jobs") + + write_limiter = AsyncLimiter(max_rate=self.args.write_rps, time_period=1) + + tasks = [] + for i in range(self.args.write_threads): + task = asyncio.create_task(self._run_topic_writes(write_limiter, i), name=f"slo_topic_write_{i}") + tasks.append(task) + + return tasks + + async def _run_topic_read_jobs(self): + logger.info("Start async topic read jobs") + + read_limiter = AsyncLimiter(max_rate=self.args.read_rps, time_period=1) + + tasks = [] + for i in range(self.args.read_threads): + task = asyncio.create_task(self._run_topic_reads(read_limiter), name=f"slo_topic_read_{i}") + tasks.append(task) + + return tasks + + async def _run_topic_writes(self, limiter, partition_id=None): + start_time = time.time() + logger.info("Start async topic write workload") + + async with self.driver.topic_client.writer( + self.args.path, + codec=ydb.TopicCodec.GZIP, + partition_id=partition_id, + ) as writer: + logger.info("Async topic writer created") + + message_count = 0 + while time.time() - start_time < self.args.time: + async with limiter: + message_count += 1 + + content = f"message_{message_count}_{asyncio.current_task().get_name()}".encode("utf-8") + + if len(content) < self.args.message_size: + content += b"x" * (self.args.message_size - len(content)) + + message = ydb.TopicWriterMessage(data=content) + + ts = self.metrics.start((OP_TYPE_WRITE,)) + try: + await writer.write_with_ack(message) + logger.info("Write message: %s", content) + self.metrics.stop((OP_TYPE_WRITE,), ts) + except Exception as e: + self.metrics.stop((OP_TYPE_WRITE,), ts, error=e) + logger.error("Write error: %s", e) + + logger.info("Stop async topic write workload") + + async def _run_topic_reads(self, limiter): + start_time = time.time() + logger.info("Start async topic read workload") + + async with self.driver.topic_client.reader( + self.args.path, + self.args.consumer, + ) as reader: + logger.info("Async topic reader created") + + while time.time() - start_time < self.args.time: + async with limiter: + ts = self.metrics.start((OP_TYPE_READ,)) + try: + msg = await reader.receive_message() + if msg is not None: + logger.info("Read message: %s", msg.data.decode()) + await reader.commit_with_ack(msg) + + self.metrics.stop((OP_TYPE_READ,), ts) + except Exception as e: + self.metrics.stop((OP_TYPE_READ,), ts, error=e) + logger.error("Read error: %s", e) + + logger.info("Stop async topic read workload") + + def _run_metric_job(self): + if not self.args.prom_pgw: + return [] + + # Create async task for metrics + task = asyncio.create_task(self._async_metric_sender(self.args.time), name="slo_metrics_sender") + return [task] + + async def _async_metric_sender(self, runtime): + start_time = time.time() + logger.info("Start push metrics (async)") + + limiter = AsyncLimiter(max_rate=10**6 // self.args.report_period, time_period=1) + + while time.time() - start_time < runtime: + async with limiter: + # Call sync metrics.push() in executor to avoid blocking + await asyncio.get_event_loop().run_in_executor(None, self.metrics.push) + + logger.info("Stop push metrics (async)") diff --git a/tests/slo/src/options.py b/tests/slo/src/options.py index a634bc89..b15bf8f7 100644 --- a/tests/slo/src/options.py +++ b/tests/slo/src/options.py @@ -6,6 +6,7 @@ def add_common_options(parser): parser.add_argument("db", help="YDB database name") parser.add_argument("-t", "--table-name", default="key_value", help="Table name") parser.add_argument("--debug", action="store_true", help="Enable debug logging") + parser.add_argument("--async", action="store_true", help="Use async mode for operations") def make_table_create_parser(subparsers): diff --git a/tests/slo/src/root_runner.py b/tests/slo/src/root_runner.py index 3bf8a8a0..20589c14 100644 --- a/tests/slo/src/root_runner.py +++ b/tests/slo/src/root_runner.py @@ -1,4 +1,6 @@ +import asyncio import ydb +import ydb.aio import logging from typing import Dict @@ -26,6 +28,15 @@ def run_command(self, args): raise ValueError(f"Unknown prefix: {prefix}. Available: {list(self.runners.keys())}") runner_instance = self.runners[prefix]() + + # Check if async mode is requested and command is 'run' + if getattr(args, "async", False) and command == "run": + asyncio.run(self._run_async_command(args, runner_instance, command)) + else: + self._run_sync_command(args, runner_instance, command) + + def _run_sync_command(self, args, runner_instance, command): + """Run command in synchronous mode""" driver_config = ydb.DriverConfig( args.endpoint, database=args.db, @@ -43,13 +54,33 @@ def run_command(self, args): elif command == "cleanup": runner_instance.cleanup(args) else: - raise RuntimeError(f"Unknown command {command} for prefix {prefix}") + raise RuntimeError(f"Unknown command {command} for prefix {runner_instance.prefix}") except BaseException: logger.exception("Something went wrong") raise finally: driver.stop(timeout=getattr(args, "shutdown_time", 10)) + async def _run_async_command(self, args, runner_instance, command): + """Run command in asynchronous mode""" + driver_config = ydb.DriverConfig( + args.endpoint, + database=args.db, + grpc_keep_alive_timeout=5000, + ) + + async with ydb.aio.Driver(driver_config) as driver: + await driver.wait(timeout=300) + try: + runner_instance.set_driver(driver) + if command == "run": + await runner_instance.run_async(args) + else: + raise RuntimeError(f"Async mode only supports 'run' command, got '{command}'") + except BaseException: + logger.exception("Something went wrong in async mode") + raise + def create_runner() -> SLORunner: runner = SLORunner() diff --git a/tests/slo/src/runners/base.py b/tests/slo/src/runners/base.py index 1f9eda2e..66f6fde1 100644 --- a/tests/slo/src/runners/base.py +++ b/tests/slo/src/runners/base.py @@ -24,6 +24,9 @@ def create(self, args): def run(self, args): pass + async def run_async(self, args): + raise NotImplementedError(f"Async mode not supported for {self.prefix}") + @abstractmethod def cleanup(self, args): pass diff --git a/tests/slo/src/runners/topic_runner.py b/tests/slo/src/runners/topic_runner.py index 7a9be942..10c72125 100644 --- a/tests/slo/src/runners/topic_runner.py +++ b/tests/slo/src/runners/topic_runner.py @@ -1,8 +1,10 @@ import time import ydb +import ydb.aio from .base import BaseRunner from jobs.topic_jobs import TopicJobManager +from jobs.async_topic_jobs import AsyncTopicJobManager from core.metrics import create_metrics @@ -76,6 +78,21 @@ def run(self, args): if hasattr(metrics, "reset"): metrics.reset() + async def run_async(self, args): + """Async version of topic SLO tests using ydb.aio.Driver""" + metrics = create_metrics(args.prom_pgw) + + self.logger.info("Starting async topic SLO tests") + + # Use async driver for topic operations + job_manager = AsyncTopicJobManager(self.driver, args, metrics) + await job_manager.run_tests() + + self.logger.info("Async topic SLO tests completed") + + if hasattr(metrics, "reset"): + metrics.reset() + def cleanup(self, args): self.logger.info("Cleaning up topic: %s", args.path) diff --git a/ydb/_errors.py b/ydb/_errors.py index b19de749..3f426350 100644 --- a/ydb/_errors.py +++ b/ydb/_errors.py @@ -1,7 +1,5 @@ from dataclasses import dataclass -from typing import Optional, Union - -import grpc +from typing import Optional from . import issues @@ -15,6 +13,7 @@ issues.Overloaded, issues.SessionPoolEmpty, issues.ConnectionError, + issues.ConnectionLost, ] _errors_retriable_slow_backoff_idempotent_types = [ issues.Undetermined, @@ -22,6 +21,10 @@ def check_retriable_error(err, retry_settings, attempt): + if isinstance(err, issues.Cancelled): + if retry_settings.retry_cancelled: + return ErrorRetryInfo(True, retry_settings.fast_backoff.calc_timeout(attempt)) + if isinstance(err, issues.NotFound): if retry_settings.retry_not_found: return ErrorRetryInfo(True, retry_settings.fast_backoff.calc_timeout(attempt)) @@ -54,26 +57,3 @@ def check_retriable_error(err, retry_settings, attempt): class ErrorRetryInfo: is_retriable: bool sleep_timeout_seconds: Optional[float] - - -def stream_error_converter(exc: BaseException) -> Union[issues.Error, BaseException]: - """Converts gRPC stream errors to appropriate YDB exception types. - - This function takes a base exception and converts specific gRPC aio stream errors - to their corresponding YDB exception types for better error handling and semantic - clarity. - - Args: - exc (BaseException): The original exception to potentially convert. - - Returns: - BaseException: Either a converted YDB exception or the original exception - if no specific conversion rule applies. - """ - if isinstance(exc, (grpc.RpcError, grpc.aio.AioRpcError)): - if exc.code() == grpc.StatusCode.UNAVAILABLE: - return issues.Unavailable(exc.details() or "") - if exc.code() == grpc.StatusCode.DEADLINE_EXCEEDED: - return issues.DeadlineExceed("Deadline exceeded on request") - return issues.Error("Stream has been terminated. Original exception: {}".format(str(exc.details()))) - return exc diff --git a/ydb/_grpc/grpcwrapper/common_utils.py b/ydb/_grpc/grpcwrapper/common_utils.py index 9b4529de..0fb960d6 100644 --- a/ydb/_grpc/grpcwrapper/common_utils.py +++ b/ydb/_grpc/grpcwrapper/common_utils.py @@ -6,6 +6,7 @@ import contextvars import datetime import functools +import logging import typing from typing import ( Optional, @@ -37,6 +38,8 @@ from ...settings import BaseRequestSettings from ..._constants import DEFAULT_LONG_STREAM_TIMEOUT +logger = logging.getLogger(__name__) + class IFromProto(abc.ABC): @staticmethod diff --git a/ydb/_topic_reader/topic_reader.py b/ydb/_topic_reader/topic_reader.py index d477c9ca..38ee1be6 100644 --- a/ydb/_topic_reader/topic_reader.py +++ b/ydb/_topic_reader/topic_reader.py @@ -85,7 +85,7 @@ def _init_message(self) -> StreamReadMessage.InitRequest: ) def _retry_settings(self) -> RetrySettings: - return RetrySettings(idempotent=True) + return RetrySettings(idempotent=True, retry_cancelled=True) class RetryPolicy: diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index b855a80b..818eb1a9 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -248,12 +248,15 @@ async def _connection_loop(self): self._state_changed.set() await self._stream_reader.wait_error() except BaseException as err: + logger.debug("reader %s, attempt %s connection loop error %s", self._id, attempt, err) retry_info = check_retriable_error(err, self._settings._retry_settings(), attempt) if not retry_info.is_retriable: logger.debug("reader %s stop connection loop due to %s", self._id, err) self._set_first_error(err) return + logger.debug("sleep before retry for %s seconds", retry_info.sleep_timeout_seconds) + await asyncio.sleep(retry_info.sleep_timeout_seconds) attempt += 1 diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index d39606d1..1c0e410b 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -434,7 +434,7 @@ def _check_stop(self): raise self._stop_reason.exception() async def _connection_loop(self): - retry_settings = RetrySettings() # todo + retry_settings = RetrySettings(retry_cancelled=True) # todo while True: attempt = 0 # todo calc and reset @@ -485,15 +485,16 @@ async def _connection_loop(self): except issues.Error as err: err_info = check_retriable_error(err, retry_settings, attempt) if not err_info.is_retriable or self._tx is not None: # no retries in tx writer + logger.debug("writer reconnector %s stop connection loop due to %s", self._id, err) self._stop(err) return - await asyncio.sleep(err_info.sleep_timeout_seconds) logger.debug( "writer reconnector %s retry in %s seconds", self._id, err_info.sleep_timeout_seconds, ) + await asyncio.sleep(err_info.sleep_timeout_seconds) except (asyncio.CancelledError, Exception) as err: self._stop(err) diff --git a/ydb/aio/_utilities.py b/ydb/aio/_utilities.py index 53a7d412..062545d8 100644 --- a/ydb/aio/_utilities.py +++ b/ydb/aio/_utilities.py @@ -2,10 +2,9 @@ class AsyncResponseIterator(object): - def __init__(self, it, wrapper, error_converter=None): + def __init__(self, it, wrapper): self.it = it.__aiter__() self.wrapper = wrapper - self.error_converter = error_converter def cancel(self): self.it.cancel() @@ -18,12 +17,7 @@ def __aiter__(self): return self async def _next(self): - try: - res = self.wrapper(await self.it.__anext__()) - except BaseException as e: - if self.error_converter: - raise self.error_converter(e) from e - raise e + res = self.wrapper(await self.it.__anext__()) if res is not None: return res diff --git a/ydb/aio/connection.py b/ydb/aio/connection.py index 1f328a37..eab8ee0c 100644 --- a/ydb/aio/connection.py +++ b/ydb/aio/connection.py @@ -11,6 +11,7 @@ _log_request, _log_response, _rpc_error_handler, + _is_disconnect_needed, _get_request_timeout, _set_server_timeouts, _RpcState as RpcState, @@ -103,6 +104,30 @@ def future(self, *args, **kwargs): raise NotImplementedError +class _SafeAsyncIterator: + def __init__(self, resp, rpc_state, on_disconnected_callback): + self.resp = resp + self.it = resp.__aiter__() + self.rpc_state = rpc_state + self.on_disconnected_callback = on_disconnected_callback + + def cancel(self): + self.resp.cancel() + return self + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return await self.it.__anext__() + except grpc.RpcError as rpc_error: + ydb_error = _rpc_error_handler(self.rpc_state, rpc_error, use_unavailable=True) + if _is_disconnect_needed(ydb_error): + await self.on_disconnected_callback() + raise ydb_error + + class Connection: __slots__ = ( "endpoint", @@ -192,6 +217,11 @@ async def __call__( response = await feature _log_response(rpc_state, response) + + if hasattr(response, "__aiter__"): + # NOTE(vgvoleg): for stream results we should also be able to handle disconnects + response = _SafeAsyncIterator(response, rpc_state, on_disconnected) + return response if wrap_result is None else wrap_result(rpc_state, response, *wrap_args) except grpc.RpcError as rpc_error: if on_disconnected: @@ -235,7 +265,7 @@ async def close(self, grace: float = 30): been terminated are cancelled. If grace is None, this method will wait until all tasks are finished. :return: None """ - logger.info("Closing channel for endpoint %s", self.endpoint) + logger.debug("Closing channel for endpoint %s", self.endpoint) self.closing = True diff --git a/ydb/aio/pool.py b/ydb/aio/pool.py index 99a3cfdb..0c75fa30 100644 --- a/ydb/aio/pool.py +++ b/ydb/aio/pool.py @@ -270,7 +270,7 @@ async def __call__( self._discovery.notify_disconnected() raise - return await connection( + res = await connection( request, stub, rpc_name, @@ -279,3 +279,5 @@ async def __call__( wrap_args, self._on_disconnected(connection), ) + + return res diff --git a/ydb/aio/query/session.py b/ydb/aio/query/session.py index 13906164..98ea1849 100644 --- a/ydb/aio/query/session.py +++ b/ydb/aio/query/session.py @@ -23,7 +23,6 @@ ) from ..._constants import DEFAULT_INITIAL_RESPONSE_TIMEOUT -from ..._errors import stream_error_converter class QuerySession(BaseQuerySession): @@ -164,7 +163,6 @@ async def execute( session=self, settings=self._settings, ), - error_converter=stream_error_converter, ) async def explain( diff --git a/ydb/aio/query/transaction.py b/ydb/aio/query/transaction.py index 2c313a4a..9b2db2ef 100644 --- a/ydb/aio/query/transaction.py +++ b/ydb/aio/query/transaction.py @@ -11,7 +11,6 @@ BaseQueryTxContext, QueryTxStateEnum, ) -from ..._errors import stream_error_converter logger = logging.getLogger(__name__) @@ -191,6 +190,5 @@ async def execute( commit_tx=commit_tx, settings=self.session._settings, ), - error_converter=stream_error_converter, ) return self._prev_stream diff --git a/ydb/connection.py b/ydb/connection.py index d5b6ed50..5f9d5ec8 100644 --- a/ydb/connection.py +++ b/ydb/connection.py @@ -67,6 +67,7 @@ def _rpc_error_handler( rpc_state, rpc_error: typing.Union[grpc.RpcError, grpc.aio.AioRpcError, grpc.Call, grpc.aio.Call], on_disconnected: typing.Callable[[], None] = None, + use_unavailable: bool = False, ): """ RPC call error handler, that translates gRPC error into YDB issue @@ -74,7 +75,7 @@ def _rpc_error_handler( :param rpc_error: an underlying rpc error to handle :param on_disconnected: a handler to call on disconnected connection """ - logger.info("%s: received error, %s", rpc_state, rpc_error) + logger.debug("%s: received error, %s", rpc_state, rpc_error) if isinstance(rpc_error, (grpc.RpcError, grpc.aio.AioRpcError, grpc.Call, grpc.aio.Call)): if rpc_error.code() == grpc.StatusCode.UNAUTHENTICATED: return issues.Unauthenticated(rpc_error.details()) @@ -82,6 +83,10 @@ def _rpc_error_handler( return issues.DeadlineExceed("Deadline exceeded on request") elif rpc_error.code() == grpc.StatusCode.UNIMPLEMENTED: return issues.Unimplemented("Method or feature is not implemented on server!") + elif rpc_error.code() == grpc.StatusCode.CANCELLED: + return issues.Cancelled(rpc_error.details()) + elif use_unavailable and rpc_error.code() == grpc.StatusCode.UNAVAILABLE: + return issues.Unavailable(rpc_error.details()) logger.debug("%s: unhandled rpc error, disconnecting channel", rpc_state) if on_disconnected is not None: @@ -90,6 +95,16 @@ def _rpc_error_handler( return issues.ConnectionLost("Rpc error, reason %s" % str(rpc_error)) +def _is_disconnect_needed(error): + return isinstance( + error, + ( + issues.ConnectionLost, + issues.Unavailable, + ), + ) + + def _on_response_callback(rpc_state, call_state_unref, wrap_result=None, on_disconnected=None, wrap_args=()): """ Callback to be executed on received RPC response @@ -325,6 +340,30 @@ def __init__(self, endpoint, node_id): self.node_id = node_id +class _SafeSyncIterator: + def __init__(self, resp, rpc_state, on_disconnected_callback): + self.resp = resp + self.it = resp.__iter__() + self.rpc_state = rpc_state + self.on_disconnected_callback = on_disconnected_callback + + def cancel(self): + self.resp.cancel() + return self + + def __iter__(self): + return self + + def __next__(self): + try: + return self.it.__next__() + except grpc.RpcError as rpc_error: + ydb_error = _rpc_error_handler(self.rpc_state, rpc_error, use_unavailable=True) + if _is_disconnect_needed(ydb_error): + self.on_disconnected_callback() + raise ydb_error + + class Connection(object): __slots__ = ( "endpoint", @@ -466,6 +505,11 @@ def __call__( compression=getattr(settings, "compression", None), ) _log_response(rpc_state, response) + + if hasattr(response, "__iter__"): + # NOTE(vgvoleg): for stream results we should also be able to handle disconnects + response = _SafeSyncIterator(response, rpc_state, on_disconnected) + return response if wrap_result is None else wrap_result(rpc_state, response, *wrap_args) except grpc.RpcError as rpc_error: raise _rpc_error_handler(rpc_state, rpc_error, on_disconnected) @@ -499,7 +543,7 @@ def close(self): Closes the underlying gRPC channel :return: None """ - logger.info("Closing channel for endpoint %s", self.endpoint) + logger.debug("Closing channel for endpoint %s", self.endpoint) with self.lock: self.closing = True diff --git a/ydb/pool.py b/ydb/pool.py index 476ea674..0fb3e86a 100644 --- a/ydb/pool.py +++ b/ydb/pool.py @@ -470,7 +470,9 @@ def __call__( wrap_args, lambda: self._on_disconnected(connection), ) + tracing.trace(self.tracer, {"response": res}, trace_level=tracing.TraceLevel.DEBUG) + return res @_utilities.wrap_async_call_exceptions diff --git a/ydb/retries.py b/ydb/retries.py index c9c23b1a..5331f1b0 100644 --- a/ydb/retries.py +++ b/ydb/retries.py @@ -32,6 +32,7 @@ def __init__( fast_backoff_settings=None, slow_backoff_settings=None, idempotent=False, + retry_cancelled=False, ): self.max_retries = max_retries self.max_session_acquire_timeout = max_session_acquire_timeout @@ -45,6 +46,7 @@ def __init__( self.retry_not_found = True self.idempotent = idempotent self.retry_internal_error = True + self.retry_cancelled = retry_cancelled self.unknown_error_handler = lambda e: None self.get_session_client_timeout = get_session_client_timeout if max_session_acquire_timeout is not None: