From aded4cc297b2ac56465189debc457f66ba903b48 Mon Sep 17 00:00:00 2001 From: Sergiusz Kierat Date: Tue, 3 Dec 2024 15:46:46 +0100 Subject: [PATCH] Chunked transmission lasts longer than timeout *Why I did it?* In order to have a test which might confirm an issue with an interrupted request *How I did it:* I prepared `NettyCatsRequestTimeoutTest` with the folloing test scenario: - send first chunk (100 bytes) - sleep - send second chunk (100 bytes) --- .../cats/NettyCatsRequestTimeoutTest.scala | 95 +++++++++++++++++++ .../netty/cats/NettyCatsServerTest.scala | 10 +- 2 files changed, 101 insertions(+), 4 deletions(-) create mode 100644 server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsRequestTimeoutTest.scala diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsRequestTimeoutTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsRequestTimeoutTest.scala new file mode 100644 index 0000000000..0954ae11f6 --- /dev/null +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsRequestTimeoutTest.scala @@ -0,0 +1,95 @@ +package sttp.tapir.server.netty.cats + +import cats.effect.{IO, Resource} +import cats.effect.std.Dispatcher +import cats.effect.unsafe.implicits.global +import io.netty.channel.EventLoopGroup +import org.scalatest.matchers.should.Matchers._ +import sttp.capabilities.WebSockets +import sttp.capabilities.fs2.Fs2Streams +import sttp.client3._ +import sttp.model.HeaderNames +import sttp.tapir._ +import sttp.tapir.server.netty.NettyConfig +import sttp.tapir.tests.Test + +import scala.concurrent.duration.DurationInt + +class NettyCatsRequestTimeoutTest( + dispatcher: Dispatcher[IO], + eventLoopGroup: EventLoopGroup, + backend: SttpBackend[IO, Fs2Streams[IO] with WebSockets] +) { + def tests(): List[Test] = List( + Test("chunked transmission lasts longer than given timeout") { + val givenRequestTimeout = 2.seconds + val howManyChunks: Int = 2 + val chunkSize = 100 + val millisBeforeSendingSecondChunk = 1000L + + val e = + endpoint.post + .in(header[Long](HeaderNames.ContentLength)) + .in(streamTextBody(Fs2Streams[IO])(CodecFormat.TextPlain())) + .out(header[Long](HeaderNames.ContentLength)) + .out(streamTextBody(Fs2Streams[IO])(CodecFormat.TextPlain())) + .serverLogicSuccess[IO] { case (length, stream) => + IO((length, stream)) + } + + val config = + NettyConfig.default + .eventLoopGroup(eventLoopGroup) + .randomPort + .withDontShutdownEventLoopGroupOnClose + .noGracefulShutdown + .requestTimeout(givenRequestTimeout) + + val bind = NettyCatsServer(dispatcher, config).addEndpoint(e).start() + + def iterator(howManyChunks: Int, chunkSize: Int): Iterator[Byte] = new Iterator[Iterator[Byte]] { + private var chunksToGo: Int = howManyChunks + + def hasNext: Boolean = { + if (chunksToGo == 1) + Thread.sleep(millisBeforeSendingSecondChunk) + chunksToGo > 0 + } + + def next(): Iterator[Byte] = { + chunksToGo -= 1 + List.fill('A')(chunkSize).map(_.toByte).iterator + } + }.flatten + + val inputStream = fs2.Stream.fromIterator[IO](iterator(howManyChunks, chunkSize), chunkSize = chunkSize) + + Resource + .make(bind)(_.stop()) + .map(_.port) + .use { port => + basicRequest + .post(uri"http://localhost:$port") + .contentLength(howManyChunks * chunkSize) + .streamBody(Fs2Streams[IO])(inputStream) + .send(backend) + .map { _ => + fail("I've got a bad feeling about this.") + } + } + .attempt + .map { + case Left(ex: sttp.client3.SttpClientException.TimeoutException) => + ex.getCause.getMessage shouldBe "request timed out" + case Left(ex: sttp.client3.SttpClientException.ReadException) if ex.getCause.isInstanceOf[java.io.IOException] => + println(s"Unexpected IOException: $ex") + fail(s"Unexpected IOException: $ex") + case Left(ex) => + fail(s"Unexpected exception: $ex") + case Right(_) => + fail("Expected an exception but got success") + } + .unsafeToFuture() + } + ) +} diff --git a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala index 73053ed362..7f38b698d9 100644 --- a/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala +++ b/server/netty-server/cats/src/test/scala/sttp/tapir/server/netty/cats/NettyCatsServerTest.scala @@ -24,9 +24,7 @@ class NettyCatsServerTest extends TestSuite with EitherValues { val interpreter = new NettyCatsTestServerInterpreter(eventLoopGroup, dispatcher) val createServerTest = new DefaultCreateServerTest(backend, interpreter) - val ioSleeper: Sleeper[IO] = new Sleeper[IO] { - override def sleep(duration: FiniteDuration): IO[Unit] = IO.sleep(duration) - } + val ioSleeper: Sleeper[IO] = (duration: FiniteDuration) => IO.sleep(duration) def drainFs2(stream: Fs2Streams[IO]#BinaryStream): IO[Unit] = stream.compile.drain.void @@ -50,7 +48,9 @@ class NettyCatsServerTest extends TestSuite with EitherValues { ) { override def functionToPipe[A, B](f: A => B): streams.Pipe[A, B] = in => in.map(f) override def emptyPipe[A, B]: fs2.Pipe[IO, A, B] = _ => fs2.Stream.empty - }.tests() + } + .tests() ++ + new NettyCatsRequestTimeoutTest(dispatcher, eventLoopGroup, backend).tests() IO.pure((tests, eventLoopGroup)) } { case (_, eventLoopGroup) => @@ -58,4 +58,6 @@ class NettyCatsServerTest extends TestSuite with EitherValues { } .map { case (tests, _) => tests } } + + override def testNameFilter: Option[String] = Some("chunked transmission lasts longer than given timeout") }