diff --git a/armeria-backend/src/main/scala/sttp/client4/armeria/AbstractArmeriaBackend.scala b/armeria-backend/src/main/scala/sttp/client4/armeria/AbstractArmeriaBackend.scala index cc4a379fa6..815c601c36 100644 --- a/armeria-backend/src/main/scala/sttp/client4/armeria/AbstractArmeriaBackend.scala +++ b/armeria-backend/src/main/scala/sttp/client4/armeria/AbstractArmeriaBackend.scala @@ -40,8 +40,6 @@ import sttp.model._ import sttp.monad.syntax._ import sttp.monad.{Canceler, MonadAsyncError} import sttp.client4.compression.Compressor -import sttp.client4.compression.GZipDefaultCompressor -import sttp.client4.compression.DeflateDefaultCompressor abstract class AbstractArmeriaBackend[F[_], S <: Streams[S]]( client: WebClient = WebClient.of(), @@ -57,7 +55,7 @@ abstract class AbstractArmeriaBackend[F[_], S <: Streams[S]]( protected def streamToPublisher(stream: streams.BinaryStream): Publisher[HttpData] - protected def compressors: List[Compressor[R]] = List(new GZipDefaultCompressor(), new DeflateDefaultCompressor()) + protected def compressors: List[Compressor[R]] = Compressor.default[R] override def send[T](request: GenericRequest[T, R]): F[Response[T]] = monad.suspend(adjustExceptions(request)(execute(request))) diff --git a/core/src/main/scala/sttp/client4/compression/Compressor.scala b/core/src/main/scala/sttp/client4/compression/Compressor.scala index 3a008d79d6..3e2de0adbf 100644 --- a/core/src/main/scala/sttp/client4/compression/Compressor.scala +++ b/core/src/main/scala/sttp/client4/compression/Compressor.scala @@ -1,89 +1,14 @@ package sttp.client4.compression import sttp.client4._ -import sttp.model.Encodings - -import Compressor._ import java.nio.ByteBuffer -import java.util.zip.DeflaterInputStream -import java.util.zip.Deflater -import java.io.ByteArrayOutputStream trait Compressor[R] { def encoding: String def apply(body: GenericRequestBody[R], encoding: String): GenericRequestBody[R] } -class GZipDefaultCompressor[R] extends Compressor[R] { - val encoding: String = Encodings.Gzip - - def apply(body: GenericRequestBody[R], encoding: String): GenericRequestBody[R] = - body match { - case NoBody => NoBody - case StringBody(s, encoding, defaultContentType) => - ByteArrayBody(byteArray(s.getBytes(encoding)), defaultContentType) - case ByteArrayBody(b, defaultContentType) => ByteArrayBody(byteArray(b), defaultContentType) - case ByteBufferBody(b, defaultContentType) => - ByteArrayBody(byteArray(byteBufferToArray(b)), defaultContentType) - case InputStreamBody(b, defaultContentType) => - InputStreamBody(new GZIPCompressingInputStream(b), defaultContentType) - case StreamBody(b) => streamsNotSupported - case FileBody(f, defaultContentType) => - InputStreamBody(new GZIPCompressingInputStream(f.openStream()), defaultContentType) - case MultipartStreamBody(parts) => compressingMultipartBodiesNotSupported - case BasicMultipartBody(parts) => compressingMultipartBodiesNotSupported - } - - private def byteArray(bytes: Array[Byte]): Array[Byte] = { - val bos = new java.io.ByteArrayOutputStream() - val gzip = new java.util.zip.GZIPOutputStream(bos) - gzip.write(bytes) - gzip.close() - bos.toByteArray() - } -} - -class DeflateDefaultCompressor[R] extends Compressor[R] { - val encoding: String = Encodings.Deflate - - def apply(body: GenericRequestBody[R], encoding: String): GenericRequestBody[R] = - body match { - case NoBody => NoBody - case StringBody(s, encoding, defaultContentType) => - ByteArrayBody(byteArray(s.getBytes(encoding)), defaultContentType) - case ByteArrayBody(b, defaultContentType) => ByteArrayBody(byteArray(b), defaultContentType) - case ByteBufferBody(b, defaultContentType) => - ByteArrayBody(byteArray(byteBufferToArray(b)), defaultContentType) - case InputStreamBody(b, defaultContentType) => - InputStreamBody(new DeflaterInputStream(b), defaultContentType) - case StreamBody(b) => streamsNotSupported - case FileBody(f, defaultContentType) => - InputStreamBody(new DeflaterInputStream(f.openStream()), defaultContentType) - case MultipartStreamBody(parts) => compressingMultipartBodiesNotSupported - case BasicMultipartBody(parts) => compressingMultipartBodiesNotSupported - } - - private def byteArray(bytes: Array[Byte]): Array[Byte] = { - val deflater = new Deflater() - try { - deflater.setInput(bytes) - deflater.finish() - val byteArrayOutputStream = new ByteArrayOutputStream() - val readBuffer = new Array[Byte](1024) - - while (!deflater.finished()) { - val readCount = deflater.deflate(readBuffer) - if (readCount > 0) { - byteArrayOutputStream.write(readBuffer, 0, readCount) - } - } - - byteArrayOutputStream.toByteArray - } finally deflater.end() - } -} - -private[client4] object Compressor { +object Compressor extends CompressorExtensions { /** Compress the request body if needed, using the given compressors. * @return diff --git a/core/src/main/scalajs/sttp/client4/compression/CompressorExtensions.scala b/core/src/main/scalajs/sttp/client4/compression/CompressorExtensions.scala new file mode 100644 index 0000000000..9831ced732 --- /dev/null +++ b/core/src/main/scalajs/sttp/client4/compression/CompressorExtensions.scala @@ -0,0 +1,5 @@ +package sttp.client4.compression + +trait CompressorExtensions { + def default[R]: List[Compressor[R]] = Nil +} diff --git a/core/src/main/scalajvm/sttp/client4/compression/CompressorExtensions.scala b/core/src/main/scalajvm/sttp/client4/compression/CompressorExtensions.scala new file mode 100644 index 0000000000..a78a366806 --- /dev/null +++ b/core/src/main/scalajvm/sttp/client4/compression/CompressorExtensions.scala @@ -0,0 +1,5 @@ +package sttp.client4.compression + +trait CompressorExtensions { + def default[R]: List[Compressor[R]] = List(new GZipDefaultCompressor[R](), new DeflateDefaultCompressor[R]()) +} diff --git a/core/src/main/scala/sttp/client4/compression/GZIPCompressingInputStream.scala b/core/src/main/scalajvm/sttp/client4/compression/GZIPCompressingInputStream.scala similarity index 100% rename from core/src/main/scala/sttp/client4/compression/GZIPCompressingInputStream.scala rename to core/src/main/scalajvm/sttp/client4/compression/GZIPCompressingInputStream.scala diff --git a/core/src/main/scalajvm/sttp/client4/compression/defaultCompressors.scala b/core/src/main/scalajvm/sttp/client4/compression/defaultCompressors.scala new file mode 100644 index 0000000000..a63ed6c98b --- /dev/null +++ b/core/src/main/scalajvm/sttp/client4/compression/defaultCompressors.scala @@ -0,0 +1,78 @@ +package sttp.client4.compression + +import sttp.client4._ +import sttp.model.Encodings + +import Compressor._ +import java.util.zip.Deflater +import java.util.zip.DeflaterInputStream +import java.io.ByteArrayOutputStream + +class GZipDefaultCompressor[R] extends Compressor[R] { + val encoding: String = Encodings.Gzip + + def apply(body: GenericRequestBody[R], encoding: String): GenericRequestBody[R] = + body match { + case NoBody => NoBody + case StringBody(s, encoding, defaultContentType) => + ByteArrayBody(byteArray(s.getBytes(encoding)), defaultContentType) + case ByteArrayBody(b, defaultContentType) => ByteArrayBody(byteArray(b), defaultContentType) + case ByteBufferBody(b, defaultContentType) => + ByteArrayBody(byteArray(byteBufferToArray(b)), defaultContentType) + case InputStreamBody(b, defaultContentType) => + InputStreamBody(new GZIPCompressingInputStream(b), defaultContentType) + case StreamBody(b) => streamsNotSupported + case FileBody(f, defaultContentType) => + InputStreamBody(new GZIPCompressingInputStream(f.openStream()), defaultContentType) + case MultipartStreamBody(parts) => compressingMultipartBodiesNotSupported + case BasicMultipartBody(parts) => compressingMultipartBodiesNotSupported + } + + private def byteArray(bytes: Array[Byte]): Array[Byte] = { + val bos = new java.io.ByteArrayOutputStream() + val gzip = new java.util.zip.GZIPOutputStream(bos) + gzip.write(bytes) + gzip.close() + bos.toByteArray() + } +} + +class DeflateDefaultCompressor[R] extends Compressor[R] { + val encoding: String = Encodings.Deflate + + def apply(body: GenericRequestBody[R], encoding: String): GenericRequestBody[R] = + body match { + case NoBody => NoBody + case StringBody(s, encoding, defaultContentType) => + ByteArrayBody(byteArray(s.getBytes(encoding)), defaultContentType) + case ByteArrayBody(b, defaultContentType) => ByteArrayBody(byteArray(b), defaultContentType) + case ByteBufferBody(b, defaultContentType) => + ByteArrayBody(byteArray(byteBufferToArray(b)), defaultContentType) + case InputStreamBody(b, defaultContentType) => + InputStreamBody(new DeflaterInputStream(b), defaultContentType) + case StreamBody(b) => streamsNotSupported + case FileBody(f, defaultContentType) => + InputStreamBody(new DeflaterInputStream(f.openStream()), defaultContentType) + case MultipartStreamBody(parts) => compressingMultipartBodiesNotSupported + case BasicMultipartBody(parts) => compressingMultipartBodiesNotSupported + } + + private def byteArray(bytes: Array[Byte]): Array[Byte] = { + val deflater = new Deflater() + try { + deflater.setInput(bytes) + deflater.finish() + val byteArrayOutputStream = new ByteArrayOutputStream() + val readBuffer = new Array[Byte](1024) + + while (!deflater.finished()) { + val readCount = deflater.deflate(readBuffer) + if (readCount > 0) { + byteArrayOutputStream.write(readBuffer, 0, readCount) + } + } + + byteArrayOutputStream.toByteArray + } finally deflater.end() + } +} diff --git a/core/src/main/scalajvm/sttp/client4/httpurlconnection/HttpURLConnectionBackend.scala b/core/src/main/scalajvm/sttp/client4/httpurlconnection/HttpURLConnectionBackend.scala index 69d3849eab..3db3cbdf15 100644 --- a/core/src/main/scalajvm/sttp/client4/httpurlconnection/HttpURLConnectionBackend.scala +++ b/core/src/main/scalajvm/sttp/client4/httpurlconnection/HttpURLConnectionBackend.scala @@ -3,7 +3,7 @@ package sttp.client4.httpurlconnection import sttp.capabilities.Effect import sttp.client4.httpurlconnection.HttpURLConnectionBackend.EncodingHandler import sttp.client4.internal._ -import sttp.client4.compression.{Compressor, DeflateDefaultCompressor, GZipDefaultCompressor} +import sttp.client4.compression.Compressor import sttp.client4.testing.SyncBackendStub import sttp.client4.ws.{GotAWebSocketException, NotAWebSocketException} import sttp.client4.{ @@ -49,7 +49,7 @@ class HttpURLConnectionBackend private ( ) extends SyncBackend { type R = Any with Effect[Identity] - private val compressors: List[Compressor[R]] = List(new GZipDefaultCompressor(), new DeflateDefaultCompressor()) + private val compressors: List[Compressor[R]] = Compressor.default[R] override def send[T](r: GenericRequest[T, R]): Response[T] = adjustExceptions(r) { diff --git a/core/src/main/scalajvm/sttp/client4/internal/httpclient/BodyToHttpClient.scala b/core/src/main/scalajvm/sttp/client4/internal/httpclient/BodyToHttpClient.scala index 591fb7c532..c5e9657e0f 100644 --- a/core/src/main/scalajvm/sttp/client4/internal/httpclient/BodyToHttpClient.scala +++ b/core/src/main/scalajvm/sttp/client4/internal/httpclient/BodyToHttpClient.scala @@ -3,7 +3,7 @@ package sttp.client4.internal.httpclient import sttp.capabilities.Streams import sttp.client4.internal.SttpToJavaConverters.toJavaSupplier import sttp.client4.internal.{throwNestedMultipartNotAllowed, Utf8} -import sttp.client4.compression.{Compressor, DeflateDefaultCompressor, GZipDefaultCompressor} +import sttp.client4.compression.Compressor import sttp.client4._ import sttp.model.{Header, HeaderNames, Part} import sttp.monad.MonadError @@ -52,7 +52,7 @@ private[client4] trait BodyToHttpClient[F[_], S, R] { def streamToPublisher(stream: streams.BinaryStream): F[BodyPublisher] - def compressors: List[Compressor[R]] = List(new GZipDefaultCompressor(), new DeflateDefaultCompressor()) + def compressors: List[Compressor[R]] = Compressor.default[R] private def multipartBody[T](parts: Seq[Part[GenericRequestBody[_]]]) = { val multipartBuilder = new MultiPartBodyPublisher() diff --git a/core/src/main/scalanative/sttp/client4/compression/CompressorExtensions.scala b/core/src/main/scalanative/sttp/client4/compression/CompressorExtensions.scala new file mode 100644 index 0000000000..9831ced732 --- /dev/null +++ b/core/src/main/scalanative/sttp/client4/compression/CompressorExtensions.scala @@ -0,0 +1,5 @@ +package sttp.client4.compression + +trait CompressorExtensions { + def default[R]: List[Compressor[R]] = Nil +} diff --git a/core/src/test/scala/sttp/client4/internal/compression/GZIPCompressingInputStreamTest.scala b/core/src/test/scalajvm/sttp/client4/compression/GZIPCompressingInputStreamTest.scala similarity index 88% rename from core/src/test/scala/sttp/client4/internal/compression/GZIPCompressingInputStreamTest.scala rename to core/src/test/scalajvm/sttp/client4/compression/GZIPCompressingInputStreamTest.scala index adf1cd5f68..0cf040bad7 100644 --- a/core/src/test/scala/sttp/client4/internal/compression/GZIPCompressingInputStreamTest.scala +++ b/core/src/test/scalajvm/sttp/client4/compression/GZIPCompressingInputStreamTest.scala @@ -1,4 +1,4 @@ -package sttp.client4.internal.compression +package sttp.client4.compression import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers @@ -6,7 +6,6 @@ import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks import java.io.ByteArrayInputStream import java.util.zip.GZIPInputStream -import sttp.client4.compression.GZIPCompressingInputStream class GZIPCompressingInputStreamTest extends AnyFlatSpec with Matchers with ScalaCheckPropertyChecks { implicit override val generatorDrivenConfig = diff --git a/finagle-backend/src/main/scala/sttp/client4/finagle/FinagleBackend.scala b/finagle-backend/src/main/scala/sttp/client4/finagle/FinagleBackend.scala index 6627b04d10..4025ab2db5 100644 --- a/finagle-backend/src/main/scala/sttp/client4/finagle/FinagleBackend.scala +++ b/finagle-backend/src/main/scala/sttp/client4/finagle/FinagleBackend.scala @@ -82,7 +82,7 @@ class FinagleBackend(client: Option[Client] = None) extends Backend[TFuture] { val url = r.uri.toString val (body, contentLength) = Compressor.compressIfNeeded(r, compressors) val headers = { - val hh = headersToMap(r.headers).removed(HeaderNames.ContentLength) + val hh = headersToMap(r.headers) - HeaderNames.ContentLength contentLength.fold(hh)(cl => hh.updated(HeaderNames.ContentLength, cl.toString)) } diff --git a/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpBackend.scala b/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpBackend.scala index df1242ff60..87ca0f6541 100644 --- a/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpBackend.scala +++ b/okhttp-backend/src/main/scala/sttp/client4/okhttp/OkHttpBackend.scala @@ -23,8 +23,6 @@ import sttp.model._ import scala.collection.JavaConverters._ import sttp.client4.compression.Compressor -import sttp.client4.compression.DeflateDefaultCompressor -import sttp.client4.compression.GZipDefaultCompressor abstract class OkHttpBackend[F[_], S <: Streams[S], P]( client: OkHttpClient, @@ -36,7 +34,7 @@ abstract class OkHttpBackend[F[_], S <: Streams[S], P]( val streams: Streams[S] type R = P with Effect[F] - private val compressors: List[Compressor[R]] = List(new GZipDefaultCompressor(), new DeflateDefaultCompressor) + private val compressors: List[Compressor[R]] = Compressor.default[R] override def send[T](request: GenericRequest[T, R]): F[Response[T]] = adjustExceptions(request.isWebSocket, request) {