Skip to content

Commit

Permalink
Move compressors to jvm
Browse files Browse the repository at this point in the history
  • Loading branch information
adamw committed Dec 28, 2024
1 parent cefad70 commit 1e65bda
Show file tree
Hide file tree
Showing 12 changed files with 102 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ import sttp.model._
import sttp.monad.syntax._
import sttp.monad.{Canceler, MonadAsyncError}
import sttp.client4.compression.Compressor
import sttp.client4.compression.GZipDefaultCompressor
import sttp.client4.compression.DeflateDefaultCompressor

abstract class AbstractArmeriaBackend[F[_], S <: Streams[S]](
client: WebClient = WebClient.of(),
Expand All @@ -57,7 +55,7 @@ abstract class AbstractArmeriaBackend[F[_], S <: Streams[S]](

protected def streamToPublisher(stream: streams.BinaryStream): Publisher[HttpData]

protected def compressors: List[Compressor[R]] = List(new GZipDefaultCompressor(), new DeflateDefaultCompressor())
protected def compressors: List[Compressor[R]] = Compressor.default[R]

override def send[T](request: GenericRequest[T, R]): F[Response[T]] =
monad.suspend(adjustExceptions(request)(execute(request)))
Expand Down
77 changes: 1 addition & 76 deletions core/src/main/scala/sttp/client4/compression/Compressor.scala
Original file line number Diff line number Diff line change
@@ -1,89 +1,14 @@
package sttp.client4.compression

import sttp.client4._
import sttp.model.Encodings

import Compressor._
import java.nio.ByteBuffer
import java.util.zip.DeflaterInputStream
import java.util.zip.Deflater
import java.io.ByteArrayOutputStream

trait Compressor[R] {
def encoding: String
def apply(body: GenericRequestBody[R], encoding: String): GenericRequestBody[R]
}

class GZipDefaultCompressor[R] extends Compressor[R] {
val encoding: String = Encodings.Gzip

def apply(body: GenericRequestBody[R], encoding: String): GenericRequestBody[R] =
body match {
case NoBody => NoBody
case StringBody(s, encoding, defaultContentType) =>
ByteArrayBody(byteArray(s.getBytes(encoding)), defaultContentType)
case ByteArrayBody(b, defaultContentType) => ByteArrayBody(byteArray(b), defaultContentType)
case ByteBufferBody(b, defaultContentType) =>
ByteArrayBody(byteArray(byteBufferToArray(b)), defaultContentType)
case InputStreamBody(b, defaultContentType) =>
InputStreamBody(new GZIPCompressingInputStream(b), defaultContentType)
case StreamBody(b) => streamsNotSupported
case FileBody(f, defaultContentType) =>
InputStreamBody(new GZIPCompressingInputStream(f.openStream()), defaultContentType)
case MultipartStreamBody(parts) => compressingMultipartBodiesNotSupported
case BasicMultipartBody(parts) => compressingMultipartBodiesNotSupported
}

private def byteArray(bytes: Array[Byte]): Array[Byte] = {
val bos = new java.io.ByteArrayOutputStream()
val gzip = new java.util.zip.GZIPOutputStream(bos)
gzip.write(bytes)
gzip.close()
bos.toByteArray()
}
}

class DeflateDefaultCompressor[R] extends Compressor[R] {
val encoding: String = Encodings.Deflate

def apply(body: GenericRequestBody[R], encoding: String): GenericRequestBody[R] =
body match {
case NoBody => NoBody
case StringBody(s, encoding, defaultContentType) =>
ByteArrayBody(byteArray(s.getBytes(encoding)), defaultContentType)
case ByteArrayBody(b, defaultContentType) => ByteArrayBody(byteArray(b), defaultContentType)
case ByteBufferBody(b, defaultContentType) =>
ByteArrayBody(byteArray(byteBufferToArray(b)), defaultContentType)
case InputStreamBody(b, defaultContentType) =>
InputStreamBody(new DeflaterInputStream(b), defaultContentType)
case StreamBody(b) => streamsNotSupported
case FileBody(f, defaultContentType) =>
InputStreamBody(new DeflaterInputStream(f.openStream()), defaultContentType)
case MultipartStreamBody(parts) => compressingMultipartBodiesNotSupported
case BasicMultipartBody(parts) => compressingMultipartBodiesNotSupported
}

private def byteArray(bytes: Array[Byte]): Array[Byte] = {
val deflater = new Deflater()
try {
deflater.setInput(bytes)
deflater.finish()
val byteArrayOutputStream = new ByteArrayOutputStream()
val readBuffer = new Array[Byte](1024)

while (!deflater.finished()) {
val readCount = deflater.deflate(readBuffer)
if (readCount > 0) {
byteArrayOutputStream.write(readBuffer, 0, readCount)
}
}

byteArrayOutputStream.toByteArray
} finally deflater.end()
}
}

private[client4] object Compressor {
object Compressor extends CompressorExtensions {

/** Compress the request body if needed, using the given compressors.
* @return
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package sttp.client4.compression

trait CompressorExtensions {
def default[R]: List[Compressor[R]] = Nil
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package sttp.client4.compression

trait CompressorExtensions {
def default[R]: List[Compressor[R]] = List(new GZipDefaultCompressor[R](), new DeflateDefaultCompressor[R]())
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package sttp.client4.compression

import sttp.client4._
import sttp.model.Encodings

import Compressor._
import java.util.zip.Deflater
import java.util.zip.DeflaterInputStream
import java.io.ByteArrayOutputStream

class GZipDefaultCompressor[R] extends Compressor[R] {
val encoding: String = Encodings.Gzip

def apply(body: GenericRequestBody[R], encoding: String): GenericRequestBody[R] =
body match {
case NoBody => NoBody
case StringBody(s, encoding, defaultContentType) =>
ByteArrayBody(byteArray(s.getBytes(encoding)), defaultContentType)
case ByteArrayBody(b, defaultContentType) => ByteArrayBody(byteArray(b), defaultContentType)
case ByteBufferBody(b, defaultContentType) =>
ByteArrayBody(byteArray(byteBufferToArray(b)), defaultContentType)
case InputStreamBody(b, defaultContentType) =>
InputStreamBody(new GZIPCompressingInputStream(b), defaultContentType)
case StreamBody(b) => streamsNotSupported
case FileBody(f, defaultContentType) =>
InputStreamBody(new GZIPCompressingInputStream(f.openStream()), defaultContentType)
case MultipartStreamBody(parts) => compressingMultipartBodiesNotSupported
case BasicMultipartBody(parts) => compressingMultipartBodiesNotSupported
}

private def byteArray(bytes: Array[Byte]): Array[Byte] = {
val bos = new java.io.ByteArrayOutputStream()
val gzip = new java.util.zip.GZIPOutputStream(bos)
gzip.write(bytes)
gzip.close()
bos.toByteArray()
}
}

class DeflateDefaultCompressor[R] extends Compressor[R] {
val encoding: String = Encodings.Deflate

def apply(body: GenericRequestBody[R], encoding: String): GenericRequestBody[R] =
body match {
case NoBody => NoBody
case StringBody(s, encoding, defaultContentType) =>
ByteArrayBody(byteArray(s.getBytes(encoding)), defaultContentType)
case ByteArrayBody(b, defaultContentType) => ByteArrayBody(byteArray(b), defaultContentType)
case ByteBufferBody(b, defaultContentType) =>
ByteArrayBody(byteArray(byteBufferToArray(b)), defaultContentType)
case InputStreamBody(b, defaultContentType) =>
InputStreamBody(new DeflaterInputStream(b), defaultContentType)
case StreamBody(b) => streamsNotSupported
case FileBody(f, defaultContentType) =>
InputStreamBody(new DeflaterInputStream(f.openStream()), defaultContentType)
case MultipartStreamBody(parts) => compressingMultipartBodiesNotSupported
case BasicMultipartBody(parts) => compressingMultipartBodiesNotSupported
}

private def byteArray(bytes: Array[Byte]): Array[Byte] = {
val deflater = new Deflater()
try {
deflater.setInput(bytes)
deflater.finish()
val byteArrayOutputStream = new ByteArrayOutputStream()
val readBuffer = new Array[Byte](1024)

while (!deflater.finished()) {
val readCount = deflater.deflate(readBuffer)
if (readCount > 0) {
byteArrayOutputStream.write(readBuffer, 0, readCount)
}
}

byteArrayOutputStream.toByteArray
} finally deflater.end()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package sttp.client4.httpurlconnection
import sttp.capabilities.Effect
import sttp.client4.httpurlconnection.HttpURLConnectionBackend.EncodingHandler
import sttp.client4.internal._
import sttp.client4.compression.{Compressor, DeflateDefaultCompressor, GZipDefaultCompressor}
import sttp.client4.compression.Compressor
import sttp.client4.testing.SyncBackendStub
import sttp.client4.ws.{GotAWebSocketException, NotAWebSocketException}
import sttp.client4.{
Expand Down Expand Up @@ -49,7 +49,7 @@ class HttpURLConnectionBackend private (
) extends SyncBackend {
type R = Any with Effect[Identity]

private val compressors: List[Compressor[R]] = List(new GZipDefaultCompressor(), new DeflateDefaultCompressor())
private val compressors: List[Compressor[R]] = Compressor.default[R]

override def send[T](r: GenericRequest[T, R]): Response[T] =
adjustExceptions(r) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package sttp.client4.internal.httpclient
import sttp.capabilities.Streams
import sttp.client4.internal.SttpToJavaConverters.toJavaSupplier
import sttp.client4.internal.{throwNestedMultipartNotAllowed, Utf8}
import sttp.client4.compression.{Compressor, DeflateDefaultCompressor, GZipDefaultCompressor}
import sttp.client4.compression.Compressor
import sttp.client4._
import sttp.model.{Header, HeaderNames, Part}
import sttp.monad.MonadError
Expand Down Expand Up @@ -52,7 +52,7 @@ private[client4] trait BodyToHttpClient[F[_], S, R] {

def streamToPublisher(stream: streams.BinaryStream): F[BodyPublisher]

def compressors: List[Compressor[R]] = List(new GZipDefaultCompressor(), new DeflateDefaultCompressor())
def compressors: List[Compressor[R]] = Compressor.default[R]

private def multipartBody[T](parts: Seq[Part[GenericRequestBody[_]]]) = {
val multipartBuilder = new MultiPartBodyPublisher()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package sttp.client4.compression

trait CompressorExtensions {
def default[R]: List[Compressor[R]] = Nil
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
package sttp.client4.internal.compression
package sttp.client4.compression

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks

import java.io.ByteArrayInputStream
import java.util.zip.GZIPInputStream
import sttp.client4.compression.GZIPCompressingInputStream

class GZIPCompressingInputStreamTest extends AnyFlatSpec with Matchers with ScalaCheckPropertyChecks {
implicit override val generatorDrivenConfig =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class FinagleBackend(client: Option[Client] = None) extends Backend[TFuture] {
val url = r.uri.toString
val (body, contentLength) = Compressor.compressIfNeeded(r, compressors)
val headers = {
val hh = headersToMap(r.headers).removed(HeaderNames.ContentLength)
val hh = headersToMap(r.headers) - HeaderNames.ContentLength
contentLength.fold(hh)(cl => hh.updated(HeaderNames.ContentLength, cl.toString))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ import sttp.model._

import scala.collection.JavaConverters._
import sttp.client4.compression.Compressor
import sttp.client4.compression.DeflateDefaultCompressor
import sttp.client4.compression.GZipDefaultCompressor

abstract class OkHttpBackend[F[_], S <: Streams[S], P](
client: OkHttpClient,
Expand All @@ -36,7 +34,7 @@ abstract class OkHttpBackend[F[_], S <: Streams[S], P](
val streams: Streams[S]
type R = P with Effect[F]

private val compressors: List[Compressor[R]] = List(new GZipDefaultCompressor(), new DeflateDefaultCompressor)
private val compressors: List[Compressor[R]] = Compressor.default[R]

override def send[T](request: GenericRequest[T, R]): F[Response[T]] =
adjustExceptions(request.isWebSocket, request) {
Expand Down

0 comments on commit 1e65bda

Please sign in to comment.