Skip to content

Commit

Permalink
Fix stream replay in validators (#1678)
Browse files Browse the repository at this point in the history
The current implementation of replaying the stream will always replay
the first message. This PR fixes this by progressing through the
messages with each call.
  • Loading branch information
RobbeSneyders authored Mar 30, 2023
1 parent 8cebebc commit 55e376f
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
9 changes: 6 additions & 3 deletions connexion/validators/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,14 @@ def _insert_messages(
receive: Receive, *, messages: t.Iterable[t.MutableMapping[str, t.Any]]
) -> Receive:
"""Insert messages at the start of the `receive` channel."""
# Ensure that messages is an iterator so each message is replayed once.
message_iterator = iter(messages)

async def receive_() -> t.MutableMapping[str, t.Any]:
for message in messages:
return message
return await receive()
try:
return next(message_iterator)
except StopIteration:
return await receive()

return receive_

Expand Down
29 changes: 28 additions & 1 deletion tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from connexion.exceptions import BadRequestProblem
from connexion.uri_parsing import Swagger2URIParser
from connexion.validators.parameter import ParameterValidator
from connexion.validators import AbstractRequestBodyValidator, ParameterValidator
from starlette.datastructures import QueryParams


Expand Down Expand Up @@ -140,3 +140,30 @@ def test_parameter_validator(monkeypatch):
with pytest.raises(BadRequestProblem) as exc:
validator.validate_request(request)
assert exc.value.detail.startswith("'x' is not one of ['a', 'b']")


async def test_stream_replay():
messages = [
{"body": b"message 1", "more_body": True},
{"body": b"message 2", "more_body": False},
]

async def receive():
return b""

wrapped_receive = AbstractRequestBodyValidator._insert_messages(
receive, messages=messages
)

replay = []
more_body = True
while more_body:
message = await wrapped_receive()
replay.append(message)
more_body = message.get("more_body", False)

assert len(replay) <= len(
messages
), "Replayed more messages than received, break out of while loop"

assert messages == replay

0 comments on commit 55e376f

Please sign in to comment.